Skip to content

Commit

Permalink
Merge 463b1cc into 95bdfdb
Browse files Browse the repository at this point in the history
  • Loading branch information
ncilfone authored Apr 6, 2021
2 parents 95bdfdb + 463b1cc commit a583b65
Show file tree
Hide file tree
Showing 16 changed files with 349 additions and 189 deletions.
2 changes: 1 addition & 1 deletion spock/backend/attr/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _update_payload(base_payload, input_classes, payload):
if isinstance(values, list):
# Check for incorrect specific override of global def
if keys not in attr_fields:
raise TypeError(f'Referring to a class space {keys} that is undefined')
raise ValueError(f'Referring to a class space {keys} that is undefined')
# We are in a repeated class def
# Raise if the key set is different from the defined set (i.e. incorrect arguments)
key_set = set(list(chain(*[list(val.keys()) for val in values])))
Expand Down
57 changes: 40 additions & 17 deletions spock/backend/attr/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ def __new__(cls, x):
return super().__new__(cls, x)


def _get_name_py_version(typed):
"""Gets the name of the type depending on the python version
*Args*:
typed: the type of the parameter
*Returns*:
name of the type
"""
return typed._name if hasattr(typed, '_name') else typed.__name__


def _extract_base_type(typed):
"""Extracts the name of the type from a _GenericAlias
Expand All @@ -43,10 +58,7 @@ def _extract_base_type(typed):
name of type
"""
if hasattr(typed, '__args__'):
if minor < 7:
name = typed.__name__
else:
name = typed._name
name = _get_name_py_version(typed=typed)
bracket_val = f"{name}[{_extract_base_type(typed.__args__[0])}]"
return bracket_val
else:
Expand Down Expand Up @@ -229,8 +241,6 @@ def _in_type(instance, attribute, value, options):
*Returns*:
"""
if type(options) not in [list, tuple, EnumMeta]:
raise TypeError(f'options argument must be of type List, Tuple, or Enum -- given {type(options)}')
if type(value) not in options:
raise ValueError(f'{attribute.name} must be in {options}')

Expand Down Expand Up @@ -294,10 +304,7 @@ def _type_katra(typed, default=None, optional=False):
if isinstance(typed, type):
name = typed.__name__
elif isinstance(typed, _GenericAlias):
if minor < 7:
name = typed.__name__
else:
name = typed._name
name = _get_name_py_version(typed=typed)
else:
raise TypeError('Encountered an uxpected type in _type_katra')
special_key = None
Expand Down Expand Up @@ -348,18 +355,32 @@ def _handle_optional_typing(typed):
type_args = typed.__args__
# Optional[X] has type_args = (X, None) and is equal to Union[X, None]
if (len(type_args) == 2) and (typed == Union[type_args[0], None]):
# Since this is true we need to strip out the OG type
# Grab all the types that are not NoneType and collapse to a list
type_list = [val for val in type_args if val is not type(None)]
if len(type_list) > 1:
raise TypeError(f"Passing multiple subscript types to GenericAlias is not supported: {type_list}")
else:
typed = type_list[0]
typed = type_args[0]
# Set the optional flag to true
optional = True
return typed, optional


def _check_generic_recursive_single_type(typed):
"""Checks generics for the single types -- mixed types of generics are not allowed
*Args*:
typed: type
*Returns*:
"""
# Check if it has __args__ to look for optionality as it is a GenericAlias
if hasattr(typed, '__args__'):
if len(set(typed.__args__)) > 1:
type_list = [str(val) for val in typed.__args__]
raise TypeError(f"Passing multiple different subscript types to GenericAlias is not supported: {type_list}")
else:
for val in typed.__args__:
_check_generic_recursive_single_type(typed=val)


def katra(typed, default=None):
"""Public interface to create a katra
Expand All @@ -380,6 +401,8 @@ def katra(typed, default=None):
"""
# Handle optionals
typed, optional = _handle_optional_typing(typed)
# Check generic types for consistent types
_check_generic_recursive_single_type(typed)
# We need to check if the type is a _GenericAlias so that we can handle subscripted general types
# If it is subscript typed it will not be T which python uses as a generic type name
if isinstance(typed, _GenericAlias) and (not isinstance(typed.__args__[0], TypeVar)):
Expand Down
44 changes: 2 additions & 42 deletions spock/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,22 +416,6 @@ def _make_group_override_parser(self, parser, class_obj):
group_parser = make_argument(arg_name, val_type, group_parser)
return parser

@staticmethod
def _check_protected_keys(all_attr):
"""Test for protected keys
Tests to see if an attribute has been defined at the genreral level that is within the protected list that
would break basic command line handling.
Args:
all_attr: dictionary of all attr
"""
protected_names = ['config', 'help']
if any([val in all_attr for val in protected_names]):
raise ValueError(f"Using a protected name from {protected_names} at general class level which prevents "
f"command line overrides")

@staticmethod
def _get_from_arg_parser(desc):
"""Get configs from command line
Expand Down Expand Up @@ -472,7 +456,7 @@ def _get_from_kwargs(args, configs):
if type(configs).__name__ == 'list':
args.config.extend(configs)
else:
raise TypeError('configs kwarg must be of type list')
raise TypeError(f'configs kwarg must be of type list -- given {type(configs)}')
return args

@staticmethod
Expand All @@ -481,7 +465,7 @@ def _find_attribute_idx(newline_split_docs):
*Args*:
newline_split_docs:
newline_split_docs: new line split text
Returns:
Expand Down Expand Up @@ -943,27 +927,3 @@ def _handle_payload_override(payload, key, value):
else:
curr_ref = curr_ref[split]
return payload

@staticmethod
def _dict_payload_override(payload, dict_key, val_name, value):
"""Updates the payload at the dictionary level
First checks to see if there is an existing dictionary to insert into, if not creates an empty one. Then it
inserts the updated value at the correct dictionary level
*Args*:
payload: current payload dictionary
dict_key: dictionary key to check
val_name: value name to update
value: value to update
*Returns*:
payload: updated payload dictionary
"""
if dict_key not in payload:
payload.update({dict_key: {}})
payload[dict_key][val_name] = value
return payload
20 changes: 9 additions & 11 deletions spock/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from spock.backend.attr.builder import AttrBuilder
from spock.backend.attr.payload import AttrPayload
from spock.backend.attr.saver import AttrSaver
from spock.backend.base import Spockspace
from spock.utils import check_payload_overwrite
from spock.utils import deep_payload_update

Expand Down Expand Up @@ -61,22 +60,17 @@ def __call__(self, *args, **kwargs):
"""
return ConfigArgBuilder(*args, **kwargs)

def generate(self, unclass=False):
def generate(self):
"""Generate method that returns the actual argument namespace
*Args*:
unclass: swaps the backend attr class type for dictionaries
*Returns*:
argument namespace consisting of all config classes
"""
if unclass:
self._arg_namespace = Spockspace(**{k: Spockspace(**{
val.name: getattr(v, val.name) for val in v.__attrs_attrs__})
for k, v in self._arg_namespace.__dict__.items()})
return self._arg_namespace

@staticmethod
Expand All @@ -95,11 +89,15 @@ def _set_backend(args):
# Gather if all attr backend
type_attrs = all([attr.has(arg) for arg in args])
if not type_attrs:
raise TypeError("*args must be of all attrs backend")
elif type_attrs:
backend = {'builder': AttrBuilder, 'payload': AttrPayload, 'saver': AttrSaver}
which_idx = [attr.has(arg) for arg in args].index(False)
if hasattr(args[which_idx], '__name__'):
raise TypeError(f"*args must be of all attrs backend -- missing a @spock decorator on class "
f"{args[which_idx].__name__}")
else:
raise TypeError(f"*args must be of all attrs backend -- invalid type "
f"{type(args[which_idx])}")
else:
raise TypeError("*args must be of all attrs backend")
backend = {'builder': AttrBuilder, 'payload': AttrPayload, 'saver': AttrSaver}
return backend

def _get_config_paths(self):
Expand Down
4 changes: 2 additions & 2 deletions spock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def add_repo_info(out_dict):
out_dict: output dictionary
"""
try:
try: # pragma: no cover
# Assume we are working out of a repo
repo = git.Repo(os.getcwd(), search_parent_directories=True)
# Check if we are really in a detached head state as later info will fail if we are
Expand All @@ -159,7 +159,7 @@ def add_repo_info(out_dict):
git_status = 'CLEAN'
out_dict.update({'# Git Status': git_status})
out_dict.update({'# Git Origin': repo.active_branch.commit.repo.remotes.origin.url})
except git.InvalidGitRepositoryError:
except git.InvalidGitRepositoryError: # pragma: no cover
# But it's okay if we are not
out_dict = make_blank_git(out_dict)
return out_dict
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ class StrChoice(Enum):
option_2 = 'option_2'


class FailedEnum(Enum):
str_type = 'hello'
float_type = 10.0


class IntChoice(Enum):
option_1 = 10
option_2 = 20
Expand Down
Loading

0 comments on commit a583b65

Please sign in to comment.