-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
694 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# -*- coding: utf-8 -*- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
"""A simple example using sklearn and Ax support""" | ||
|
||
# Spock ONLY supports the service style API from Ax | ||
# https://ax.dev/docs/api.html | ||
|
||
|
||
from sklearn.datasets import load_iris | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.model_selection import train_test_split | ||
|
||
from spock.addons.tune import ( | ||
AxTunerConfig, | ||
ChoiceHyperParameter, | ||
RangeHyperParameter, | ||
spockTuner, | ||
) | ||
from spock.builder import ConfigArgBuilder | ||
from spock.config import spock | ||
|
||
|
||
@spock | ||
class BasicParams: | ||
n_trials: int | ||
max_iter: int | ||
|
||
|
||
@spockTuner | ||
class LogisticRegressionHP: | ||
c: RangeHyperParameter | ||
solver: ChoiceHyperParameter | ||
|
||
|
||
def main(): | ||
# Load the iris data | ||
X, y = load_iris(return_X_y=True) | ||
|
||
# Split the Iris data | ||
X_train, X_valid, y_train, y_valid = train_test_split(X, y) | ||
|
||
# Ax config -- this will internally spawn the AxClient service API style which will be returned | ||
# by accessing the tuner_status property on the ConfigArgBuilder object | ||
ax_config = AxTunerConfig(objective_name="accuracy", minimize=False) | ||
|
||
# Use the builder to setup | ||
# Call tuner to indicate that we are going to do some HP tuning -- passing in an ax study object | ||
attrs_obj = ( | ||
ConfigArgBuilder( | ||
LogisticRegressionHP, | ||
BasicParams, | ||
desc="Example Logistic Regression Hyper-Parameter Tuning -- Ax Backend", | ||
) | ||
.tuner(tuner_config=ax_config) | ||
.save(user_specified_path="/tmp/ax") | ||
) | ||
|
||
# Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params | ||
# prior to starting the sampling process | ||
fixed_params = attrs_obj.generate() | ||
|
||
# Now we iterate through a bunch of ax trials | ||
for _ in range(fixed_params.BasicParams.n_trials): | ||
# The crux of spock support -- call save w/ the add_tuner_sample flag to write the current draw to file and | ||
# then call sample to return the composed Spockspace of the fixed parameters and the sampled parameters | ||
# Under the hood spock uses the AxClient Ax interface -- thus it handled the underlying call to get the next | ||
# sample and returns the necessary AxClient object in the return dictionary to call 'complete_trial' with the | ||
# associated metrics | ||
hp_attrs = attrs_obj.save( | ||
add_tuner_sample=True, user_specified_path="/tmp/ax" | ||
).sample() | ||
# Use the currently sampled parameters in a simple LogisticRegression from sklearn | ||
clf = LogisticRegression( | ||
C=hp_attrs.LogisticRegressionHP.c, | ||
solver=hp_attrs.LogisticRegressionHP.solver, | ||
max_iter=hp_attrs.BasicParams.max_iter, | ||
) | ||
clf.fit(X_train, y_train) | ||
val_acc = clf.score(X_valid, y_valid) | ||
# Get the status of the tuner -- this dict will contain all the objects needed to update | ||
tuner_status = attrs_obj.tuner_status | ||
# Pull the AxClient object and trial index out of the return dictionary and call 'complete_trial' on the | ||
# AxClient object with the correct raw_data that contains the objective name | ||
tuner_status["client"].complete_trial( | ||
trial_index=tuner_status["trial_index"], | ||
raw_data={"accuracy": (val_acc, 0.0)}, | ||
) | ||
# Always save the current best set of hyper-parameters | ||
attrs_obj.save_best(user_specified_path="/tmp/ax") | ||
|
||
# Grab the best config and metric | ||
best_config, best_metric = attrs_obj.best | ||
print(f"Best HP Config:\n{best_config}") | ||
print(f"Best Metric: {best_metric}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
################ | ||
# tune.yaml | ||
################ | ||
BasicParams: | ||
n_trials: 10 | ||
max_iter: 150 | ||
|
||
LogisticRegressionHP: | ||
c: | ||
type: float | ||
bounds: [1E-07, 10.0] | ||
log_scale: true | ||
solver: | ||
type: str | ||
choices: ["lbfgs", "saga"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
mypy_extensions==0.4.3; python_version < '3.8' | ||
optuna==2.9.1 | ||
#torchvision | ||
#torch | ||
#ax-platform | ||
torchvision==0.9.1 | ||
torch==1.8.1 | ||
ax-platform==0.2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright FMR LLC <opensource@fidelity.com> | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Handles the ax backend""" | ||
|
||
from ax.service.ax_client import AxClient | ||
|
||
from spock.addons.tune.config import AxTunerConfig | ||
from spock.addons.tune.interface import BaseInterface | ||
|
||
try: | ||
from typing import TypedDict | ||
except ImportError: | ||
from mypy_extensions import TypedDict | ||
|
||
|
||
class AxTunerStatus(TypedDict): | ||
"""Tuner status return object for Ax -- supports the service style API from Ax | ||
*Attributes*: | ||
client: current AxClient instance | ||
trial_index: current trial index | ||
""" | ||
|
||
client: AxClient | ||
trial_index: int | ||
|
||
|
||
class AxInterface(BaseInterface): | ||
"""Specific override to support the Ax backend -- supports the service style API from Ax""" | ||
|
||
def __init__(self, tuner_config: AxTunerConfig, tuner_namespace): | ||
"""AxInterface init call that maps variables, creates a map to fnc calls, and constructs the necessary | ||
underlying objects | ||
*Args*: | ||
tuner_config: configuration object for the ax backend | ||
tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types | ||
""" | ||
super(AxInterface, self).__init__(tuner_config, tuner_namespace) | ||
self._tuner_obj = AxClient( | ||
generation_strategy=self._tuner_config.generation_strategy, | ||
enforce_sequential_optimization=self._tuner_config.enforce_sequential_optimization, | ||
random_seed=self._tuner_config.random_seed, | ||
verbose_logging=self._tuner_config.verbose_logging, | ||
) | ||
# Some variables to use later | ||
self._trial_index = None | ||
self._sample_hash = None | ||
# Mapping spock underlying classes to ax distributions (search space) | ||
self._map_type = { | ||
"RangeHyperParameter": { | ||
"int": self._ax_range, | ||
"float": self._ax_range, | ||
}, | ||
"ChoiceHyperParameter": { | ||
"int": self._ax_choice, | ||
"float": self._ax_choice, | ||
"str": self._ax_choice, | ||
"bool": self._ax_choice, | ||
}, | ||
} | ||
# Build the correct underlying dictionary object for Ax client create experiment | ||
self._param_obj = self._construct() | ||
# Create the AxClient experiment | ||
self._tuner_obj.create_experiment( | ||
parameters=self._param_obj, | ||
name=self._tuner_config.name, | ||
objective_name=self._tuner_config.objective_name, | ||
minimize=self._tuner_config.minimize, | ||
parameter_constraints=self._tuner_config.parameter_constraints, | ||
outcome_constraints=self._tuner_config.outcome_constraints, | ||
overwrite_existing_experiment=self._tuner_config.overwrite_existing_experiment, | ||
tracking_metric_names=self._tuner_config.tracking_metric_names, | ||
immutable_search_space_and_opt_config=self._tuner_config.immutable_search_space_and_opt_config, | ||
is_test=self._tuner_config.is_test, | ||
) | ||
|
||
@property | ||
def tuner_status(self) -> AxTunerStatus: | ||
return AxTunerStatus(client=self._tuner_obj, trial_index=self._trial_index) | ||
|
||
@property | ||
def best(self): | ||
best_obj = self._tuner_obj.get_best_parameters() | ||
rollup_dict, _ = self._sample_rollup(best_obj[0]) | ||
return ( | ||
self._gen_spockspace(rollup_dict), | ||
best_obj[1][0][self._tuner_obj.objective_name], | ||
) | ||
|
||
@property | ||
def _get_sample(self): | ||
return self._tuner_obj.get_next_trial() | ||
|
||
def sample(self): | ||
parameters, self._trial_index = self._get_sample | ||
# Roll this back out into a Spockspace so it can be merged into the fixed parameter Spockspace | ||
# Also need to un-dot the param names to rebuild the nested structure | ||
rollup_dict, sample_hash = self._sample_rollup(parameters) | ||
self._sample_hash = sample_hash | ||
return self._gen_spockspace(rollup_dict) | ||
|
||
def _construct(self): | ||
param_list = [] | ||
# These will only be nested one level deep given the tuner syntax | ||
for k, v in vars(self._tuner_namespace).items(): | ||
for ik, iv in vars(v).items(): | ||
param_fn = self._map_type[type(iv).__name__][iv.type] | ||
param_list.append(param_fn(name=f"{k}.{ik}", val=iv)) | ||
return param_list | ||
|
||
def _ax_range(self, name, val): | ||
"""Assemble the dictionary for ax range parameters | ||
*Args*: | ||
name: parameter name | ||
val: current attr val | ||
*Returns*: | ||
dictionary that can be added to a parameter list | ||
""" | ||
low, high = self._try_range_cast(val, type_string="RangeHyperParameter") | ||
return { | ||
"name": name, | ||
"type": "range", | ||
"bounds": [low, high], | ||
"value_type": val.type, | ||
"log_scale": val.log_scale, | ||
} | ||
|
||
def _ax_choice(self, name, val): | ||
"""Assemble the dictionary for ax choice parameters | ||
*Args*: | ||
name: parameter name | ||
val: current attr val | ||
*Returns*: | ||
dictionary that can be added to a parameter list | ||
""" | ||
val = self._try_choice_cast(val, type_string="ChoiceHyperParameter") | ||
return { | ||
"name": name, | ||
"type": "choice", | ||
"values": val.choices, | ||
"value_type": val.type, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.