Skip to content

Commit

Permalink
Merge b639cfd into cbbcf1a
Browse files Browse the repository at this point in the history
  • Loading branch information
ncilfone committed Jan 21, 2022
2 parents cbbcf1a + b639cfd commit de482aa
Show file tree
Hide file tree
Showing 14 changed files with 296 additions and 86 deletions.
13 changes: 12 additions & 1 deletion spock/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Callable, Tuple, TypeVar, Union, overload
from typing import Any, Callable, Tuple, TypeVar, Union, overload, Optional, List, Type

from spock.builder import ConfigArgBuilder

from attr import attrib, field

Expand Down Expand Up @@ -39,3 +41,12 @@ def spock(
make_init: bool = True,
dynamic: bool = False,
) -> Callable[[_C], _C]: ...
def SpockBuilder(
*args: _C,
configs: Optional[List] = None,
desc: str = "",
lazy: bool = False,
no_cmd_line: bool = False,
s3_config: Optional[_C] = None,
**kwargs
) -> ConfigArgBuilder: ...
2 changes: 1 addition & 1 deletion spock/addons/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ChoiceHyperParameter,
OptunaTunerConfig,
RangeHyperParameter,
spockTuner,
spockTuner
)

__all__ = [
Expand Down
39 changes: 39 additions & 0 deletions spock/addons/tune/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Any, Callable, Tuple, TypeVar, Union, overload

from attr import attrib, field

_T = TypeVar("_T")
_C = TypeVar("_C", bound=type)

# Note: from here
# https://github.com/python-attrs/attrs/blob/main/src/attr/__init__.pyi

# Static type inference support via __dataclass_transform__ implemented as per:
# https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md
# This annotation must be applied to all overloads of "spock_attr"

# NOTE: This is a typing construct and does not exist at runtime. Extensions
# wrapping attrs decorators should declare a separate __dataclass_transform__
# signature in the extension module using the specification linked above to
# provide pyright support -- this currently doesn't work in PyCharm
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...
@overload
@__dataclass_transform__(kw_only_default=True, field_descriptors=(attrib, field))
def spockTuner(
maybe_cls: _C,
kw_only: bool = True,
make_init: bool = True,
) -> _C: ...
@overload
@__dataclass_transform__(kw_only_default=True, field_descriptors=(attrib, field))
def spockTuner(
maybe_cls: None = ...,
kw_only: bool = True,
make_init: bool = True,
) -> Callable[[_C], _C]: ...
14 changes: 14 additions & 0 deletions spock/addons/tune/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ def _spock_tune(
kw_only: bool = True,
make_init: bool = True,
) -> Callable[[_C], _C]: ...
@overload
@__dataclass_transform__(kw_only_default=True, field_descriptors=(attrib, field))
def spockTuner(
maybe_cls: _C,
kw_only: bool = True,
make_init: bool = True,
) -> _C: ...
@overload
@__dataclass_transform__(kw_only_default=True, field_descriptors=(attrib, field))
def spockTuner(
maybe_cls: None = ...,
kw_only: bool = True,
make_init: bool = True,
) -> Callable[[_C], _C]: ...
26 changes: 10 additions & 16 deletions spock/addons/tune/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def sample(self):
Spockspace of the current hyper-parameter draw
"""
pass

@abstractmethod
def _construct(self):
Expand All @@ -56,19 +55,16 @@ def _construct(self):
flat dictionary of all hyper-parameters named with dot notation (class.param_name)
"""
pass

@property
@abstractmethod
def _get_sample(self):
"""Gets the sample parameter dictionary from the underlying backend"""
pass

@property
@abstractmethod
def tuner_status(self):
"""Returns a dictionary of all the necessary underlying tuner internals to report the result"""
pass

@property
@abstractmethod
Expand Down Expand Up @@ -173,13 +169,16 @@ def _try_choice_cast(self, val, type_string: str):
try:
val.choices = [caster(v) for v in val.choices]
return val
except TypeError:
print(
f"Attempted to cast into type: {val.type} but failed -- check the inputs to {type_string}"
except Exception as e:
raise TypeError(
f"Attempted to cast into type: `{val.type}` but failed -- check the inputs to `{type_string}`: {e}"
)

def _try_range_cast(self, val, type_string: str):
"""Try/except for casting range parameters
"""Casting range parameters
Note that we don't need to try/except here as the range is already constrained to be a float/int which
will always be able to be cast into float/int
Args:
val: current attr val
Expand All @@ -191,11 +190,6 @@ def _try_range_cast(self, val, type_string: str):
"""
caster = self._get_caster(val)
try:
low = caster(val.bounds[0])
high = caster(val.bounds[1])
return low, high
except TypeError:
print(
f"Attempted to cast into type: {val.type} but failed -- check the inputs to {type_string}"
)
low = caster(val.bounds[0])
high = caster(val.bounds[1])
return low, high
20 changes: 4 additions & 16 deletions spock/addons/tune/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,6 @@ def __init__(self, s3_config=None):
"""
super().__init__(s3_config=s3_config)

def __call__(self, *args, **kwargs):
"""Call to allow self chaining
Args:
*args:
**kwArgs:
Returns:
Payload: instance of self
"""
return TunerPayload(*args, **kwargs)

@staticmethod
def _update_payload(base_payload, input_classes, ignore_classes, payload):
# Get basic args
Expand All @@ -56,12 +44,12 @@ def _update_payload(base_payload, input_classes, ignore_classes, payload):
# Check for incorrect specific override of global def
if k not in attr_fields:
raise TypeError(
f"Referring to a class space {k} that is undefined"
f"Referring to a class space `{k}` that is undefined"
)
for i_keys in v.keys():
if i_keys not in attr_fields[k]:
raise ValueError(
f"Provided an unknown argument named {k}.{i_keys}"
f"Provided an unknown argument named `{k}.{i_keys}`"
)
if k in payload and isinstance(v, dict):
payload[k].update(v)
Expand Down Expand Up @@ -92,8 +80,8 @@ def _handle_payload_override(payload, key, value):
curr_ref[split] = value
else:
raise ValueError(
f"cmd-line override failed for {key} -- "
f"Failed to find key {split} within lowest level Dict"
f"cmd-line override failed for `{key}` -- "
f"Failed to find key `{split}` within lowest level Dict"
)
# If it's not keep walking the current payload
else:
Expand Down
9 changes: 2 additions & 7 deletions spock/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from _warnings import warn

from spock.exceptions import _SpockDuplicateArgumentError
from spock.graph import Graph


Expand Down Expand Up @@ -114,7 +115,7 @@ def _attribute_name_to_config_name_mapping(
if self._is_duplicated_key(
attribute_name_to_config_name_mapping, attr.name, n.__name__
):
raise SpockDuplicateArgumentError(
raise _SpockDuplicateArgumentError(
f"`{attr.name}` key is located in more than one config and cannot be resolved automatically."
f"Either specify the config name (`<config>.{attr.name}`) or change the key name in the config."
)
Expand Down Expand Up @@ -187,9 +188,3 @@ def _clean_arguments(arguments: dict, general_arguments: dict):
if arg not in general_arguments:
clean_arguments[arg] = value
return clean_arguments


class SpockDuplicateArgumentError(Exception):
"""Custom exception type for duplicated values"""

pass
54 changes: 25 additions & 29 deletions spock/backend/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""Handles the definitions of arguments types for Spock (backend: attrs)"""

import sys
from enum import EnumMeta
from enum import Enum, EnumMeta
from functools import partial
from typing import TypeVar, Union

Expand Down Expand Up @@ -213,6 +213,24 @@ def _enum_katra(typed, default=None, optional=False):
return x


def _cast_enum_default(default):
"""Allows the enum default to be the specific value or the Enum structured value
Checks if enum type and extracts the value from the Enum
Args:
default: the default value to assign if given
Returns:
default value or the Enum extracted value
"""
if isinstance(default, Enum):
return default.value
else:
return default


def _enum_base_katra(typed, base_type, allowed, default=None, optional=False):
"""Private interface to create a base Enum typed katra
Expand All @@ -239,7 +257,7 @@ def _enum_base_katra(typed, base_type, allowed, default=None, optional=False):
attr.validators.instance_of(base_type),
attr.validators.in_(allowed),
],
default=default,
default=_cast_enum_default(default),
type=typed,
metadata={"base": typed.__name__},
)
Expand All @@ -248,7 +266,7 @@ def _enum_base_katra(typed, base_type, allowed, default=None, optional=False):
validator=attr.validators.optional(
[attr.validators.instance_of(base_type), attr.validators.in_(allowed)]
),
default=default,
default=_cast_enum_default(default),
type=typed,
metadata={"base": typed.__name__, "optional": True},
)
Expand All @@ -267,7 +285,8 @@ def _enum_base_katra(typed, base_type, allowed, default=None, optional=False):
def _in_type(instance, attribute, value, options):
"""attrs validator for class type enum
Checks if the type of the class (e.g. value) is in the specified set of types provided
Checks if the type of the class (e.g. value) is in the specified set of types provided. Also checks if the value
is specified via the Enum definition
Args:
instance: current object instance
Expand Down Expand Up @@ -304,14 +323,14 @@ def _enum_class_katra(typed, allowed, default=None, optional=False):
if default is not None:
x = attr.ib(
validator=[partial(_in_type, options=allowed)],
default=default,
default=_cast_enum_default(default),
type=typed,
metadata={"base": typed.__name__},
)
elif optional:
x = attr.ib(
validator=attr.validators.optional([partial(_in_type, options=allowed)]),
default=default,
default=_cast_enum_default(default),
type=typed,
metadata={"base": typed.__name__, "optional": True},
)
Expand Down Expand Up @@ -412,27 +431,6 @@ def _handle_optional_typing(typed):
return typed, optional


def _check_generic_recursive_single_type(typed):
"""Checks generics for the single types -- mixed types of generics are not allowed
DEPRECATED -- NOW SUPPORTS MIXED TYPES OF TUPLES
Args:
typed: type
Returns:
"""
# Check if it has __args__ to look for optionality as it is a GenericAlias
# if hasattr(typed, '__args__'):
# if len(set(typed.__args__)) > 1:
# type_list = [str(val) for val in typed.__args__]
# raise TypeError(f"Passing multiple different subscript types to GenericAlias is not supported: {type_list}")
# else:
# for val in typed.__args__:
# _check_generic_recursive_single_type(typed=val)
pass


def katra(typed, default=None):
"""Public interface to create a katra
Expand All @@ -451,8 +449,6 @@ def katra(typed, default=None):
"""
# Handle optionals
typed, optional = _handle_optional_typing(typed)
# Check generic types for consistent types
_check_generic_recursive_single_type(typed)
# We need to check if the type is a _GenericAlias so that we can handle subscripted general types
# If it is subscript typed it will not be T which python uses as a generic type name
if isinstance(typed, _GenericAlias) and (
Expand Down
Loading

0 comments on commit de482aa

Please sign in to comment.