-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
263 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
scikit-learn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Test conf for all hyper-parameters | ||
HPOne: | ||
hp_int: | ||
type: int | ||
bounds: [ 10, 100 ] | ||
log_scale: false | ||
hp_float: | ||
type: float | ||
bounds: [ 10.0, 100.0 ] | ||
log_scale: false | ||
hp_int_log: | ||
type: int | ||
bounds: [ 10, 100 ] | ||
log_scale: true | ||
hp_float_log: | ||
type: float | ||
bounds: [ 10.0, 100.0 ] | ||
log_scale: true | ||
|
||
HPTwo: | ||
hp_choice_int: | ||
type: int | ||
choices: ["hello", "ciao", "bonjour" ] | ||
hp_choice_float: | ||
type: float | ||
choices: [ 10.0, 20.0, 40.0, 80.0 ] | ||
hp_choice_bool: | ||
type: bool | ||
choices: [ true, false ] | ||
hp_choice_str: | ||
type: str | ||
choices: [ "hello", "ciao", "bonjour" ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Test conf for all hyper-parameters | ||
HPOne: | ||
hp_int: | ||
type: int | ||
bounds: [ 10, 100 ] | ||
log_scale: false | ||
hp_float: | ||
type: float | ||
bounds: [ 10.0, 100.0 ] | ||
log_scale: false | ||
hp_int_log: | ||
type: int | ||
bounds: [ 'foo', 'bar' ] | ||
log_scale: true | ||
hp_float_log: | ||
type: float | ||
bounds: [ 10.0, 100.0 ] | ||
log_scale: true | ||
|
||
HPTwo: | ||
hp_choice_int: | ||
type: int | ||
choices: [10, 20, 40, 80 ] | ||
hp_choice_float: | ||
type: float | ||
choices: [ 10.0, 20.0, 40.0, 80.0 ] | ||
hp_choice_bool: | ||
type: bool | ||
choices: [ true, false ] | ||
hp_choice_str: | ||
type: str | ||
choices: [ "hello", "ciao", "bonjour" ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
############################### | ||
# optuna simple sklearn example | ||
############################### | ||
|
||
LogisticRegressionHP: | ||
c: | ||
type: float | ||
bounds: [1E-07, 10.0] | ||
log_scale: true | ||
solver: | ||
type: str | ||
choices: ["lbfgs", "saga"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# -*- coding: utf-8 -*- | ||
# -*- coding: utf-8 -*- | ||
from tests.tune.attr_configs_test import * | ||
import pytest | ||
import sys | ||
from spock.builder import ConfigArgBuilder | ||
from spock.addons.tune import OptunaTunerConfig | ||
|
||
|
||
class TestOptunaCmdLineOverride: | ||
@staticmethod | ||
@pytest.fixture | ||
def arg_builder(monkeypatch): | ||
with monkeypatch.context() as m: | ||
m.setattr(sys, 'argv', ['', '--config', | ||
'./tests/conf/yaml/test_hp.yaml', | ||
'--HPOne.hp_int.bounds', '(1, 1000)', | ||
'--HPOne.hp_int_log.bounds', '(1, 1000)', | ||
'--HPOne.hp_float.bounds', '(1.0, 1000.0)', | ||
'--HPOne.hp_float_log.bounds', '(1.0, 1000.0)', | ||
'--HPTwo.hp_choice_int.choices', '[1, 2, 4, 8]', | ||
'--HPTwo.hp_choice_float.choices', '[1.0, 2.0, 4.0, 8.0]', | ||
'--HPTwo.hp_choice_str.choices', "['is', 'it ', 'me', 'youre', 'looking', 'for']" | ||
]) | ||
optuna_config = OptunaTunerConfig(study_name="Tests", direction="maximize") | ||
config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) | ||
return config | ||
|
||
def test_hp_one(self, arg_builder): | ||
assert arg_builder._tune_namespace.HPOne.hp_int.bounds == (1, 1000) | ||
assert arg_builder._tune_namespace.HPOne.hp_int.type == 'int' | ||
assert arg_builder._tune_namespace.HPOne.hp_int.log_scale is False | ||
assert arg_builder._tune_namespace.HPOne.hp_int_log.bounds == (1, 1000) | ||
assert arg_builder._tune_namespace.HPOne.hp_int_log.type == 'int' | ||
assert arg_builder._tune_namespace.HPOne.hp_int_log.log_scale is True | ||
assert arg_builder._tune_namespace.HPOne.hp_float.bounds == (1.0, 1000.0) | ||
assert arg_builder._tune_namespace.HPOne.hp_float.type == 'float' | ||
assert arg_builder._tune_namespace.HPOne.hp_float.log_scale is False | ||
assert arg_builder._tune_namespace.HPOne.hp_float_log.bounds == (1.0, 1000.0) | ||
assert arg_builder._tune_namespace.HPOne.hp_float_log.type == 'float' | ||
assert arg_builder._tune_namespace.HPOne.hp_float_log.log_scale is True | ||
|
||
def test_hp_two(self, arg_builder): | ||
assert arg_builder._tune_namespace.HPTwo.hp_choice_int.type == 'int' | ||
assert arg_builder._tune_namespace.HPTwo.hp_choice_int.choices == [1, 2, 4, 8] | ||
assert arg_builder._tune_namespace.HPTwo.hp_choice_float.type == 'float' | ||
assert arg_builder._tune_namespace.HPTwo.hp_choice_float.choices == [1.0, 2.0, 4.0, 8.0] | ||
assert arg_builder._tune_namespace.HPTwo.hp_choice_bool.type == 'bool' | ||
assert arg_builder._tune_namespace.HPTwo.hp_choice_bool.choices == [True, False] | ||
assert arg_builder._tune_namespace.HPTwo.hp_choice_str.type == 'str' | ||
assert arg_builder._tune_namespace.HPTwo.hp_choice_str.choices == ['is', 'it ', 'me', 'youre', 'looking', 'for'] | ||
|
||
def test_sampling(self, arg_builder): | ||
# Draw 100 random samples and make sure all fall within all of the bounds or sets | ||
for _ in range(100): | ||
hp_attrs = arg_builder.sample() | ||
assert 1 <= hp_attrs.HPOne.hp_int <= 1000 | ||
assert isinstance(hp_attrs.HPOne.hp_int, int) is True | ||
assert 1 <= hp_attrs.HPOne.hp_int_log <= 1000 | ||
assert isinstance(hp_attrs.HPOne.hp_int_log, int) is True | ||
assert 1.0 <= hp_attrs.HPOne.hp_float <= 1000.0 | ||
assert isinstance(hp_attrs.HPOne.hp_float, float) is True | ||
assert 1.0 <= hp_attrs.HPOne.hp_float_log <= 1000.0 | ||
assert isinstance(hp_attrs.HPOne.hp_float_log, float) is True | ||
assert hp_attrs.HPTwo.hp_choice_int in [1, 2, 4, 8] | ||
assert isinstance(hp_attrs.HPTwo.hp_choice_int, int) is True | ||
assert hp_attrs.HPTwo.hp_choice_float in [1.0, 2.0, 4.0, 8.0] | ||
assert isinstance(hp_attrs.HPTwo.hp_choice_float, float) is True | ||
assert hp_attrs.HPTwo.hp_choice_bool in [True, False] | ||
assert isinstance(hp_attrs.HPTwo.hp_choice_bool, bool) is True | ||
assert hp_attrs.HPTwo.hp_choice_str in ['is', 'it ', 'me', 'youre', 'looking', 'for'] | ||
assert isinstance(hp_attrs.HPTwo.hp_choice_str, str) is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# -*- coding: utf-8 -*- | ||
from tests.tune.attr_configs_test import * | ||
import pytest | ||
import sys | ||
from spock.builder import ConfigArgBuilder | ||
import optuna | ||
|
||
|
||
class TestIncorrectTunerConfig: | ||
def test_incorrect_tuner_config(self, monkeypatch): | ||
with monkeypatch.context() as m: | ||
m.setattr(sys, 'argv', ['', '--config', | ||
'./tests/conf/yaml/test_hp.yaml']) | ||
optuna_config = optuna.create_study(study_name="Tests", direction='minimize') | ||
with pytest.raises(TypeError): | ||
config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) | ||
|
||
|
||
class TestInvalidCastChoice: | ||
def test_invalid_cast_choice(self, monkeypatch): | ||
with monkeypatch.context() as m: | ||
m.setattr(sys, 'argv', ['', '--config', | ||
'./tests/conf/yaml/test_hp_cast.yaml']) | ||
optuna_config = optuna.create_study(study_name="Tests", direction='minimize') | ||
with pytest.raises(TypeError): | ||
config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) | ||
|
||
|
||
class TestInvalidCastRange: | ||
def test_invalid_cast_range(self, monkeypatch): | ||
with monkeypatch.context() as m: | ||
m.setattr(sys, 'argv', ['', '--config', | ||
'./tests/conf/yaml/test_hp_cast_bounds.yaml']) | ||
optuna_config = optuna.create_study(study_name="Tests", direction='minimize') | ||
with pytest.raises(ValueError): | ||
config = ConfigArgBuilder(HPOne, HPTwo).tuner(optuna_config) |