Skip to content

Commit

Permalink
testing API refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
antonymayi committed Sep 26, 2021
1 parent 66260b5 commit 8ba54d6
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 114 deletions.
38 changes: 3 additions & 35 deletions forml/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,8 @@
"""
Testing framework.
"""
import typing

from forml import flow
from forml.testing import routine, spec
from ._routine import Suite, operator
from ._spec import Case, Scenario


def operator(subject: type[flow.Operator]) -> type[routine.Suite]:
"""Operator base class generator.
Args:
subject: Operator to be tested within given suite.
"""

class Operator(routine.Suite, metaclass=routine.Meta):
"""Generated base class."""

@property
def __operator__(self) -> type[flow.Operator]:
"""Attached operator.
Returns:
Operator instance.
"""
return subject

return Operator


class Case(spec.Appliable):
"""Test case entrypoint."""

def __init__(self, *args, **kwargs):
super().__init__(spec.Scenario.Params(*args, **kwargs))

def train(self, features: typing.Any, labels: typing.Any = None) -> spec.Trained:
"""Train input dataset definition."""
return spec.Trained(self._params, spec.Scenario.Input(train=features, label=labels))
__all__ = ['Case', 'operator', 'Suite', 'Scenario']
10 changes: 5 additions & 5 deletions forml/testing/facility.py → forml/testing/_facility.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from forml.conf.parsed import provider as provcfg
from forml.io import dsl, layout
from forml.runtime import launcher
from forml.testing import spec
from forml.testing import _spec

LOGGER = logging.getLogger(__name__)

Expand All @@ -44,9 +44,9 @@ class DataSet(dsl.Schema):
class Feed(io.Feed[None, typing.Any], alias='testing'):
"""Special feed to input the test cases."""

def __init__(self, scenario: spec.Scenario.Input, **kwargs):
def __init__(self, scenario: _spec.Scenario.Input, **kwargs):
super().__init__(**kwargs)
self._scenario: spec.Scenario.Input = scenario
self._scenario: _spec.Scenario.Input = scenario

# pylint: disable=unused-argument
@classmethod
Expand Down Expand Up @@ -100,8 +100,8 @@ def visit_node(self, node: flow.Worker) -> None:
self._gids.add(node.gid)
node.spec()

def __init__(self, params: spec.Scenario.Params, scenario: spec.Scenario.Input, runner: provcfg.Runner):
self._params: spec.Scenario.Params = params
def __init__(self, params: _spec.Scenario.Params, scenario: _spec.Scenario.Input, runner: provcfg.Runner):
self._params: _spec.Scenario.Params = params
self._source: project.Source = project.Source.query(DataSet.select(DataSet.feature), DataSet.label)
self._feed: Feed = Feed(scenario)
self._runner: provcfg.Runner = runner
Expand Down
62 changes: 42 additions & 20 deletions forml/testing/routine.py → forml/testing/_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import typing
import unittest

from forml import flow
from forml.conf.parsed import provider as provcfg
from forml.flow._suite import member
from forml.runtime import launcher as launchmod
from forml.testing import facility, spec
from forml.testing import _facility, _spec

LOGGER = logging.getLogger(__name__)

Expand All @@ -41,7 +41,7 @@ def __repr__(self):

@property
@abc.abstractmethod
def __operator__(self) -> type[member.Operator]:
def __operator__(self) -> type[flow.Operator]:
"""Operator instance."""


Expand All @@ -51,7 +51,7 @@ class Meta(abc.ABCMeta):
def __new__(mcs, name: str, bases: tuple[type], namespace: dict[str, typing.Any], **kwargs):
if not any(issubclass(b, Suite) for b in bases):
raise TypeError(f'{name} not a valid {Suite.__name__}')
for title, scenario in [(t, s) for t, s in namespace.items() if isinstance(s, spec.Scenario)]:
for title, scenario in [(t, s) for t, s in namespace.items() if isinstance(s, _spec.Scenario)]:
namespace[f'test_{title}'] = Case(title, scenario)
del namespace[title]
return super().__new__(mcs, name, bases, namespace)
Expand All @@ -60,8 +60,8 @@ def __new__(mcs, name: str, bases: tuple[type], namespace: dict[str, typing.Any]
class Test:
"""Base class for test implementations."""

def __init__(self, launcher: facility.Launcher):
self._launcher: facility.Launcher = launcher
def __init__(self, launcher: _facility.Launcher):
self._launcher: _facility.Launcher = launcher

def __call__(self, suite: Suite) -> None:
launcher: typing.Optional[launchmod.Virtual.Builder] = self.init(suite)
Expand Down Expand Up @@ -113,9 +113,9 @@ def test(self, launcher: launchmod.Virtual.Builder) -> typing.Any:
class RaisableTest(Test):
"""Base test class for raising test cases."""

def __init__(self, launcher: facility.Launcher, exception: spec.Scenario.Exception):
def __init__(self, launcher: _facility.Launcher, exception: _spec.Scenario.Exception):
super().__init__(launcher)
self._exception: spec.Scenario.Exception = exception
self._exception: _spec.Scenario.Exception = exception

def raises(self, suite: Suite) -> typing.ContextManager:
"""Context manager for wrapping raising assertions.
Expand All @@ -134,9 +134,9 @@ def raises(self, suite: Suite) -> typing.ContextManager:
class ReturnableTest(Test):
"""Base test class for returning test cases."""

def __init__(self, launcher: facility.Launcher, output: spec.Scenario.Output):
def __init__(self, launcher: _facility.Launcher, output: _spec.Scenario.Output):
super().__init__(launcher)
self._output: spec.Scenario.Output = output
self._output: _spec.Scenario.Output = output

def matches(self, suite: Suite, value: typing.Any) -> None:
"""Context manager for wrapping raising assertions.
Expand Down Expand Up @@ -212,13 +212,13 @@ class TestStateApplyRaises(RaisableTest, StateApplyTest):
class Case:
"""Test case routine."""

def __init__(self, name: str, scenario: spec.Scenario, launcher: provcfg.Runner = provcfg.Runner.default):
def __init__(self, name: str, scenario: _spec.Scenario, launcher: provcfg.Runner = provcfg.Runner.default):
self._name: str = name
launcher = facility.Launcher(scenario.params, scenario.input, launcher)
launcher = _facility.Launcher(scenario.params, scenario.input, launcher)
self._test: Test = self.select(scenario, launcher)

@staticmethod
def select(scenario: spec.Scenario, launcher: facility.Launcher) -> Test:
def select(scenario: _spec.Scenario, launcher: _facility.Launcher) -> Test:
"""Selecting and setting up the test implementation for given scenario.
Args:
Expand All @@ -228,19 +228,19 @@ def select(scenario: spec.Scenario, launcher: facility.Launcher) -> Test:
Returns:
Test case instance.
"""
if scenario.outcome is spec.Scenario.Outcome.INIT_RAISES:
if scenario.outcome is _spec.Scenario.Outcome.INIT_RAISES:
return TestInitRaises(launcher, scenario.exception)
if scenario.outcome is spec.Scenario.Outcome.PLAINAPPLY_RAISES:
if scenario.outcome is _spec.Scenario.Outcome.PLAINAPPLY_RAISES:
return TestPlainApplyRaises(launcher, scenario.exception)
if scenario.outcome is spec.Scenario.Outcome.STATETRAIN_RAISES:
if scenario.outcome is _spec.Scenario.Outcome.STATETRAIN_RAISES:
return TestStateTrainRaises(launcher, scenario.exception)
if scenario.outcome is spec.Scenario.Outcome.STATEAPPLY_RAISES:
if scenario.outcome is _spec.Scenario.Outcome.STATEAPPLY_RAISES:
return TestStateApplyRaises(launcher, scenario.exception)
if scenario.outcome is spec.Scenario.Outcome.PLAINAPPLY_RETURNS:
if scenario.outcome is _spec.Scenario.Outcome.PLAINAPPLY_RETURNS:
return TestPlainApplyReturns(launcher, scenario.output)
if scenario.outcome is spec.Scenario.Outcome.STATETRAIN_RETURNS:
if scenario.outcome is _spec.Scenario.Outcome.STATETRAIN_RETURNS:
return TestStateTrainReturns(launcher, scenario.output)
if scenario.outcome is spec.Scenario.Outcome.STATEAPPLY_RETURNS:
if scenario.outcome is _spec.Scenario.Outcome.STATEAPPLY_RETURNS:
return TestStateApplyReturns(launcher, scenario.output)
raise RuntimeError('Unexpected scenario outcome')

Expand All @@ -252,3 +252,25 @@ def case():

case.__doc__ = f'Test of {string.capwords(self._name.replace("_", " "))}'
return case


def operator(subject: type[flow.Operator]) -> type[Suite]:
"""Operator base class generator.
Args:
subject: Operator to be tested within given suite.
"""

class Operator(Suite, metaclass=Meta):
"""Generated base class."""

@property
def __operator__(self) -> type[flow.Operator]:
"""Attached operator.
Returns:
Operator instance.
"""
return subject

return Operator
11 changes: 11 additions & 0 deletions forml/testing/spec.py → forml/testing/_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,14 @@ def returns(
) -> Scenario:
"""Assertion on expected return value."""
return Scenario(self._params, self._input, Scenario.Output(train=output, matcher=matcher))


class Case(Appliable):
"""Test case entrypoint."""

def __init__(self, *args, **kwargs):
super().__init__(Scenario.Params(*args, **kwargs))

def train(self, features: typing.Any, labels: typing.Any = None) -> Trained:
"""Train input dataset definition."""
return Trained(self._params, Scenario.Input(train=features, label=labels))
25 changes: 12 additions & 13 deletions tests/testing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,16 @@
import pytest

from forml import testing
from forml.testing import spec


@pytest.fixture(scope='session')
def exception() -> spec.Scenario.Exception:
def exception() -> testing.Scenario.Exception:
"""Exception fixture.
Returns:
Exception type.
"""
return spec.Scenario.Exception(RuntimeError, 'This is an Error')
return testing.Scenario.Exception(RuntimeError, 'This is an Error')


@pytest.fixture(scope='session')
Expand Down Expand Up @@ -90,37 +89,37 @@ def train_output() -> str:


@pytest.fixture(scope='session')
def init_raises(hyperparams: typing.Mapping[str, int], exception: spec.Scenario.Exception) -> spec.Scenario:
def init_raises(hyperparams: typing.Mapping[str, int], exception: testing.Scenario.Exception) -> testing.Scenario:
"""Scenario fixture."""
return testing.Case(**hyperparams).raises(exception.kind, exception.message)


@pytest.fixture(scope='session')
def plainapply_raises(
hyperparams: typing.Mapping[str, int], apply_input: str, exception: spec.Scenario.Exception
) -> spec.Scenario:
hyperparams: typing.Mapping[str, int], apply_input: str, exception: testing.Scenario.Exception
) -> testing.Scenario:
"""Scenario fixture."""
return testing.Case(**hyperparams).apply(apply_input).raises(exception.kind, exception.message)


@pytest.fixture(scope='session')
def plainapply_returns(hyperparams: typing.Mapping[str, int], apply_input: str, apply_output: str) -> spec.Scenario:
def plainapply_returns(hyperparams: typing.Mapping[str, int], apply_input: str, apply_output: str) -> testing.Scenario:
"""Scenario fixture."""
return testing.Case(**hyperparams).apply(apply_input).returns(apply_output)


@pytest.fixture(scope='session')
def statetrain_raises(
hyperparams: typing.Mapping[str, int], train_input: str, label_input: str, exception: spec.Scenario.Exception
) -> spec.Scenario:
hyperparams: typing.Mapping[str, int], train_input: str, label_input: str, exception: testing.Scenario.Exception
) -> testing.Scenario:
"""Scenario fixture."""
return testing.Case(**hyperparams).train(train_input, label_input).raises(exception.kind, exception.message)


@pytest.fixture(scope='session')
def statetrain_returns(
hyperparams: typing.Mapping[str, int], train_input: str, label_input: str, train_output: str
) -> spec.Scenario:
) -> testing.Scenario:
"""Scenario fixture."""
return testing.Case(**hyperparams).train(train_input, label_input).returns(train_output)

Expand All @@ -131,8 +130,8 @@ def stateapply_raises(
train_input: str,
label_input: str,
apply_input,
exception: spec.Scenario.Exception,
) -> spec.Scenario:
exception: testing.Scenario.Exception,
) -> testing.Scenario:
"""Scenario fixture."""
return (
testing.Case(**hyperparams)
Expand All @@ -145,6 +144,6 @@ def stateapply_raises(
@pytest.fixture(scope='session')
def stateapply_returns(
hyperparams: typing.Mapping[str, int], train_input: str, label_input: str, apply_input: str, apply_output: str
) -> spec.Scenario:
) -> testing.Scenario:
"""Scenario fixture."""
return testing.Case(**hyperparams).train(train_input, label_input).apply(apply_input).returns(apply_output)

0 comments on commit 8ba54d6

Please sign in to comment.