Skip to content

Commit

Permalink
adding sink tests
Browse files Browse the repository at this point in the history
  • Loading branch information
antonymayi committed Sep 26, 2022
1 parent 7fbfdea commit ef2a45b
Show file tree
Hide file tree
Showing 12 changed files with 345 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -43,7 +43,7 @@ repos:
name: Black
args: [--config=pyproject.toml]
- repo: https://github.com/asottile/pyupgrade
rev: v2.37.3
rev: v2.38.2
hooks:
- id: pyupgrade
name: Python syntax upgrade
Expand Down
47 changes: 25 additions & 22 deletions constraints.txt
Expand Up @@ -12,7 +12,7 @@ anyio==3.6.1
# via starlette
appdirs==1.4.4
# via sphinx-immaterial
astroid==2.12.9
astroid==2.12.10
# via pylint
asttokens==2.0.8
# via stack-data
Expand All @@ -33,7 +33,7 @@ bleach==5.0.1
# via nbconvert
build==0.8.0
# via pip-tools
certifi==2022.6.15.1
certifi==2022.9.24
# via requests
cfgv==3.3.1
# via pre-commit
Expand All @@ -56,7 +56,7 @@ cloudpickle==2.2.0
# mlflow
coverage==6.4.4
# via pytest-cov
dask==2022.9.0
dask==2022.9.1
# via forml (setup.py)
databricks-cli==0.17.3
# via mlflow
Expand All @@ -68,7 +68,7 @@ dill==0.3.5.1
# via pylint
distlib==0.3.6
# via virtualenv
docker==5.0.3
docker==6.0.0
# via mlflow
docutils==0.19
# via
Expand All @@ -80,9 +80,9 @@ entrypoints==0.4
# mlflow
execnet==1.9.0
# via pytest-xdist
executing==1.0.0
executing==1.1.0
# via stack-data
fastjsonschema==2.16.1
fastjsonschema==2.16.2
# via nbformat
filelock==3.8.0
# via virtualenv
Expand All @@ -91,7 +91,7 @@ flake8==5.0.4
# flake8-bugbear
# flake8-colors
# flake8-typing-imports
flake8-bugbear==22.9.11
flake8-bugbear==22.9.23
# via forml (setup.py)
flake8-colors==0.1.9
# via forml (setup.py)
Expand All @@ -115,11 +115,11 @@ greenlet==1.1.3
# via sqlalchemy
gunicorn==20.1.0
# via mlflow
h11==0.13.0
h11==0.14.0
# via uvicorn
identify==2.5.5
# via pre-commit
idna==3.3
idna==3.4
# via
# anyio
# requests
Expand All @@ -146,7 +146,7 @@ jinja2==3.1.2
# nbconvert
# nbsphinx
# sphinx
joblib==1.1.0
joblib==1.2.0
# via scikit-learn
jsonschema==4.16.0
# via nbformat
Expand All @@ -167,7 +167,7 @@ locket==1.0.0
# via partd
lxml==4.9.1
# via nbconvert
mako==1.2.2
mako==1.2.3
# via alembic
markupsafe==2.1.1
# via
Expand All @@ -184,7 +184,7 @@ mccabe==0.7.0
# pylint
mistune==2.0.4
# via nbconvert
mlflow==1.28.0
mlflow==1.29.0
# via forml (setup.py)
mypy-extensions==0.4.3
# via
Expand All @@ -194,7 +194,7 @@ nbclient==0.6.8
# via nbconvert
nbconvert==7.0.0
# via nbsphinx
nbformat==5.4.0
nbformat==5.6.0
# via
# nbclient
# nbconvert
Expand All @@ -220,12 +220,13 @@ packaging==21.3
# via
# build
# dask
# docker
# forml (setup.py)
# mlflow
# nbconvert
# pytest
# sphinx
pandas==1.4.4
pandas==1.5.0
# via
# forml (setup.py)
# mlflow
Expand Down Expand Up @@ -264,7 +265,7 @@ prometheus-flask-exporter==0.20.3
# via mlflow
prompt-toolkit==3.0.31
# via ipython
protobuf==4.21.5
protobuf==4.21.6
# via mlflow
ptyprocess==0.7.0
# via pexpect
Expand All @@ -291,9 +292,9 @@ pygments==2.13.0
# sphinx
pyhive==0.6.5
# via forml (setup.py)
pyjwt==2.4.0
pyjwt==2.5.0
# via databricks-cli
pylint==2.15.2
pylint==2.15.3
# via forml (setup.py)
pyparsing==3.0.9
# via packaging
Expand Down Expand Up @@ -330,7 +331,7 @@ pyyaml==6.0
# mlflow
# pre-commit
# pycln
pyzmq==23.2.1
pyzmq==24.0.1
# via jupyter-client
querystring-parser==1.2.4
# via mlflow
Expand Down Expand Up @@ -364,7 +365,7 @@ snowballstemmer==2.2.0
# via sphinx
soupsieve==2.3.2.post1
# via beautifulsoup4
sphinx==5.1.1
sphinx==5.2.1
# via
# forml (setup.py)
# nbsphinx
Expand Down Expand Up @@ -399,9 +400,9 @@ sqlalchemy==1.4.41
# alembic
# forml (setup.py)
# mlflow
sqlparse==0.4.2
sqlparse==0.4.3
# via mlflow
stack-data==0.5.0
stack-data==0.5.1
# via ipython
starlette==0.20.4
# via forml (setup.py)
Expand Down Expand Up @@ -457,7 +458,9 @@ typing-extensions==4.3.0
typing-inspect==0.8.0
# via libcst
urllib3==1.26.12
# via requests
# via
# docker
# requests
uvicorn==0.18.3
# via forml (setup.py)
virtualenv==20.16.5
Expand Down
2 changes: 1 addition & 1 deletion docs/platform.rst
Expand Up @@ -100,7 +100,7 @@ The meaning of the different placeholders and keywords is:
keyword to preselect a configuration instance for situations when no explicit choice is
specified during some particular execution.

.. attention::
.. note::
The ``FEED`` provider type can specify a list of *multiple* instances as *default*
(contextual :ref:`feed selection <feed-selection>` is then performed at runtime).

Expand Down
4 changes: 2 additions & 2 deletions forml/io/_output/__init__.py
Expand Up @@ -75,8 +75,8 @@ def consumer(cls, schema: typing.Optional['dsl.Source.Schema'], **kwargs: typing


class Exporter:
"""Sink exporter is a lazy wrapper around alternative sink specifiers providing a particular Sink instance upon
request.
"""Sink exporter is a lazy wrapper around alternative sink specifiers providing a particular
Sink instance upon request.
"""

def __init__(self, sink: typing.Union[setup.Sink.Mode, str, Sink]):
Expand Down
2 changes: 1 addition & 1 deletion forml/runtime/_pad.py
Expand Up @@ -162,7 +162,7 @@ def tune(self) -> typing.Callable[[typing.Optional[dsl.Native], typing.Optional[
"""Return the tune handler.
Returns:
Tune handler
Tune handler.
"""
raise NotImplementedError()

Expand Down
23 changes: 11 additions & 12 deletions forml/runtime/_pseudo.py
Expand Up @@ -38,17 +38,6 @@
LOGGER = logging.getLogger(__name__)


class Sink(io.Sink):
"""Sniffer sink."""

def __init__(self, sniffer: payload.Sniff):
super().__init__()
self._sniffer: payload.Sniff = sniffer

def save(self, schema: typing.Optional['dsl.Source.Schema']) -> 'flow.Composable':
return self._sniffer


class Virtual:
"""Custom launcher allowing to execute the provided artifact using the default or an
explicit runner in combination with a special pipeline sink to capture and return any output
Expand Down Expand Up @@ -185,6 +174,16 @@ def tune(self) -> typing.Callable[[typing.Optional['dsl.Native'], typing.Optiona
"""
return self._launcher.tune

class Sink(io.Sink):
"""Sniffer sink."""

def __init__(self, sniffer: payload.Sniff):
super().__init__()
self._sniffer: payload.Sniff = sniffer

def save(self, schema: typing.Optional['dsl.Source.Schema']) -> 'flow.Composable':
return self._sniffer

def __init__(self, artifact: 'project.Artifact'):
class Manifest(types.ModuleType):
"""Fake manifest module."""
Expand All @@ -210,7 +209,7 @@ def __call__(
runner: typing.Optional[typing.Union[setup.Runner, str]] = None,
feeds: typing.Optional[typing.Iterable[typing.Union[setup.Feed, str, io.Feed]]] = None,
) -> 'runtime.Virtual.Handler':
launcher = _pad.Platform(runner, self._registry, feeds, Sink(self._sniffer)).launcher(self._project)
launcher = _pad.Platform(runner, self._registry, feeds, self.Sink(self._sniffer)).launcher(self._project)
return self.Handler(launcher, self._sniffer)

def __getitem__(self, runner: typing.Union[setup.Runner, str]) -> 'runtime.Virtual.Handler':
Expand Down
96 changes: 96 additions & 0 deletions tests/application/test_strategy.py
@@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Strategy unit tests.
"""
import abc
import typing

import pytest

from forml import application, runtime
from forml.io import asset


class Strategy(abc.ABC):
"""Base class for strategy tests."""

@staticmethod
@abc.abstractmethod
@pytest.fixture(scope='function')
def strategy() -> application.Selector:
"""Strategy fixture."""

@staticmethod
@abc.abstractmethod
@pytest.fixture(scope='function')
def instance() -> asset.Instance:
"""Instance fixture."""

@staticmethod
@pytest.fixture(scope='function')
def context() -> typing.Any:
"""Context fixture."""
return None

@staticmethod
@pytest.fixture(scope='function')
def stats() -> runtime.Stats:
"""Stats fixture."""
return runtime.Stats()

def test_select(
self,
strategy: application.Selector,
instance: asset.Instance,
directory: asset.Directory,
context: typing.Any,
stats: runtime.Stats,
):
"""Strategy select test."""
assert strategy.select(directory, context, stats) == instance


class TestExplicit(Strategy):
"""Explicit strategy unit tests."""

@staticmethod
@pytest.fixture(scope='function')
def strategy(
project_name: asset.Project.Key, project_release: asset.Release.Key, valid_generation: asset.Generation.Key
) -> application.Selector:
return application.Explicit(project_name, project_release, valid_generation)

@staticmethod
@pytest.fixture(scope='function')
def instance(valid_instance: asset.Instance) -> asset.Instance:
return valid_instance


class TestLatest(Strategy):
"""Latest strategy unit tests."""

@staticmethod
@pytest.fixture(scope='function', params=[False, True])
def strategy(request, project_name: asset.Project.Key, project_release: asset.Release.Key) -> application.Selector:
release = project_release if request.param else None
return application.Latest(project_name, release)

@staticmethod
@pytest.fixture(scope='function')
def instance(valid_instance: asset.Instance) -> asset.Instance:
return valid_instance
9 changes: 3 additions & 6 deletions tests/provider/feed/__init__.py
Expand Up @@ -37,8 +37,7 @@ class Launcher:
"""Feed test launcher."""

def __init__(self, feed: io.Feed, source: project.Source):
self._sniff = payload.Sniff()
self._handler: runtime.Virtual.Handler = source.bind(self._sniff).launcher(runner='dask', feeds=[feed])
self._handler: runtime.Virtual.Handler = source.bind(payload.Sniff()).launcher(runner='dask', feeds=[feed])

@property
def apply(self) -> numpy.array:
Expand All @@ -48,10 +47,8 @@ def apply(self) -> numpy.array:
@property
def train(self) -> tuple[numpy.array, numpy.array]:
"""Train-mode result."""
with self._sniff as future:
self._handler.train()
features, labels = future.result()
return numpy.array(features, dtype=object), numpy.array(labels, dtype=object)
result = self._handler.train()
return numpy.array(result.features, dtype=object), numpy.array(result.labels, dtype=object)

@staticmethod
@abc.abstractmethod
Expand Down

0 comments on commit ef2a45b

Please sign in to comment.