Skip to content

Commit

Permalink
more test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
ncilfone committed Jul 16, 2021
1 parent 6f4b2e5 commit 54b3f1c
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 18 deletions.
35 changes: 27 additions & 8 deletions spock/addons/tune/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Handles the tuner payload backend"""

from spock.backend.payload import BasePayload
from spock.backend.utils import get_attr_fields


class TunerPayload(BasePayload):
Expand Down Expand Up @@ -47,17 +48,35 @@ def __call__(self, *args, **kwargs):

@staticmethod
def _update_payload(base_payload, input_classes, ignore_classes, payload):
# Get basic args
attr_fields = get_attr_fields(input_classes=input_classes)
# Get the ignore fields
ignore_fields = {
attr.__name__: [val.name for val in attr.__attrs_attrs__]
for attr in ignore_classes
}
ignore_fields = get_attr_fields(input_classes=ignore_classes)
for k, v in base_payload.items():
if k not in ignore_fields:
for ik, iv in v.items():
if "bounds" in iv:
iv["bounds"] = tuple(iv["bounds"])
return base_payload
if k != "config":
# Dict infers that we are overriding a global setting in a specific config
if isinstance(v, dict):
# we're in a namespace
# Check for incorrect specific override of global def
if k not in attr_fields:
raise TypeError(
f"Referring to a class space {k} that is undefined"
)
for i_keys in v.keys():
if i_keys not in attr_fields[k]:
raise ValueError(
f"Provided an unknown argument named {k}.{i_keys}"
)
if k in payload and isinstance(v, dict):
payload[k].update(v)
else:
payload[k] = v
# Handle tuple conversion here -- lazily
for ik, iv in v.items():
if "bounds" in iv:
iv["bounds"] = tuple(iv["bounds"])
return payload

@staticmethod
def _handle_payload_override(payload, key, value):
Expand Down
12 changes: 3 additions & 9 deletions spock/backend/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pathlib import Path

from spock.backend.handler import BaseHandler
from spock.backend.utils import convert_to_tuples, deep_update, get_type_fields
from spock.backend.utils import convert_to_tuples, deep_update, get_type_fields, get_attr_fields
from spock.utils import check_path_s3


Expand Down Expand Up @@ -307,15 +307,9 @@ def __call__(self, *args, **kwargs):
@staticmethod
def _update_payload(base_payload, input_classes, ignore_classes, payload):
# Get basic args
attr_fields = {
attr.__name__: [val.name for val in attr.__attrs_attrs__]
for attr in input_classes
}
attr_fields = get_attr_fields(input_classes=input_classes)
# Get the ignore fields
ignore_fields = {
attr.__name__: [val.name for val in attr.__attrs_attrs__]
for attr in ignore_classes
}
ignore_fields = get_attr_fields(input_classes=ignore_classes)
# Class names
class_names = [val.__name__ for val in input_classes]
# Parse out the types if generic
Expand Down
18 changes: 18 additions & 0 deletions spock/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@
"""Attr utility functions for Spock"""


def get_attr_fields(input_classes):
"""Gets the attribute fields from all classes
*Args*:
input_classes: current list of input classes
*Returns*:
dictionary of all attrs attribute fields
"""
return {
attr.__name__: [val.name for val in attr.__attrs_attrs__]
for attr in input_classes
}


def get_type_fields(input_classes):
"""Creates a dictionary of names and types
Expand Down
8 changes: 8 additions & 0 deletions tests/conf/yaml/test_hp_compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
config: [test_hp.yaml]

# Test conf for all hyper-parameters
HPOne:
hp_int:
type: int
bounds: [ 20, 200 ]
log_scale: false
27 changes: 26 additions & 1 deletion tests/tune/test_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,31 @@ def arg_builder(monkeypatch):
return config


class TestOptunaCompose(AllTypes):
@staticmethod
@pytest.fixture
def arg_builder(monkeypatch):
with monkeypatch.context() as m:
m.setattr(sys, 'argv', ['', '--config',
'./tests/conf/yaml/test_hp_compose.yaml'])
optuna_config = OptunaTunerConfig(study_name="Basic Tests", direction="maximize")
config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config)
return config

def test_hp_one(self, arg_builder):
assert arg_builder._tune_namespace.HPOne.hp_int.bounds == (20, 200)
assert arg_builder._tune_namespace.HPOne.hp_int.type == 'int'
assert arg_builder._tune_namespace.HPOne.hp_int.log_scale is False
assert arg_builder._tune_namespace.HPOne.hp_int_log.bounds == (10, 100)
assert arg_builder._tune_namespace.HPOne.hp_int_log.type == 'int'
assert arg_builder._tune_namespace.HPOne.hp_int_log.log_scale is True
assert arg_builder._tune_namespace.HPOne.hp_float.bounds == (10.0, 100.0)
assert arg_builder._tune_namespace.HPOne.hp_float.type == 'float'
assert arg_builder._tune_namespace.HPOne.hp_float.log_scale is False
assert arg_builder._tune_namespace.HPOne.hp_float_log.bounds == (10.0, 100.0)
assert arg_builder._tune_namespace.HPOne.hp_float_log.type == 'float'
assert arg_builder._tune_namespace.HPOne.hp_float_log.log_scale is True

class TestOptunaSample(SampleTypes):
@staticmethod
@pytest.fixture
Expand Down Expand Up @@ -81,7 +106,7 @@ def test_iris(self, arg_builder):
# Pull the study and trials object out of the return dictionary and pass it to the tell call using the study
# object
tuner_status["study"].tell(tuner_status["trial"], val_acc)

# Verify the sample was written out to file
yaml_regex = re.compile(fr'pytest.{curr_int_time}.hp.sample.[0-9]+.'
fr'[a-fA-F0-9]{{8}}-[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{4}}-'
fr'[a-fA-F0-9]{{4}}-[a-fA-F0-9]{{12}}.spock.cfg.yaml')
Expand Down

0 comments on commit 54b3f1c

Please sign in to comment.