Skip to content

Commit

Permalink
Fix missing classes in spockspace (#201)
Browse files Browse the repository at this point in the history
* working implementation that catches nested config class defs and enums.

* fixed bug in enum of class that is optional leading to NoneType in dict. Added unit test to check that serialization/de-serialization works correctly

* linted
  • Loading branch information
ncilfone committed Jan 10, 2022
1 parent d8eee88 commit 14aa1ab
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 245 deletions.
46 changes: 40 additions & 6 deletions spock/backend/field_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,31 +60,48 @@ def __call__(self, attr_space: AttributeSpace, builder_space: BuilderSpace):
Returns:
"""
if self._is_attribute_in_config_arguments(attr_space, builder_space.arguments):
if self._is_attribute_in_config_arguments(attr_space, builder_space):
self.handle_attribute_from_config(attr_space, builder_space)
elif self._is_attribute_optional(attr_space.attribute):
if isinstance(attr_space.attribute.default, type):
self.handle_optional_attribute_type(attr_space, builder_space)
else:
self.handle_optional_attribute_value(attr_space, builder_space)

@staticmethod
def _is_attribute_in_config_arguments(
attr_space: AttributeSpace, arguments: SpockArguments
self, attr_space: AttributeSpace, builder_space: BuilderSpace
):
"""Checks if an attribute is in the configuration file or keyword arguments dictionary
Will recurse spock classes as dependencies might be defined in the configs class
Args:
attr_space: holds information about a single attribute that is mapped to a ConfigSpace
arguments: map of the read/cmd-line parameter dictionary to general or class level arguments
builder_space: map of the read/cmd-line parameter dictionary to general or class level arguments
Returns:
boolean if in dictionary
"""
# Instances might have other instances that might be defined in the configs
# Recurse to try and catch all config defs
# Only map if default is not None -- do so by evolving the attribute
if (
_is_spock_instance(attr_space.attribute.type)
and attr_space.attribute.default is not None
):
attr_space.field, special_keys = RegisterSpockCls().recurse_generate(
attr_space.attribute.type, builder_space
)
attr_space.attribute = attr_space.attribute.evolve(default=attr_space.field)
builder_space.spock_space[
attr_space.attribute.type.__name__
] = attr_space.field
self.special_keys.update(special_keys)
return (
attr_space.config_space.name in arguments
and attr_space.attribute.name in arguments[attr_space.config_space.name]
attr_space.config_space.name in builder_space.arguments
and attr_space.attribute.name
in builder_space.arguments[attr_space.config_space.name]
)

@staticmethod
Expand Down Expand Up @@ -276,6 +293,23 @@ def handle_optional_attribute_type(
attr_space.attribute.default, attr_space, builder_space
)

def handle_optional_attribute_value(
self, attr_space: AttributeSpace, builder_space: BuilderSpace
):
"""Handles setting an optional value with its default
Args:
attr_space: holds information about a single attribute that is mapped to a ConfigSpace
builder_space: named_tuple containing the arguments and spock_space
Returns:
"""
super().handle_optional_attribute_value(attr_space, builder_space)
if attr_space.field is not None:
builder_space.spock_space[
type(attr_space.field).__name__
] = attr_space.field

def _handle_and_register_enum(
self, enum_cls, attr_space: AttributeSpace, builder_space: BuilderSpace
):
Expand Down
2 changes: 2 additions & 0 deletions spock/backend/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def _clean_up_values(self, payload):
)
# Convert values
clean_dict = self._clean_output(out_dict)
# Clip any empty dictionaries
clean_dict = {k: v for k, v in clean_dict.items() if len(v) > 0}
return clean_dict

def _clean_tuner_values(self, payload):
Expand Down
5 changes: 5 additions & 0 deletions spock/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def _build(self):
for input_class in self._input_classes:
dep_classes = self._find_all_spock_classes(input_class)
for v in dep_classes:
if v not in nodes:
raise ValueError(
f"Missing @spock decorated class -- `{v.__name__}` was not passed as an *arg to "
f"ConfigArgBuilder"
)
nodes.get(v).append(input_class)
nodes = {key: set(val) for key, val in nodes.items()}
return nodes
Expand Down
20 changes: 19 additions & 1 deletion tests/base/attr_configs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class NestedStuff:
two: str


@spock
class NestedStuffOpt:
one: int = 1
two: str = 'boo'


@spock
class NestedListStuff:
one: int
Expand Down Expand Up @@ -186,7 +192,7 @@ class TypeOptConfig:
# Required list of list of choice -- Str
list_list_choice_p_opt_no_def_str: Optional[List[List[StrChoice]]]
# Nested configuration
nested_opt_no_def: Optional[NestedStuff]
nested_opt_no_def: Optional[NestedStuffOpt]
# Nested list configuration
nested_list_opt_no_def: Optional[List[NestedListStuff]]
# Class Enum
Expand Down Expand Up @@ -302,3 +308,15 @@ class TypeInherited(TypeConfig, TypeDefaultOptConfig):
"""This tests inheritance with mixed default and non-default arguments"""

...


all_configs = [
TypeConfig,
NestedStuff,
NestedListStuff,
TypeOptConfig,
SingleNestedConfig,
FirstDoubleNestedConfig,
SecondDoubleNestedConfig,
NestedStuffOpt
]
7 changes: 2 additions & 5 deletions tests/base/base_asserts_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# -*- coding: utf-8 -*-

# Copyright FMR LLC <opensource@fidelity.com>
# SPDX-License-Identifier: Apache-2.0
from tests.base.attr_configs_test import (
FirstDoubleNestedConfig,
SecondDoubleNestedConfig,
)

from tests.base.attr_configs_test import FirstDoubleNestedConfig


class AllTypes:
Expand Down
83 changes: 25 additions & 58 deletions tests/base/test_config_arg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,15 @@ class TestBasic(AllTypes):
def arg_builder(monkeypatch):
with monkeypatch.context() as m:
m.setattr(sys, "argv", ["", "--config", "./tests/conf/yaml/test.yaml"])
config = ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig
)
config = ConfigArgBuilder(*all_configs)
return config.generate()


class TestConfigDict:
def test_config_2_dict(self, monkeypatch):
with monkeypatch.context() as m:
m.setattr(sys, "argv", ["", "--config", "./tests/conf/yaml/test.yaml"])
config_dict = ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig
).config_2_dict
config_dict = ConfigArgBuilder(*all_configs).config_2_dict
assert isinstance(config_dict, dict) is True


Expand All @@ -42,13 +36,7 @@ class TestNoCmdLineKwarg(AllTypes):
def arg_builder(monkeypatch):
with monkeypatch.context() as m:
config = ConfigArgBuilder(
TypeConfig,
NestedStuff,
NestedListStuff,
TypeOptConfig,
SingleNestedConfig,
FirstDoubleNestedConfig,
SecondDoubleNestedConfig,
*all_configs,
no_cmd_line=True,
configs=["./tests/conf/yaml/test.yaml"],
)
Expand All @@ -62,13 +50,7 @@ def test_cmd_line_kwarg_raise(self, monkeypatch):
with monkeypatch.context() as m:
with pytest.raises(TypeError):
config = ConfigArgBuilder(
TypeConfig,
NestedStuff,
NestedListStuff,
TypeOptConfig,
SingleNestedConfig,
FirstDoubleNestedConfig,
SecondDoubleNestedConfig,
*all_configs,
no_cmd_line=True,
configs="./tests/conf/yaml/test.yaml",
)
Expand All @@ -82,13 +64,7 @@ def test_choice_raise(self, monkeypatch):
with monkeypatch.context() as m:
with pytest.raises(ValueError):
ConfigArgBuilder(
TypeConfig,
NestedStuff,
NestedListStuff,
TypeOptConfig,
SingleNestedConfig,
FirstDoubleNestedConfig,
SecondDoubleNestedConfig,
*all_configs,
no_cmd_line=True,
)

Expand All @@ -102,13 +78,7 @@ def arg_builder(monkeypatch):
with monkeypatch.context() as m:
m.setattr(sys, "argv", [""])
config = ConfigArgBuilder(
TypeConfig,
NestedStuff,
NestedListStuff,
TypeOptConfig,
SingleNestedConfig,
FirstDoubleNestedConfig,
SecondDoubleNestedConfig,
*all_configs,
desc="Test Builder",
configs=["./tests/conf/yaml/test.yaml"],
)
Expand All @@ -134,19 +104,26 @@ class AttrFail:
failed_attr: int

config = ConfigArgBuilder(
TypeConfig,
NestedStuff,
NestedListStuff,
TypeOptConfig,
SingleNestedConfig,
FirstDoubleNestedConfig,
SecondDoubleNestedConfig,
*all_configs,
AttrFail,
configs=["./tests/conf/yaml/test.yaml"],
)
return config.generate()


class TestRaisesMissingClass:
"""Testing basic functionality"""

def test_raises_missing_class(self, monkeypatch):
with monkeypatch.context() as m:
m.setattr(sys, "argv", ["", "--config", "./tests/conf/yaml/test.yaml"])
with pytest.raises(ValueError):
config = ConfigArgBuilder(
*all_configs[:-1]
)
return config.generate()


class TestRaiseWrongInputType:
"""Check all required types work as expected"""

Expand All @@ -155,13 +132,7 @@ def test_wrong_input_raise(self, monkeypatch):
m.setattr(sys, "argv", ["", "--config", "./tests/conf/yaml/test.foo"])
with pytest.raises(TypeError):
config = ConfigArgBuilder(
TypeConfig,
NestedStuff,
NestedListStuff,
TypeOptConfig,
SingleNestedConfig,
FirstDoubleNestedConfig,
SecondDoubleNestedConfig,
*all_configs,
desc="Test Builder",
)
return config.generate()
Expand All @@ -175,8 +146,7 @@ def test_type_unknown(self, monkeypatch):
)
with pytest.raises(ValueError):
ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig, desc="Test Builder"
*all_configs, desc="Test Builder"
)


Expand All @@ -190,8 +160,7 @@ def test_class_parameter_unknown(self, monkeypatch):
)
with pytest.raises(ValueError):
ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig, desc="Test Builder"
*all_configs, desc="Test Builder"
)


Expand All @@ -205,8 +174,7 @@ def test_class_unknown(self, monkeypatch):
)
with pytest.raises(TypeError):
ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig, desc="Test Builder"
*all_configs, desc="Test Builder"
)


Expand All @@ -224,6 +192,5 @@ def test_class_unknown(self, monkeypatch):
)
with pytest.raises(ValueError):
ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig, desc="Test Builder"
*all_configs, desc="Test Builder"
)
Loading

0 comments on commit 14aa1ab

Please sign in to comment.