Skip to content

Commit

Permalink
Add Ax Backend Support -- Hyperparameter Tuning (#83)
Browse files Browse the repository at this point in the history
* adding in Ax backend support. some refactors of Optuna backend due to overlapping functionality

* update example

* linters

* adding ax tests

* rounding out comparable unit tests for ax backend

* docs for ax
  • Loading branch information
ncilfone committed Aug 17, 2021
1 parent b95a587 commit 7c51cb7
Show file tree
Hide file tree
Showing 26 changed files with 824 additions and 134 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TUNE_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TEST_EXTRAS_REQUIREMENTS_REQUIREMENTS.txt') }}

- name: Install dependencies and dev dependencies
run: |
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/python-lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TUNE_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TEST_EXTRAS_REQUIREMENTS_REQUIREMENTS.txt') }}

- name: Install dependencies and dev dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install -r ./requirements/DEV_REQUIREMENTS.txt
pip install -r ./requirements/S3_REQUIREMENTS.txt
pip install -r ./requirements/TUNE_REQUIREMENTS.txt
pip install -r ./requirements/TEST_EXTRAS_REQUIREMENTS.txt
- name: Run isort linter
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-pytest-tune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}

- name: Install dependencies
run: |
Expand Down
6 changes: 5 additions & 1 deletion NOTICE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ This product relies on the following works (and the dependencies thereof), insta


Optional extensions rely on the following works (and the dependencies thereof), installed separately:
- ax-platform | https://github.com/facebook/Ax | MIT License
- boto3 | https://github.com/boto/boto3 | Apache License 2.0
- botocore | https://github.com/boto/botocore | Apache License 2.0
- hurry.filesize | https://pypi.org/project/hurry.filesize/ | ZPL 2.1
- mypy_extensions | https://github.com/python/mypy_extensions | MIT License
- optuna | https://optuna.org/ | MIT License
- s3transfer | https://github.com/boto/s3transfer | Apache License 2.0
- s3transfer | https://github.com/boto/s3transfer | Apache License 2.0
- torch | https://github.com/pytorch/pytorch | BSD License
- torchvision | https://github.com/pytorch/vision | BSD 3-Clause License
12 changes: 5 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ hierarchical configuration by composition.

## Quick Install

Requires Python 3.6+
The basic install and `[s3]` extension require Python 3.6+ while the `[tune]` extension requires Python 3.7+

| Base | w/ S3 Extension | w/ Hyper-Parameter Tuner |
|------|-----------------|--------------------------|
Expand All @@ -66,6 +66,9 @@ Example `spock` usage is located [here](https://github.com/fidelity/spock/blob/m

See [Releases](https://github.com/fidelity/spock/releases) for more information.

#### August 17, 2021
* Added hyper-parameter tuning backend support for Ax via Service API

#### July 21, 2021
* Added hyper-parameter tuning support with `pip install spock-config[tune]`
* Hyper-parameter tuning backend support for Optuna define-and-run API (WIP for Ax)
Expand All @@ -75,13 +78,8 @@ See [Releases](https://github.com/fidelity/spock/releases) for more information.
* S3 addon supports automatically handling loading/saving from paths defined with `s3://` URI(s) by passing in an
active `boto3.Session`

#### March 18th, 2021

* Support for Google docstring style annotation of `spock` class (and Enums) and attributes
* Added in ability to print docstring annotated help information to command line with `--help` argument


### Original Implementation
## Original Implementation

[Nicholas Cilfone](https://github.com/ncilfone), [Siddharth Narayanan](https://github.com/sidnarayanan)
___
Expand Down
8 changes: 5 additions & 3 deletions docs/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

### Requirements

* Python: 3.6+
* Python: 3.6+ (`[tune]` extension requires 3.7+)
* Base Dependencies: attrs, GitPython, PyYAML, toml
* Tested OS: Unix (Ubuntu 16.04, Ubuntu 18.04), OSX (10.14.6)
* Tested OS: Ubuntu (16.04, 18.04), OSX (10.14.6, 11.3.1)

### Install/Upgrade

Expand All @@ -23,7 +23,9 @@ pip install spock-config[s3]

#### w/ Hyper-Parameter Tuner Extension

Extra Dependencies: optuna
Requires Python 3.7+

Extra Dependencies: optuna, ax-platform, torch, torchvision, mypy_extensions (Python < 3.8)

```bash
pip install spock-config[tune]
Expand Down
4 changes: 1 addition & 3 deletions docs/addons/tuner/About.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ All examples can be found [here](https://github.com/fidelity/spock/blob/master/e

### Installing

Install `spock` with the extra hyper-parameter tuning related dependencies.
Install `spock` with the extra hyper-parameter tuning related dependencies. Requires Python 3.7+ due to ax-platform

```bash
pip install spock-config[tune]
```

### Supported Backends
* [Optuna](https://optuna.readthedocs.io/en/stable/index.html)

### WIP/Planned Backends
* [Ax](https://ax.dev/)
103 changes: 103 additions & 0 deletions docs/addons/tuner/Ax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Ax Support

`spock` integrates with the Ax optimization framework through the provided Service API. See
[docs](https://ax.dev/api/service.html#module-ax.service.ax_client) for `AxClient` info.

All examples can be found [here](https://github.com/fidelity/spock/blob/master/examples).

### Defining the Backend

So let's continue with our Ax specific version of `tune.py`:

It's important to note that you can still use the `@spock` decorator to define any non hyper-parameters! For
posterity let's add some fixed parameters (those that are not part of hyper-parameter tuning) that we will use
elsewhere in our code.

```python
from spock.config import spock

@spock
class BasicParams:
n_trials: int
max_iter: int
```

Now we need to tell `spock` that we intend on doing hyper-parameter tuning and which backend we would like to use. We
do this by calling the `tuner` method on the `ConfigArgBuilder` object passing in a configuration object for the
backend of choice (just like in basic functionality this is a chained command, thus the builder object will still be
returned). For Ax one uses `AxTunerConfig`. This config mirrors all options that would be passed into
the `AxClient` constructor and the `AxClient.create_experiment`function call so that `spock` can setup the
Service API. (Note: The `@spockTuner`decorated classes are passed to the `ConfigArgBuilder` in the exact same
way as basic `@spock`decorated classes.)

```python
from spock.addons.tune import AxTunerConfig

# Ax config -- this will internally spawn the AxClient service API style which will be returned
# by accessing the tuner_status property on the ConfigArgBuilder object -- note here that we need to define the
# objective name that the client will expect to be within the data dictionary when completing trials
ax_config = AxTunerConfig(objective_name="accuracy", minimize=False)

# Use the builder to setup
# Call tuner to indicate that we are going to do some HP tuning -- passing in an ax study object
attrs_obj = ConfigArgBuilder(
LogisticRegressionHP,
BasicParams,
desc="Example Logistic Regression Hyper-Parameter Tuning -- Ax Backend",
).tuner(tuner_config=ax_config)

```

### Generate Functionality Still Exists

To get the set of fixed parameters (those that are not hyper-parameters) one simply calls the `generate()` function
just like they would for normal `spock` usage to get the fixed parameter `spockspace`.

Continuing in `tune.py`:

```python

# Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params
# prior to starting the sampling process
fixed_params = attrs_obj.generate()
```

### Sample as an Alternative to Generate

The `sample()` call is the crux of `spock` hyper-parameter tuning support. It draws a hyper-parameter sample from the
underlying backend sampler and combines it with fixed parameters and returns a single `Spockspace` with all
useable parameters (defined with dot notation). For Ax -- Under the hood `spock` uses the Service API (with
an `AxClient`) -- thus it handles the underlying call to get the next trial. The `spock` builder object has a
`@property` called `tuner_status` that returns any necessary backend objects in a dictionary that the user needs to
interface with. In the case of Ax, this contains both the `AxClient` and `trial_index` (as dictionary keys). We use
the return of`tuner_status` to handle trial completion via the `complete_trial` call based on the metric of interested
(here just the simple validation accuracy -- remember during `AxTunerConfig` instantiation we set the `objective_name`
to 'accuracy' -- we also set the SEM to 0.0 since we are not using it for this example)

See [here](https://ax.dev/api/service.html#ax.service.ax_client.AxClient.complete_trial) for Ax documentation on
completing trials.

Continuing in `tune.py`:

```python
# Iterate through a bunch of ax trials
for _ in range(fixed_params.BasicParams.n_trials):
# Call sample on the spock object
hp_attrs = attrs_obj.sample()
# Use the currently sampled parameters in a simple LogisticRegression from sklearn
clf = LogisticRegression(
C=hp_attrs.LogisticRegressionHP.c,
solver=hp_attrs.LogisticRegressionHP.solver,
max_iter=hp_attrs.BasicParams.max_iter
)
clf.fit(X_train, y_train)
val_acc = clf.score(X_valid, y_valid)
# Get the status of the tuner -- this dict will contain all the objects needed to update
tuner_status = attrs_obj.tuner_status
# Pull the AxClient object and trial index out of the return dictionary and call 'complete_trial' on the
# AxClient object with the correct raw_data that contains the objective name
tuner_status["client"].complete_trial(
trial_index=tuner_status["trial_index"],
raw_data={"accuracy": (val_acc, 0.0)},
)
```
2 changes: 1 addition & 1 deletion docs/addons/tuner/Optuna.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ optuna_config = OptunaTunerConfig(
attrs_obj = ConfigArgBuilder(
LogisticRegressionHP,
BasicParams,
desc="Example Logistic Regression Hyper-Parameter Tuning",
desc="Example Logistic Regression Hyper-Parameter Tuning -- Optuna Backend",
).tuner(tuner_config=optuna_config)

```
Expand Down
1 change: 1 addition & 0 deletions examples/tune/ax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
98 changes: 98 additions & 0 deletions examples/tune/ax/tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-

"""A simple example using sklearn and Ax support"""

# Spock ONLY supports the service style API from Ax
# https://ax.dev/docs/api.html


from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

from spock.addons.tune import (
AxTunerConfig,
ChoiceHyperParameter,
RangeHyperParameter,
spockTuner,
)
from spock.builder import ConfigArgBuilder
from spock.config import spock


@spock
class BasicParams:
n_trials: int
max_iter: int


@spockTuner
class LogisticRegressionHP:
c: RangeHyperParameter
solver: ChoiceHyperParameter


def main():
# Load the iris data
X, y = load_iris(return_X_y=True)

# Split the Iris data
X_train, X_valid, y_train, y_valid = train_test_split(X, y)

# Ax config -- this will internally spawn the AxClient service API style which will be returned
# by accessing the tuner_status property on the ConfigArgBuilder object
ax_config = AxTunerConfig(objective_name="accuracy", minimize=False)

# Use the builder to setup
# Call tuner to indicate that we are going to do some HP tuning -- passing in an ax study object
attrs_obj = (
ConfigArgBuilder(
LogisticRegressionHP,
BasicParams,
desc="Example Logistic Regression Hyper-Parameter Tuning -- Ax Backend",
)
.tuner(tuner_config=ax_config)
.save(user_specified_path="/tmp/ax")
)

# Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params
# prior to starting the sampling process
fixed_params = attrs_obj.generate()

# Now we iterate through a bunch of ax trials
for _ in range(fixed_params.BasicParams.n_trials):
# The crux of spock support -- call save w/ the add_tuner_sample flag to write the current draw to file and
# then call sample to return the composed Spockspace of the fixed parameters and the sampled parameters
# Under the hood spock uses the AxClient Ax interface -- thus it handled the underlying call to get the next
# sample and returns the necessary AxClient object in the return dictionary to call 'complete_trial' with the
# associated metrics
hp_attrs = attrs_obj.save(
add_tuner_sample=True, user_specified_path="/tmp/ax"
).sample()
# Use the currently sampled parameters in a simple LogisticRegression from sklearn
clf = LogisticRegression(
C=hp_attrs.LogisticRegressionHP.c,
solver=hp_attrs.LogisticRegressionHP.solver,
max_iter=hp_attrs.BasicParams.max_iter,
)
clf.fit(X_train, y_train)
val_acc = clf.score(X_valid, y_valid)
# Get the status of the tuner -- this dict will contain all the objects needed to update
tuner_status = attrs_obj.tuner_status
# Pull the AxClient object and trial index out of the return dictionary and call 'complete_trial' on the
# AxClient object with the correct raw_data that contains the objective name
tuner_status["client"].complete_trial(
trial_index=tuner_status["trial_index"],
raw_data={"accuracy": (val_acc, 0.0)},
)
# Always save the current best set of hyper-parameters
attrs_obj.save_best(user_specified_path="/tmp/ax")

# Grab the best config and metric
best_config, best_metric = attrs_obj.best
print(f"Best HP Config:\n{best_config}")
print(f"Best Metric: {best_metric}")


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions examples/tune/ax/tune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
################
# tune.yaml
################
BasicParams:
n_trials: 10
max_iter: 150

LogisticRegressionHP:
c:
type: float
bounds: [1E-07, 10.0]
log_scale: true
solver:
type: str
choices: ["lbfgs", "saga"]
8 changes: 4 additions & 4 deletions examples/tune/optuna/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def main():
ConfigArgBuilder(
LogisticRegressionHP,
BasicParams,
desc="Example Logistic Regression Hyper-Parameter Tuning",
desc="Example Logistic Regression Hyper-Parameter Tuning -- Optuna Backend",
)
.tuner(tuner_config=optuna_config)
.save(user_specified_path="/tmp")
.save(user_specified_path="/tmp/optuna")
)

# Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params
Expand All @@ -68,7 +68,7 @@ def main():
# Under the hood spock uses the define-and-run Optuna interface -- thus it handled the underlying 'ask' call
# and returns the necessary trial object in the return dictionary to call 'tell' with the study object
hp_attrs = attrs_obj.save(
add_tuner_sample=True, user_specified_path="/tmp"
add_tuner_sample=True, user_specified_path="/tmp/optuna"
).sample()
# Use the currently sampled parameters in a simple LogisticRegression from sklearn
clf = LogisticRegression(
Expand All @@ -84,7 +84,7 @@ def main():
# object
tuner_status["study"].tell(tuner_status["trial"], val_acc)
# Always save the current best set of hyper-parameters
attrs_obj.save_best(user_specified_path="/tmp")
attrs_obj.save_best(user_specified_path="/tmp/optuna")

# Grab the best config and metric
best_config, best_metric = attrs_obj.best
Expand Down
Loading

0 comments on commit 7c51cb7

Please sign in to comment.