Skip to content
This repository has been archived by the owner on Jun 22, 2022. It is now read-only.

Commit

Permalink
Fix issue #28: Unintuitive adapter syntax (#42)
Browse files Browse the repository at this point in the history
* Write tests for new adapter syntax

* Refactor adapter

* Improve handling of caches and logs in tests

* Fix minor issues mentioned in PR comments

* Rewrite tests in pytest framework

* Move adapting to seperate class, alter behaviour

* Correction: mutable object as default argument in Step initializer
  • Loading branch information
grzes314 authored and Kamil A. Kaczmarek committed May 17, 2018
1 parent 4785256 commit f6f6364
Show file tree
Hide file tree
Showing 8 changed files with 355 additions and 25 deletions.
14 changes: 13 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# IDE files
.idea/
.ipynb_checkpoints/

# Generated files
*.pyc
*.log
.pytest_cache

# Working directories
examples/cache/
tests/.cache

# Unwanted notebook files
Untitled*.ipynb
.ipynb_checkpoints/

57 changes: 57 additions & 0 deletions steps/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Dict, Union, Tuple, List, Dict, Any, NamedTuple

E = NamedTuple('E', [('input_name', str), ('key', str)])

AdaptingRecipe = Any
Results = Dict[str, Any]
AllInputs = Dict[str, Any]


class AdapterError(Exception):
pass


class Adapter:
def __init__(self, adapting_recipes: Dict[str, AdaptingRecipe]):
self.adapting_recipes = adapting_recipes

def adapt(self, all_inputs: AllInputs) -> Dict[str, Any]:
adapted = {}
for name, recipe in self.adapting_recipes.items():
adapted[name] = self._construct(all_inputs, recipe)
return adapted

def _construct(self, all_inputs: AllInputs, recipe: AdaptingRecipe) -> Any:
return {
E: self._construct_element,
tuple: self._construct_tuple,
list: self._construct_list,
dict: self._construct_dict,
}.get(recipe.__class__, self._construct_constant)(all_inputs, recipe)

def _construct_constant(self, _: AllInputs, constant) -> Any:
return constant

def _construct_element(self, all_inputs: AllInputs, element: E):
input_name = element.input_name
key = element.key
try:
input_results = all_inputs[input_name]
try:
return input_results[key]
except KeyError:
msg = "Input '{}' didn't have '{}' in its result.".format(input_name, key)
raise AdapterError(msg)
except KeyError:
msg = "No such input: '{}'".format(input_name)
raise AdapterError(msg)

def _construct_list(self, all_inputs: AllInputs, lst: List[AdaptingRecipe]):
return [self._construct(all_inputs, recipe) for recipe in lst]

def _construct_tuple(self, all_inputs: AllInputs, tup: Tuple):
return tuple(self._construct(all_inputs, recipe) for recipe in tup)

def _construct_dict(self, all_inputs: AllInputs, dic: Dict[AdaptingRecipe, AdaptingRecipe]):
return {self._construct(all_inputs, k): self._construct(all_inputs, v)
for k, v in dic.items()}
53 changes: 29 additions & 24 deletions steps/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import pprint
import shutil
from collections import defaultdict

from sklearn.externals import joblib

from .adapters import take_first_inputs
from .utils import view_graph, plot_graph, get_logger, initialize_logger
from .adapter import AdapterError

initialize_logger()
logger = get_logger()
Expand All @@ -29,8 +30,8 @@ class Step:
def __init__(self,
name,
transformer,
input_steps=[],
input_data=[],
input_steps=None,
input_data=None,
adapter=None,
cache_dirpath=None,
cache_output=False,
Expand Down Expand Up @@ -128,11 +129,10 @@ def __init__(self,
"""

self.name = name

self.transformer = transformer

self.input_steps = input_steps
self.input_data = input_data
self.input_steps = input_steps or []
self.input_data = input_data or []
self.adapter = adapter

self.force_fitting = force_fitting
Expand Down Expand Up @@ -371,29 +371,30 @@ def _cached_transform(self, step_inputs):

def _adapt(self, step_inputs):
logger.info('step {} adapting inputs'.format(self.name))
adapted_steps = {}
for adapted_name, mapping in self.adapter.items():
if isinstance(mapping, str):
adapted_steps[adapted_name] = step_inputs[mapping]
else:
if len(mapping) == 2:
(step_mapping, func) = mapping
elif len(mapping) == 1:
step_mapping = mapping
func = take_first_inputs
else:
raise ValueError('wrong mapping specified')

raw_inputs = [step_inputs[step_name][step_var] for step_name, step_var in step_mapping]
adapted_steps[adapted_name] = func(raw_inputs)
return adapted_steps
try:
return self.adapter.adapt(step_inputs)
except AdapterError as e:
msg = "Error while adapting step '{}'".format(self.name)
raise StepsError(msg) from e

def _unpack(self, step_inputs):
logger.info('step {} unpacking inputs'.format(self.name))
unpacked_steps = {}
key_to_step_names = defaultdict(list)
for step_name, step_dict in step_inputs.items():
unpacked_steps = {**unpacked_steps, **step_dict}
return unpacked_steps
unpacked_steps.update(step_dict)
for key in step_dict.keys():
key_to_step_names[key].append(step_name)

repeated_keys = [(key, step_names) for key, step_names in key_to_step_names.items()
if len(step_names) > 1]
if len(repeated_keys) == 0:
return unpacked_steps
else:
msg = "Could not unpack inputs. Following keys are present in multiple input steps:\n"\
"\n".join([" '{}' present in steps {}".format(key, step_names)
for key, step_names in repeated_keys])
raise StepsError(msg)

def _get_steps(self, all_steps):
for input_step in self.input_steps:
Expand Down Expand Up @@ -514,3 +515,7 @@ class NoOperation(BaseTransformer):
"""
def transform(self, **kwargs):
return kwargs


class StepsError(Exception):
pass
Empty file added tests/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import steps.base # To make sure logger is initialized before running prepare_steps_logger

from .steps_test_utils import prepare_steps_logger, remove_cache


def pytest_sessionstart(session):
prepare_steps_logger()


def pytest_runtest_setup(item):
remove_cache()


def pytest_runtest_teardown(item):
remove_cache()
33 changes: 33 additions & 0 deletions tests/steps_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging
import os
import shutil
from pathlib import Path


CACHE_DIRPATH = '.cache'
LOGS_PATH = 'steps.log'


def remove_cache():
if Path(CACHE_DIRPATH).exists():
shutil.rmtree(CACHE_DIRPATH)


def remove_logs():
if Path(LOGS_PATH).exists():
os.remove(LOGS_PATH)


def prepare_steps_logger():
print("Redirecting logging to {}.".format(LOGS_PATH))
remove_logs()
logger = logging.getLogger('steps')
for h in logger.handlers:
logger.removeHandler(h)
message_format = logging.Formatter(fmt='%(asctime)s %(name)s >>> %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
fh = logging.FileHandler(LOGS_PATH)
fh.setLevel(logging.INFO)
fh.setFormatter(fmt=message_format)
logger.addHandler(fh)

142 changes: 142 additions & 0 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import pytest
import numpy as np

from steps.adapter import Adapter, E


@pytest.fixture
def data():
return {
'input_1': {
'features': np.array([
[1, 6],
[2, 5],
[3, 4]
]),
'labels': np.array([2, 5, 3])
},
'input_2': {
'extra_features': np.array([
[5, 7, 3],
[67, 4, 5],
[6, 13, 14]
])
},
'input_3': {
'images': np.array([
[[0, 255], [255, 0]],
[[255, 0], [0, 255]],
[[255, 255], [0, 0]],
]),
'labels': np.array([1, 1, 0])
}
}


def test_adapter_creates_defined_keys(data):
adapter = Adapter({
'X': [E('input_1', 'features')],
'Y': [E('input_2', 'extra_features')]
})
res = adapter.adapt(data)
assert {'X', 'Y'} == set(res.keys())


def test_recipe_with_single_item(data):
adapter = Adapter({
'X': E('input_1', 'labels'),
'Y': E('input_3', 'labels'),
})
res = adapter.adapt(data)
assert np.array_equal(res['X'], data['input_1']['labels'])
assert np.array_equal(res['Y'], data['input_3']['labels'])


def test_recipe_with_list(data):
adapter = Adapter({
'X': [],
'Y': [E('input_1', 'features')],
'Z': [E('input_1', 'features'), E('input_2', 'extra_features')]
})
res = adapter.adapt(data)

for i, key in enumerate(('X', 'Y', 'Z')):
assert isinstance(res[key], list)
assert len(res[key]) == i

assert res['X'] == []

assert np.array_equal(res['Y'][0], data['input_1']['features'])

assert np.array_equal(res['Z'][0], data['input_1']['features'])
assert np.array_equal(res['Z'][1], data['input_2']['extra_features'])


def test_recipe_with_tuple(data):
adapter = Adapter({
'X': (),
'Y': (E('input_1', 'features'),),
'Z': (E('input_1', 'features'), E('input_2', 'extra_features'))
})
res = adapter.adapt(data)

for i, key in enumerate(('X', 'Y', 'Z')):
assert isinstance(res[key], tuple)
assert len(res[key]) == i

assert res['X'] == ()

assert np.array_equal(res['Y'][0], data['input_1']['features'])

assert np.array_equal(res['Z'][0], data['input_1']['features'])
assert np.array_equal(res['Z'][1], data['input_2']['extra_features'])


def test_recipe_with_dictionary(data):
adapter = Adapter({
'X': {},
'Y': {'a': E('input_1', 'features')},
'Z': {'a': E('input_1', 'features'), 'b': E('input_2', 'extra_features')}
})
res = adapter.adapt(data)

for i, key in enumerate(('X', 'Y', 'Z')):
assert isinstance(res[key], dict)
assert len(res[key]) == i

assert res['X'] == {}

assert np.array_equal(res['Y']['a'], data['input_1']['features'])

assert np.array_equal(res['Z']['a'], data['input_1']['features'])
assert np.array_equal(res['Z']['b'], data['input_2']['extra_features'])


def test_recipe_with_constants(data):
adapter = Adapter({
'A': 112358,
'B': 3.14,
'C': "lorem ipsum",
'D': ('input_1', 'features'),
'E': {112358: 112358, 'a': 'a', 3.14: 3.14},
'F': [112358, 3.14, "lorem ipsum", ('input_1', 'features')]
})
res = adapter.adapt(data)

assert res['A'] == 112358
assert res['B'] == 3.14
assert res['C'] == "lorem ipsum"
assert res['D'] == ('input_1', 'features')
assert res['E'] == {112358: 112358, 'a': 'a', 3.14: 3.14}
assert res['F'] == [112358, 3.14, "lorem ipsum", ('input_1', 'features')]


def test_nested_recipes(data):
adapter = Adapter({
'X': [{'a': [E('input_1', 'features')]}],
'Y': {'a': [{'b': E('input_2', 'extra_features')}]}
})
res = adapter.adapt(data)

assert res['X'] == [{'a': [data['input_1']['features']]}]
assert res['Y'] == {'a': [{'b': data['input_2']['extra_features']}]}

0 comments on commit f6f6364

Please sign in to comment.