Skip to content

Commit

Permalink
io API refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
antonymayi committed Sep 26, 2021
1 parent 7d1eff6 commit 7731b39
Show file tree
Hide file tree
Showing 22 changed files with 91 additions and 92 deletions.
5 changes: 5 additions & 0 deletions forml/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@
"""
ETL layer.
"""

from ._input import Feed, Importer
from ._output import Exporter, Sink

__all__ = ['Feed', 'Sink', 'Importer', 'Exporter']
29 changes: 16 additions & 13 deletions forml/io/feed/__init__.py → forml/io/_input/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
from forml.conf.parsed import provider as provcfg
from forml.io import dsl, layout
from forml.io.dsl import parser
from forml.io.feed import extract
from forml.project import component

from . import extract

if typing.TYPE_CHECKING:
from forml.project import component

LOGGER = logging.getLogger(__name__)


class Provider(
class Feed(
provmod.Interface,
typing.Generic[parser.Source, parser.Feature],
path=provcfg.Feed.path, # pylint: disable=no-member
Expand Down Expand Up @@ -189,7 +192,7 @@ def features(self) -> typing.Mapping['dsl.Feature', parser.Feature]:
return {}


class Pool:
class Importer:
"""Pool of (possibly) lazily instantiated feeds. If configured without any explicit feeds, all of the feeds
registered in the provider cache are added.
Expand All @@ -204,14 +207,14 @@ class Pool:
class Slot:
"""Representation of a single feed provided either explicitly s an instance or lazily as a descriptor."""

def __init__(self, feed: typing.Union[provcfg.Feed, str, Provider]):
def __init__(self, feed: typing.Union[provcfg.Feed, str, Feed]):
if isinstance(feed, str):
feed = provcfg.Feed.resolve(feed)
descriptor, instance = (feed, None) if isinstance(feed, provcfg.Feed) else (None, feed)
self._descriptor: typing.Optional[provcfg.Feed] = descriptor
self._instance: typing.Optional[Provider] = instance
self._instance: typing.Optional[Feed] = instance

def __lt__(self, other: 'Pool.Slot'):
def __lt__(self, other: 'Importer.Slot'):
return self.priority < other.priority

@property
Expand All @@ -224,15 +227,15 @@ def priority(self) -> float:
return self._descriptor.priority if self._descriptor else float('inf')

@property
def instance(self) -> Provider:
def instance(self) -> Feed:
"""Return the feed instance possibly creating it on the fly if lazy.
Returns:
Feed instance.
"""
if self._instance is None:
LOGGER.debug('Instantiating feed %s', self._descriptor.reference)
self._instance = Provider[self._descriptor.reference](**self._descriptor.params)
self._instance = Feed[self._descriptor.reference](**self._descriptor.params)
return self._instance

class Matcher(dsl.Source.Visitor):
Expand Down Expand Up @@ -268,14 +271,14 @@ def visit_table(self, source: 'dsl.Table') -> None:
if source not in self._sources:
self._matches = False

def __init__(self, *feeds: typing.Union[provcfg.Feed, str, Provider]):
self._feeds: tuple[Pool.Slot] = tuple(sorted((self.Slot(f) for f in feeds or Provider), reverse=True))
def __init__(self, *feeds: typing.Union[provcfg.Feed, str, Feed]):
self._feeds: tuple[Importer.Slot] = tuple(sorted((self.Slot(f) for f in feeds or Feed), reverse=True))

def __iter__(self) -> typing.Iterable[Provider]:
def __iter__(self) -> typing.Iterable[Feed]:
for feed in self._feeds:
yield feed.instance

def match(self, source: 'dsl.Source') -> Provider:
def match(self, source: 'dsl.Source') -> Feed:
"""Select a feed that can provide for (be used to construct) the given source.
Args:
Expand Down
File renamed without changes.
19 changes: 10 additions & 9 deletions forml/io/sink/__init__.py → forml/io/_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
from forml import provider as provmod
from forml.conf.parsed import provider as provcfg
from forml.io import layout
from forml.io.sink import publish

from . import publish

class Provider(provmod.Interface, default=provcfg.Sink.default, path=provcfg.Sink.path):

class Sink(provmod.Interface, default=provcfg.Sink.default, path=provcfg.Sink.path):
"""Sink is an implementation of a specific data consumer."""

class Writer(publish.Writer, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -59,22 +60,22 @@ def writer(cls, **kwargs: typing.Any) -> typing.Callable[[layout.ColumnMajor], N
return cls.Writer(**kwargs) # pylint: disable=abstract-class-instantiated


class Handle:
"""Sink handle is a lazy wrapper around alternative sink specifiers providing a particular Sink instance upon
class Exporter:
"""Sink exporter is a lazy wrapper around alternative sink specifiers providing a particular Sink instance upon
request.
"""

def __init__(self, sink: typing.Union[provcfg.Sink.Mode, str, Provider]):
def __init__(self, sink: typing.Union[provcfg.Sink.Mode, str, Sink]):
if isinstance(sink, str):
sink = provcfg.Sink.Mode.resolve(sink)
self._sink: typing.Union[provcfg.Sink.Mode, Provider] = sink
self._sink: typing.Union[provcfg.Sink.Mode, Sink] = sink

def __call__(self, getter: property) -> 'Provider':
if isinstance(self._sink, Provider): # already a Sink instance
def __call__(self, getter: property) -> 'Sink':
if isinstance(self._sink, Sink): # already a Sink instance
return self._sink
assert isinstance(self._sink, provcfg.Sink.Mode)
descriptor: provcfg.Sink = getter.fget(self._sink)
return Provider[descriptor.reference](**descriptor.params)
return Sink[descriptor.reference](**descriptor.params)

# pylint: disable=no-member
train = property(lambda self: self(provcfg.Sink.Mode.train))
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions forml/lib/feed/reader/sql/alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.engine import interfaces

from forml import io
from forml.io import dsl, layout
from forml.io.dsl import error, function
from forml.io.dsl import parser as parsmod
from forml.io.feed import extract

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -334,7 +334,7 @@ def generate_reference(
return ref, ref


class Reader(extract.Reader[sql.Selectable, sql.ColumnElement, pandas.DataFrame]):
class Reader(io.Feed.Reader[sql.Selectable, sql.ColumnElement, pandas.DataFrame]):
"""SQLAlchemy based reader."""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions forml/lib/feed/reader/sql/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import re
import typing

from forml import io
from forml.io import dsl, layout
from forml.io.dsl import error, function
from forml.io.dsl import parser as parsmod
from forml.io.feed import extract

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -332,7 +332,7 @@ def generate_reference(self, instance: str, name: str) -> tuple[str, str]: # py
return f'{self.Wrap.word(instance)} AS "{name}"', name


class Reader(extract.Reader[str, str, layout.RowMajor], metaclass=abc.ABCMeta):
class Reader(io.Feed.Reader[str, str, layout.RowMajor], metaclass=abc.ABCMeta):
"""SQL reader base class for PEP249 compliant DB APIs."""

@classmethod
Expand Down
5 changes: 3 additions & 2 deletions forml/lib/feed/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import types
import typing

from forml.io import dsl, feed, layout
from forml import io
from forml.io import dsl, layout
from forml.io.dsl import error


class Feed(feed.Provider[None, layout.Vector]):
class Feed(io.Feed[None, layout.Vector]):
"""Static feed is initialized with actual data which can only be returned in primitive column-wise fashion. No
advanced ETL can be applied.
"""
Expand Down
8 changes: 3 additions & 5 deletions forml/lib/runner/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@

import dask

from forml import runtime
from forml.io import feed as feedmod
from forml.io import sink as sinkmod
from forml import io, runtime
from forml.runtime import code
from forml.runtime.asset import access

Expand Down Expand Up @@ -91,8 +89,8 @@ def __repr__(self):
def __init__(
self,
assets: typing.Optional[access.Assets] = None,
feed: typing.Optional[feedmod.Provider] = None,
sink: typing.Optional[sinkmod.Provider] = None,
feed: typing.Optional[io.Feed] = None,
sink: typing.Optional[io.Sink] = None,
scheduler: typing.Optional[str] = None,
):
super().__init__(assets, feed, sink)
Expand Down
8 changes: 3 additions & 5 deletions forml/lib/runner/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@

import graphviz as grviz

from forml import conf, runtime
from forml.io import feed as feedmod
from forml.io import sink as sinkmod
from forml import conf, io, runtime
from forml.runtime import code
from forml.runtime.asset import access
from forml.runtime.code import instruction
Expand All @@ -41,8 +39,8 @@ class Runner(runtime.Runner, alias='graphviz'):
def __init__(
self,
assets: typing.Optional[access.Assets] = None,
feed: typing.Optional[feedmod.Provider] = None,
sink: typing.Optional[sinkmod.Provider] = None,
feed: typing.Optional[io.Feed] = None,
sink: typing.Optional[io.Sink] = None,
filepath: typing.Optional[str] = None,
**gvkw: typing.Any,
):
Expand Down
7 changes: 4 additions & 3 deletions forml/lib/sink/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
"""
import typing

from forml.io import layout, sink
from forml import io
from forml.io import layout


class Sink(sink.Provider, alias='stdout'):
class Sink(io.Sink, alias='stdout'):
"""Stdout sink."""

class Writer(sink.Provider.Writer):
class Writer(io.Sink.Writer):
"""Sink writer implementation."""

@classmethod
Expand Down
30 changes: 14 additions & 16 deletions forml/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@
import abc
import typing

from forml import error, flow
from forml import error, flow, io
from forml import provider as provmod
from forml.conf.parsed import provider as provcfg
from forml.io import dsl
from forml.io import feed as feedmod
from forml.io import sink as sinkmod
from forml.runtime import code
from forml.runtime.asset import access, directory, persistent
from forml.runtime.asset.directory import root
Expand All @@ -43,13 +41,13 @@ class Runner(provmod.Interface, default=provcfg.Runner.default, path=provcfg.Run
def __init__(
self,
assets: typing.Optional[access.Assets] = None,
feed: typing.Optional['feedmod.Provider'] = None,
sink: typing.Optional['sinkmod.Provider'] = None,
feed: typing.Optional['io.Feed'] = None,
sink: typing.Optional['io.Feed'] = None,
**_,
):
self._assets: access.Assets = assets or access.Assets()
self._feed: feedmod.Provider = feed or feedmod.Provider()
self._sink: typing.Optional[sinkmod.Provider] = sink
self._feed: io.Feed = feed or io.Feed()
self._sink: typing.Optional[io.Sink] = sink

def train(self, lower: typing.Optional['dsl.Native'] = None, upper: typing.Optional['dsl.Native'] = None) -> None:
"""Run the training code.
Expand Down Expand Up @@ -148,12 +146,12 @@ class Launcher:
"""Runner handle."""

def __init__(
self, provider: provcfg.Runner, assets: access.Assets, feeds: 'Platform.Feeds', sink: sinkmod.Handle
self, provider: provcfg.Runner, assets: access.Assets, feeds: 'Platform.Feeds', sink: 'io.Exporter'
):
self._provider: provcfg.Runner = provider
self._assets: access.Assets = assets
self._feeds: Platform.Feeds = feeds
self._sink: sinkmod.Handle = sink
self._sink: io.Exporter = sink

@property
def train(self) -> typing.Callable[[typing.Optional['dsl.Native'], typing.Optional['dsl.Native']], None]:
Expand Down Expand Up @@ -191,7 +189,7 @@ def tune(self) -> typing.Callable[[typing.Optional['dsl.Native'], typing.Optiona
"""
raise NotImplementedError()

def __call__(self, query: 'dsl.Query', sink: 'sinkmod.Provider') -> Runner:
def __call__(self, query: 'dsl.Query', sink: 'io.Feed') -> Runner:
return Runner[self._provider.reference](
self._assets, self._feeds.match(query), sink, **self._provider.params
)
Expand Down Expand Up @@ -250,10 +248,10 @@ def list(
class Feeds:
"""Feed pool and util handle."""

def __init__(self, *configs: typing.Union[provcfg.Feed, 'feedmod.Provider']):
self._pool: feedmod.Pool = feedmod.Pool(*configs)
def __init__(self, *configs: typing.Union[provcfg.Feed, 'io.Feed']):
self._pool: io.Importer = io.Importer(*configs)

def match(self, query: 'dsl.Query') -> 'feedmod.Provider':
def match(self, query: 'dsl.Query') -> 'io.Feed':
"""Select the feed that can provide for given query.
Args:
Expand All @@ -275,15 +273,15 @@ def __init__(
self,
runner: typing.Optional[typing.Union[provcfg.Runner, str]] = None,
registry: typing.Optional[typing.Union[provcfg.Registry, persistent.Registry]] = None,
feeds: typing.Optional[typing.Iterable[typing.Union[provcfg.Feed, str, 'feedmod.Provider']]] = None,
sink: typing.Optional[typing.Union[provcfg.Sink.Mode, str, sinkmod.Provider]] = None,
feeds: typing.Optional[typing.Iterable[typing.Union[provcfg.Feed, str, 'io.Feed']]] = None,
sink: typing.Optional[typing.Union[provcfg.Sink.Mode, str, 'io.Sink']] = None,
):
if isinstance(runner, str):
runner = provcfg.Runner.resolve(runner)
self._runner: provcfg.Runner = runner or provcfg.Runner.default
self._registry: Platform.Registry = self.Registry(registry or provcfg.Registry.default)
self._feeds: Platform.Feeds = self.Feeds(*(feeds or provcfg.Feed.default))
self._sink: sinkmod.Handle = sinkmod.Handle(sink or provcfg.Sink.Mode.default)
self._sink: io.Exporter = io.Exporter(sink or provcfg.Sink.Mode.default)

def launcher(
self,
Expand Down

0 comments on commit 7731b39

Please sign in to comment.