Skip to content

Commit

Permalink
clean up of some old code. added a few more tests for tune coverage.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncilfone committed Jan 21, 2022
1 parent 9169079 commit b639cfd
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 79 deletions.
2 changes: 1 addition & 1 deletion spock/addons/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ChoiceHyperParameter,
OptunaTunerConfig,
RangeHyperParameter,
spockTuner,
spockTuner
)

__all__ = [
Expand Down
39 changes: 39 additions & 0 deletions spock/addons/tune/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Any, Callable, Tuple, TypeVar, Union, overload

from attr import attrib, field

_T = TypeVar("_T")
_C = TypeVar("_C", bound=type)

# Note: from here
# https://github.com/python-attrs/attrs/blob/main/src/attr/__init__.pyi

# Static type inference support via __dataclass_transform__ implemented as per:
# https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md
# This annotation must be applied to all overloads of "spock_attr"

# NOTE: This is a typing construct and does not exist at runtime. Extensions
# wrapping attrs decorators should declare a separate __dataclass_transform__
# signature in the extension module using the specification linked above to
# provide pyright support -- this currently doesn't work in PyCharm
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...
@overload
@__dataclass_transform__(kw_only_default=True, field_descriptors=(attrib, field))
def spockTuner(
maybe_cls: _C,
kw_only: bool = True,
make_init: bool = True,
) -> _C: ...
@overload
@__dataclass_transform__(kw_only_default=True, field_descriptors=(attrib, field))
def spockTuner(
maybe_cls: None = ...,
kw_only: bool = True,
make_init: bool = True,
) -> Callable[[_C], _C]: ...
14 changes: 14 additions & 0 deletions spock/addons/tune/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ def _spock_tune(
kw_only: bool = True,
make_init: bool = True,
) -> Callable[[_C], _C]: ...
@overload
@__dataclass_transform__(kw_only_default=True, field_descriptors=(attrib, field))
def spockTuner(
maybe_cls: _C,
kw_only: bool = True,
make_init: bool = True,
) -> _C: ...
@overload
@__dataclass_transform__(kw_only_default=True, field_descriptors=(attrib, field))
def spockTuner(
maybe_cls: None = ...,
kw_only: bool = True,
make_init: bool = True,
) -> Callable[[_C], _C]: ...
26 changes: 10 additions & 16 deletions spock/addons/tune/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def sample(self):
Spockspace of the current hyper-parameter draw
"""
pass

@abstractmethod
def _construct(self):
Expand All @@ -56,19 +55,16 @@ def _construct(self):
flat dictionary of all hyper-parameters named with dot notation (class.param_name)
"""
pass

@property
@abstractmethod
def _get_sample(self):
"""Gets the sample parameter dictionary from the underlying backend"""
pass

@property
@abstractmethod
def tuner_status(self):
"""Returns a dictionary of all the necessary underlying tuner internals to report the result"""
pass

@property
@abstractmethod
Expand Down Expand Up @@ -173,13 +169,16 @@ def _try_choice_cast(self, val, type_string: str):
try:
val.choices = [caster(v) for v in val.choices]
return val
except TypeError:
print(
f"Attempted to cast into type: {val.type} but failed -- check the inputs to {type_string}"
except Exception as e:
raise TypeError(
f"Attempted to cast into type: `{val.type}` but failed -- check the inputs to `{type_string}`: {e}"
)

def _try_range_cast(self, val, type_string: str):
"""Try/except for casting range parameters
"""Casting range parameters
Note that we don't need to try/except here as the range is already constrained to be a float/int which
will always be able to be cast into float/int
Args:
val: current attr val
Expand All @@ -191,11 +190,6 @@ def _try_range_cast(self, val, type_string: str):
"""
caster = self._get_caster(val)
try:
low = caster(val.bounds[0])
high = caster(val.bounds[1])
return low, high
except TypeError:
print(
f"Attempted to cast into type: {val.type} but failed -- check the inputs to {type_string}"
)
low = caster(val.bounds[0])
high = caster(val.bounds[1])
return low, high
20 changes: 4 additions & 16 deletions spock/addons/tune/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,6 @@ def __init__(self, s3_config=None):
"""
super().__init__(s3_config=s3_config)

def __call__(self, *args, **kwargs):
"""Call to allow self chaining
Args:
*args:
**kwArgs:
Returns:
Payload: instance of self
"""
return TunerPayload(*args, **kwargs)

@staticmethod
def _update_payload(base_payload, input_classes, ignore_classes, payload):
# Get basic args
Expand All @@ -56,12 +44,12 @@ def _update_payload(base_payload, input_classes, ignore_classes, payload):
# Check for incorrect specific override of global def
if k not in attr_fields:
raise TypeError(
f"Referring to a class space {k} that is undefined"
f"Referring to a class space `{k}` that is undefined"
)
for i_keys in v.keys():
if i_keys not in attr_fields[k]:
raise ValueError(
f"Provided an unknown argument named {k}.{i_keys}"
f"Provided an unknown argument named `{k}.{i_keys}`"
)
if k in payload and isinstance(v, dict):
payload[k].update(v)
Expand Down Expand Up @@ -92,8 +80,8 @@ def _handle_payload_override(payload, key, value):
curr_ref[split] = value
else:
raise ValueError(
f"cmd-line override failed for {key} -- "
f"Failed to find key {split} within lowest level Dict"
f"cmd-line override failed for `{key}` -- "
f"Failed to find key `{split}` within lowest level Dict"
)
# If it's not keep walking the current payload
else:
Expand Down
9 changes: 2 additions & 7 deletions spock/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from _warnings import warn

from spock.exceptions import _SpockDuplicateArgumentError
from spock.graph import Graph


Expand Down Expand Up @@ -114,7 +115,7 @@ def _attribute_name_to_config_name_mapping(
if self._is_duplicated_key(
attribute_name_to_config_name_mapping, attr.name, n.__name__
):
raise SpockDuplicateArgumentError(
raise _SpockDuplicateArgumentError(
f"`{attr.name}` key is located in more than one config and cannot be resolved automatically."
f"Either specify the config name (`<config>.{attr.name}`) or change the key name in the config."
)
Expand Down Expand Up @@ -187,9 +188,3 @@ def _clean_arguments(arguments: dict, general_arguments: dict):
if arg not in general_arguments:
clean_arguments[arg] = value
return clean_arguments


class SpockDuplicateArgumentError(Exception):
"""Custom exception type for duplicated values"""

pass
23 changes: 0 additions & 23 deletions spock/backend/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,27 +431,6 @@ def _handle_optional_typing(typed):
return typed, optional


def _check_generic_recursive_single_type(typed):
"""Checks generics for the single types -- mixed types of generics are not allowed
DEPRECATED -- NOW SUPPORTS MIXED TYPES OF TUPLES
Args:
typed: type
Returns:
"""
# Check if it has __args__ to look for optionality as it is a GenericAlias
# if hasattr(typed, '__args__'):
# if len(set(typed.__args__)) > 1:
# type_list = [str(val) for val in typed.__args__]
# raise TypeError(f"Passing multiple different subscript types to GenericAlias is not supported: {type_list}")
# else:
# for val in typed.__args__:
# _check_generic_recursive_single_type(typed=val)
pass


def katra(typed, default=None):
"""Public interface to create a katra
Expand All @@ -470,8 +449,6 @@ def katra(typed, default=None):
"""
# Handle optionals
typed, optional = _handle_optional_typing(typed)
# Check generic types for consistent types
_check_generic_recursive_single_type(typed)
# We need to check if the type is a _GenericAlias so that we can handle subscripted general types
# If it is subscript typed it will not be T which python uses as a generic type name
if isinstance(typed, _GenericAlias) and (
Expand Down
22 changes: 8 additions & 14 deletions spock/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,8 @@ def sample(self):
argument namespace(s) -- fixed + drawn sample from tuner backend
"""
if self._tune_obj is None:
raise ValueError(
f"Called sample method without passing any @spockTuner decorated classes"
)
if self._tuner_interface is None:
raise ValueError(
raise RuntimeError(
f"Called sample method without first calling the tuner method that initializes the "
f"backend library"
)
Expand All @@ -187,10 +183,12 @@ def tuner(self, tuner_config):
self so that functions can be chained
"""

if self._tune_obj is None:
raise ValueError(
raise RuntimeError(
f"Called tuner method without passing any @spockTuner decorated classes"
)

try:
from spock.addons.tune.tuner import TunerInterface

Expand All @@ -200,12 +198,8 @@ def tuner(self, tuner_config):
fixed_namespace=self._arg_namespace,
)
self._tuner_state = self._tuner_interface.sample()
except ImportError as e:
print(
"Missing libraries to support tune functionality. Please re-install with the extra tune "
"dependencies -- pip install spock-config[tune]."
f"Error: {e}"
)
except Exception as e:
raise e
return self

def _print_usage_and_exit(self, msg=None, sys_exit=True, exit_code=1):
Expand Down Expand Up @@ -500,7 +494,7 @@ def save(
if add_tuner_sample:
if self._tune_obj is None:
raise ValueError(
f"Called save method with add_tuner_sample as {add_tuner_sample} without passing any @spockTuner "
f"Called save method with add_tuner_sample as `{add_tuner_sample}` without passing any @spockTuner "
f"decorated classes -- please use the add_tuner_sample flag for saving only hyper-parameter tuning "
f"runs"
)
Expand Down Expand Up @@ -553,7 +547,7 @@ def save_best(
"""
if self._tune_obj is None:
raise ValueError(
f"Called save_best method without passing any @spockTuner decorated classes -- please use the save()"
f"Called save_best method without passing any @spockTuner decorated classes -- please use the `save()`"
f" method for saving non hyper-parameter tuning runs"
)
file_name = f"hp.best" if file_name is None else f"{file_name}.hp.best"
Expand Down
6 changes: 6 additions & 0 deletions spock/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ class _SpockNotOptionalError(Exception):
"""Custom exception for missing value"""

pass


class _SpockDuplicateArgumentError(Exception):
"""Custom exception type for duplicated values"""

pass
52 changes: 52 additions & 0 deletions tests/base/test_addons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import sys

import pytest

from spock.builder import ConfigArgBuilder
from tests.base.attr_configs_test import *
import datetime


class TestBasicBuilder:
"""Testing when builder is calling an add on functionality it shouldn't"""
def test_raise_tuner_sample(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
m.setattr(
sys, "argv", ["", "--config", "./tests/conf/yaml/test.yaml"]
)
# Serialize
config = ConfigArgBuilder(
*all_configs,
desc="Test Builder",
)
now = datetime.datetime.now()
curr_int_time = int(f"{now.year}{now.month}{now.day}{now.hour}{now.second}")
with pytest.raises(ValueError):
config_values = config.save(
file_extension=".yaml",
file_name=f"pytest.{curr_int_time}",
user_specified_path=tmp_path,
add_tuner_sample=True
)

def test_raise_save_best(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
m.setattr(
sys, "argv", ["", "--config", "./tests/conf/yaml/test.yaml"]
)
# Serialize
config = ConfigArgBuilder(
*all_configs,
desc="Test Builder",
)
now = datetime.datetime.now()
curr_int_time = int(f"{now.year}{now.month}{now.day}{now.hour}{now.second}")
with pytest.raises(ValueError):
config_values = config.save_best(
file_extension=".yaml",
file_name=f"pytest.{curr_int_time}",
user_specified_path=tmp_path,
)

0 comments on commit b639cfd

Please sign in to comment.