Skip to content

Commit

Permalink
Change @component decorator so it doesn't add default to_dict and fro…
Browse files Browse the repository at this point in the history
…m_dict (#98)
  • Loading branch information
silvanocerza committed Aug 18, 2023
1 parent 8b5b405 commit a7141ff
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 122 deletions.
51 changes: 14 additions & 37 deletions canals/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,16 @@

import logging
import inspect
from typing import Protocol, Union, Dict, Type, Any, get_origin, get_args
from typing import Protocol, Union, Dict, Any, get_origin, get_args
from functools import wraps

from canals.errors import ComponentError, ComponentDeserializationError
from canals.errors import ComponentError


logger = logging.getLogger(__name__)


# We ignore too-few-public-methods Pylint error as this is only meant to be
# the definition of the Component interface.
class Component(Protocol): # pylint: disable=too-few-public-methods
class Component(Protocol):
"""
Abstract interface of a Component.
This is only used by type checking tools.
Expand Down Expand Up @@ -231,11 +229,21 @@ def _component(self, class_):
"""
logger.debug("Registering %s as a component", class_)

# Check for run()
# Check for required methods
if not hasattr(class_, "run"):
raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
run_signature = inspect.signature(class_.run)

if not hasattr(class_, "to_dict"):
raise ComponentError(
f"{class_.__name__} must have a 'to_dict()' method. See the docs for more information."
)

if not hasattr(class_, "from_dict"):
raise ComponentError(
f"{class_.__name__} must have a 'from_dict()' method. See the docs for more information."
)

# Create the input sockets
class_.run.__canals_input__ = {
param: {
Expand All @@ -259,12 +267,6 @@ def _component(self, class_):

setattr(class_, "__canals_component__", True)

if not hasattr(class_, "to_dict"):
class_.to_dict = _default_component_to_dict

if not hasattr(class_, "from_dict"):
class_.from_dict = classmethod(_default_component_from_dict)

return class_

def __call__(self, class_=None):
Expand All @@ -283,28 +285,3 @@ def _is_optional(type_: type) -> bool:
Utility method that returns whether a type is Optional.
"""
return get_origin(type_) is Union and type(None) in get_args(type_)


def _default_component_to_dict(comp: Component) -> Dict[str, Any]:
"""
Default component serializer.
Serializes a component to a dictionary.
"""
return {
"hash": id(comp),
"type": comp.__class__.__name__,
"init_parameters": getattr(comp, "init_parameters", {}),
}


def _default_component_from_dict(cls: Type[Component], data: Dict[str, Any]) -> Component:
"""
Default component deserializer.
The "type" field in `data` must match the class that is being deserialized into.
"""
init_params = data.get("init_parameters", {})
if "type" not in data:
raise ComponentDeserializationError("Missing 'type' in component serialization data")
if data["type"] != cls.__name__:
raise ComponentDeserializationError(f"Component '{data['type']}' can't be deserialized as '{cls.__name__}'")
return cls(**init_params)
9 changes: 9 additions & 0 deletions canals/testing/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, Optional, Tuple, Type

from canals import component, Component
from canals.serialization import default_to_dict, default_from_dict


def component_class(
Expand Down Expand Up @@ -98,9 +99,17 @@ def run(self, **kwargs): # pylint: disable=unused-argument
return output
return {name: None for name in output_types.keys()}

def to_dict(self):
return default_to_dict(self)

def from_dict(cls, data: Dict[str, Any]):
return default_from_dict(cls, data)

fields = {
"__init__": init,
"run": run,
"to_dict": to_dict,
"from_dict": classmethod(from_dict),
}
if extra_fields is not None:
fields = {**fields, **extra_fields}
Expand Down
113 changes: 50 additions & 63 deletions test/component/test_component.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from typing import Any
from unittest.mock import Mock

import pytest

from canals import component
from canals.component.component import _default_component_to_dict, _default_component_from_dict
from canals.testing import factory
from canals.errors import ComponentError, ComponentDeserializationError
from canals.errors import ComponentError


def test_correct_declaration():
@component
class MockComponent:
def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}
Expand All @@ -28,6 +32,13 @@ class MockComponent:
def store(self):
return "test_store"

def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}
Expand All @@ -49,6 +60,13 @@ def store(self):
def store(self, value):
self._store = value

def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}
Expand All @@ -74,6 +92,13 @@ class MockComponent:
def __init__(self):
component.set_input_types(self, value=Any)

def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()

@component.output_types(value=int)
def run(self, **kwargs):
return {"value": 1}
Expand All @@ -95,6 +120,13 @@ class MockComponent:
def __init__(self):
component.set_output_types(self, value=int)

def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()

def run(self, value: int):
return {"value": 1}

Expand All @@ -114,6 +146,13 @@ class MockComponent:
def run(self, value: int):
return {"value": 1}

def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()

comp = MockComponent()
assert comp.run.__canals_output__ == {
"value": {
Expand All @@ -130,64 +169,12 @@ class MockComponent:
def run(self, value: int):
return {"value": 1}

comp = MockComponent()
assert comp.__canals_component__


def test_default_component_to_dict():
MyComponent = factory.component_class("MyComponent")
comp = MyComponent()
res = _default_component_to_dict(comp)
assert res == {
"hash": id(comp),
"type": "MyComponent",
"init_parameters": {},
}
def to_dict(self):
return {}

@classmethod
def from_dict(cls, data):
return cls()

def test_default_component_to_dict_with_init_parameters():
extra_fields = {"init_parameters": {"some_key": "some_value"}}
MyComponent = factory.component_class("MyComponent", extra_fields=extra_fields)
comp = MyComponent()
res = _default_component_to_dict(comp)
assert res == {
"hash": id(comp),
"type": "MyComponent",
"init_parameters": {"some_key": "some_value"},
}


def test_default_component_from_dict():
def custom_init(self, some_param):
self.some_param = some_param

extra_fields = {"__init__": custom_init}
MyComponent = factory.component_class("MyComponent", extra_fields=extra_fields)
comp = _default_component_from_dict(
MyComponent,
{
"type": "MyComponent",
"init_parameters": {
"some_param": 10,
},
"hash": 1234,
},
)
assert isinstance(comp, MyComponent)
assert comp.some_param == 10


def test_default_component_from_dict_without_type():
with pytest.raises(ComponentDeserializationError, match="Missing 'type' in component serialization data"):
_default_component_from_dict(Mock, {})


def test_default_component_from_dict_unregistered_component(request):
# We use the test function name as component name to make sure it's not registered.
# Since the registry is global we risk to have a component with the same name registered in another test.
component_name = request.node.name

with pytest.raises(
ComponentDeserializationError, match=f"Component '{component_name}' can't be deserialized as 'Mock'"
):
_default_component_from_dict(Mock, {"type": component_name})
comp = MockComponent()
assert comp.__canals_component__
28 changes: 6 additions & 22 deletions test/pipelines/unit/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,10 @@ def test_connect_many_outputs_to_the_same_input():


def test_connect_many_connections_possible_name_matches():
@component
class Component1:
@component.output_types(value=str)
def run(self, value: str):
return {"value": value}

@component
class Component2:
@component.output_types(value=str)
def run(self, value: str, othervalue: str, yetanothervalue: str):
return {"value": value}
Component1 = factory.component_class("Component1", output_types={"value": str})
Component2 = factory.component_class(
"Component2", input_types={"value": str, "othervalue": str, "yetanothervalue": str}
)

pipe = Pipeline()
pipe.add_component("c1", Component1())
Expand All @@ -398,17 +391,8 @@ def run(self, value: str, othervalue: str, yetanothervalue: str):


def test_connect_many_connections_possible_no_name_matches():
@component
class Component1:
@component.output_types(value=str)
def run(self, value: str):
return {"value": value}

@component
class Component2:
@component.output_types(value=str)
def run(self, value1: str, value2: str, value3: str):
return {"value": value1}
Component1 = factory.component_class("Component1", output_types={"value": str})
Component2 = factory.component_class("Component2", input_types={"value1": str, "value2": str, "value3": str})

expected_message = re.escape(
"""Cannot connect 'c1' with 'c2': more than one connection is possible between these components. Please specify the connection name, like: pipeline.connect('c1.value', 'c2.value1').
Expand Down

0 comments on commit a7141ff

Please sign in to comment.