Skip to content

Commit

Permalink
Nested Classes -- Bug Fix (#141)
Browse files Browse the repository at this point in the history
* fixes issues wrt more than 2 levels of class nesting references. was ssuming to fall back on defaults instead of recursing the config space to set the correct parameters.

* linted
  • Loading branch information
ncilfone authored Oct 28, 2021
1 parent 7c9f9ae commit e9ca40b
Show file tree
Hide file tree
Showing 17 changed files with 281 additions and 32 deletions.
64 changes: 59 additions & 5 deletions spock/backend/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,34 +266,79 @@ def _handle_late_defaults(self, args, fields, input_class):
]
for val in names:
if val not in field_list:
# Gets the name of the class to default to
default_type_name = type(
getattr(input_class.__attrs_attrs__, val).default
).__name__
if default_type_name not in exclude_list:
# Gets the default class object
default_attr = getattr(input_class.__attrs_attrs__, val).default
# If the default is given for a class then it's the actual class and not a type -- logic needs
# to deal with both
if type(default_attr).__name__ == "type":
default_name = default_attr.__name__
else:
default_name = type(default_attr).__name__
# Skip if in the exclude list
else:
default_name = None
# if we need to fall back onto the default and ff it's in the arg_list then we have a
# definition coming in from the config file
if default_name is not None and default_name in arg_list:
# This handles lists of class type repeats -- these cannot be nested as the logic would be too
# confusing to map to
if isinstance(args.get(default_name), list):
default_value = [
self.input_classes[class_names.index(default_name)](
**arg_val
)
for arg_val in args.get(default_name)
]
# This handles basics and references to other classes -- here we need to recurse to grab any nested
# defs since classes are passed as strings to the config but are defined via Enums (handled #139)
else:
recurse_args = self._handle_recursive_defaults(
args.get(default_name), args, class_names
)
default_value = self.input_classes[
class_names.index(default_name)
](**args.get(default_name))
](**recurse_args)
fields.update({val: default_value})
return fields

def _handle_recursive_defaults(self, curr_arg, all_args, class_names):
"""Recurses through the args from the config read to determine if it can map to a definition
*Args*:
curr_arg: current argument
all_args: all argument dictionary
class_names: list of class names
*Returns*:
out_dict: recursively mapped dictionary of attributes
"""
out_dict = {}
for k, v in curr_arg.items():
# If the value is a reference to another class we need to recurse
if v in class_names:
# Recurse only if in the all_args dict (from the config file)
if v in all_args:
bubbled_dict = self._handle_recursive_defaults(
all_args.get(v), all_args, class_names
)
out_dict.update(
{k: self.input_classes[class_names.index(v)](**bubbled_dict)}
)
# Else fall back on default instantiation
else:
out_dict.update({k: self.input_classes[class_names.index(v)]()})
else:
out_dict.update({k: v})
return out_dict

def build_override_parsers(self, parser):
"""Creates parsers for command-line overrides
Expand Down Expand Up @@ -815,18 +860,27 @@ def _handle_nested_class(self, args, check_value, class_names):
"Match error -- multiple classes with the same name definition"
)
else:
if args.get(self.input_classes[match_idx[0]].__name__) is None:
if (args.get(self.input_classes[match_idx[0]].__name__) is None) and (
check_value not in class_names
):
raise ValueError(
f"Missing config file definition for the referenced class "
f"Cannot map a definition for the referenced class "
f"{self.input_classes[match_idx[0]].__name__}"
)
current_arg = args.get(self.input_classes[match_idx[0]].__name__)
current_arg = args.get(self.input_classes[match_idx[0]].__name__, {})
if isinstance(current_arg, list):
class_value = [
self.input_classes[match_idx[0]](**val) for val in current_arg
]
else:
class_value = self.input_classes[match_idx[0]](**current_arg)
recurse_args = (
self._handle_recursive_defaults(
args.get(check_value), args, class_names
)
if check_value in args
else {}
)
class_value = self.input_classes[match_idx[0]](**recurse_args)
return_value = class_value
# else return the expected value
else:
Expand Down
1 change: 1 addition & 0 deletions spock/backend/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _payload(self, input_classes, ignore_classes, path, deps, root=False):
base_payload = self._supported_extensions.get(config_extension)().load(
path, s3_config=self._s3_config
)
base_payload = {} if base_payload is None else base_payload
# Check and? update the dependencies
deps = self._handle_dependencies(deps, path, root)
if "config" in base_payload:
Expand Down
35 changes: 35 additions & 0 deletions tests/base/attr_configs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,37 @@ class ChoiceFail:
choice_p_str: StrChoice


@spock
class BaseDoubleNestedConfig:
morph_kernels_thickness: int = 1


@spock
class FirstDoubleNestedConfig(BaseDoubleNestedConfig):
h_factor: float = 0.95
v_factor: float = 0.95


@spock
class SecondDoubleNestedConfig(BaseDoubleNestedConfig):
morph_tolerance: float = 0.1


class DoubleNestedEnum(Enum):
first = FirstDoubleNestedConfig
second = SecondDoubleNestedConfig


@spock
class SingleNestedConfig:
"""Configuration for image based ops
Attributes:
kernel_config: MorphKernelConfig object
"""
double_nested_config: DoubleNestedEnum = SecondDoubleNestedConfig()


@spock
class TypeConfig:
"""This creates a test Spock config of all supported variable types as required parameters"""
Expand Down Expand Up @@ -111,6 +142,8 @@ class TypeConfig:
nested_list: List[NestedListStuff]
# Class Enum
class_enum: ClassChoice
# Double Nested class ref
high_config: SingleNestedConfig


@spock
Expand Down Expand Up @@ -206,6 +239,8 @@ class TypeDefaultConfig:
nested_list_def: List[NestedListStuff] = NestedListStuff
# Class Enum
class_enum_def: ClassChoice = NestedStuff
# Double Nested class ref
high_config_def: SingleNestedConfig = SingleNestedConfig()


@spock
Expand Down
16 changes: 15 additions & 1 deletion tests/base/base_asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Copyright FMR LLC <opensource@fidelity.com>
# SPDX-License-Identifier: Apache-2.0

from tests.base.attr_configs_test import FirstDoubleNestedConfig, SecondDoubleNestedConfig

class AllTypes:
# Required #
Expand Down Expand Up @@ -41,6 +41,10 @@ def test_all_set(self, arg_builder):
assert arg_builder.TypeConfig.nested_list[1].two == "bye"
assert arg_builder.TypeConfig.class_enum.one == 11
assert arg_builder.TypeConfig.class_enum.two == "ciao"
assert isinstance(arg_builder.TypeConfig.high_config.double_nested_config, FirstDoubleNestedConfig) is True
assert arg_builder.TypeConfig.high_config.double_nested_config.h_factor == 0.99
assert arg_builder.TypeConfig.high_config.double_nested_config.v_factor == 0.90

# Optional #
assert arg_builder.TypeOptConfig.int_p_opt_no_def is None
assert arg_builder.TypeOptConfig.float_p_opt_no_def is None
Expand Down Expand Up @@ -92,6 +96,11 @@ def test_all_defaults(self, arg_builder):
assert arg_builder.TypeDefaultConfig.nested_list_def[1].two == "bye"
assert arg_builder.TypeDefaultConfig.class_enum_def.one == 11
assert arg_builder.TypeDefaultConfig.class_enum_def.two == "ciao"
assert isinstance(arg_builder.TypeDefaultConfig.high_config_def.double_nested_config,
FirstDoubleNestedConfig) is True
assert arg_builder.TypeDefaultConfig.high_config_def.double_nested_config.h_factor == 0.99
assert arg_builder.TypeDefaultConfig.high_config_def.double_nested_config.v_factor == 0.90

# Optional w/ Defaults #
assert arg_builder.TypeDefaultOptConfig.int_p_opt_def == 10
assert arg_builder.TypeDefaultOptConfig.float_p_opt_def == 10.0
Expand Down Expand Up @@ -163,6 +172,11 @@ def test_all_inherited(self, arg_builder):
assert arg_builder.TypeInherited.nested_list[1].two == "bye"
assert arg_builder.TypeInherited.class_enum.one == 11
assert arg_builder.TypeInherited.class_enum.two == "ciao"
assert isinstance(arg_builder.TypeInherited.high_config.double_nested_config,
FirstDoubleNestedConfig) is True
assert arg_builder.TypeInherited.high_config.double_nested_config.h_factor == 0.99
assert arg_builder.TypeInherited.high_config.double_nested_config.v_factor == 0.90

# Optional w/ Defaults #
assert arg_builder.TypeInherited.int_p_opt_def == 10
assert arg_builder.TypeInherited.float_p_opt_def == 10.0
Expand Down
26 changes: 22 additions & 4 deletions tests/base/test_cmd_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,17 @@ def arg_builder(monkeypatch):
"[11, 21]",
"--TypeConfig.nested_list.NestedListStuff.two",
"['Hooray', 'Working']",
"--TypeConfig.high_config",
"SingleNestedConfig",
"--SingleNestedConfig.double_nested_config",
"SecondDoubleNestedConfig",
"--SecondDoubleNestedConfig.morph_tolerance",
"0.2"
],
)
config = ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, desc="Test Builder"
TypeConfig, NestedStuff, NestedListStuff, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig,desc="Test Builder"
)
return config.generate()

Expand Down Expand Up @@ -110,6 +117,8 @@ def test_class_overrides(self, arg_builder):
assert arg_builder.NestedListStuff[0].two == "Hooray"
assert arg_builder.NestedListStuff[1].one == 21
assert arg_builder.NestedListStuff[1].two == "Working"
assert isinstance(arg_builder.SingleNestedConfig.double_nested_config, SecondDoubleNestedConfig) is True
assert arg_builder.SecondDoubleNestedConfig.morph_tolerance == 0.2


class TestClassOnlyCmdLine:
Expand Down Expand Up @@ -177,10 +186,13 @@ def arg_builder(monkeypatch):
"[11, 21]",
"--TypeConfig.nested_list.NestedListStuff.two",
"['Hooray', 'Working']",
"--TypeConfig.high_config",
"SingleNestedConfig"
],
)
config = ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, desc="Test Builder"
TypeConfig, NestedStuff, NestedListStuff, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig, desc="Test Builder"
)
return config.generate()

Expand Down Expand Up @@ -211,6 +223,10 @@ def test_class_overrides(self, arg_builder):
assert arg_builder.TypeConfig.list_choice_p_float == [20.0]
assert arg_builder.TypeConfig.class_enum.one == 12
assert arg_builder.TypeConfig.class_enum.two == "ancora"
assert isinstance(arg_builder.TypeConfig.high_config.double_nested_config,
SecondDoubleNestedConfig) is True
assert arg_builder.TypeConfig.high_config.double_nested_config.morph_kernels_thickness == 1
assert arg_builder.TypeConfig.high_config.double_nested_config.morph_tolerance == 0.1
assert arg_builder.NestedListStuff[0].one == 11
assert arg_builder.NestedListStuff[0].two == "Hooray"
assert arg_builder.NestedListStuff[1].one == 21
Expand All @@ -235,7 +251,8 @@ def test_cmd_line_no_key(self, monkeypatch):
)
with pytest.raises(SystemExit):
config = ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, desc="Test Builder"
TypeConfig, NestedStuff, NestedListStuff, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig, desc="Test Builder"
)
return config.generate()

Expand All @@ -258,6 +275,7 @@ def test_cmd_line_list_len(self, monkeypatch):
)
with pytest.raises(ValueError):
config = ConfigArgBuilder(
TypeConfig, NestedStuff, NestedListStuff, desc="Test Builder"
TypeConfig, NestedStuff, NestedListStuff, SingleNestedConfig,
FirstDoubleNestedConfig, SecondDoubleNestedConfig, desc="Test Builder"
)
return config.generate()
Loading

0 comments on commit e9ca40b

Please sign in to comment.