# Optimize using Ax/Botorch

In [1]:
import numpy as np
# import pandas as pd
from time import sleep

from ax import *

from ax.plot.contour import plot_contour
from ax.plot.slice import plot_slice
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()

# import evaluation function defined locally as the
# minimum validation loss training a NN for 60 epochs
from ax_utils import eval_fn

[INFO 06-24 09:27:48] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.


## Define parameters
Network parameters
* `channels`: Defaults to 16 (256 in paper)
* `num_blocks`: Defaults to 6 (19 in paper)
* `out_c`: Defaults to 4 (2 in paper). Keep this fixed for now.

Training parameters
* `lr`: Defaults to 0.1
* `decay`: Defaults to 1e-4.
* `momentum`: Defaults to 0.9. Keep this fixed for now.

In [2]:
# channels = 2**ch_param
ch_param = RangeParameter(name="channels", parameter_type=ParameterType.INT, 
                          lower=3,
                          upper=7
                          )
blk_param = RangeParameter(name="blocks", parameter_type=ParameterType.INT,
                              lower=3,
                              upper=7
                              )
lr_param = RangeParameter(name="lr", parameter_type=ParameterType.FLOAT,
                          lower=1e-2,
                          upper=1.0, 
                          log_scale=True
                          )
dec_param = RangeParameter(name="decay", parameter_type=ParameterType.FLOAT,
                             lower=1e-4,
                             upper=1e-2,
                             log_scale=True
                             )
search_space = SearchSpace(
    parameters=[ch_param, blk_param, lr_param, dec_param],
)

## Define experiment and initial arms

In [3]:
exp = SimpleExperiment(
    name='optimize_network',
    search_space=search_space,
    evaluation_function=eval_fn,
    objective_name='min_val',
    minimize=True,
)

In [4]:
sobol = Models.SOBOL(exp.search_space)
for i in range(20):
    gen = sobol.gen(1)
    print('-'*10)
    print(gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
Arm(parameters={'channels': 4, 'blocks': 5, 'lr': 0.21603543371097486, 'decay': 0.0029580901153477192})


100%|██████████| 60/60 [02:38<00:00,  2.69s/it]


1.0704485004146893
----------
Arm(parameters={'channels': 4, 'blocks': 6, 'lr': 0.9848772287629501, 'decay': 0.00014802888526512398})


100%|██████████| 60/60 [02:54<00:00,  2.90s/it]


1.2168725331624348
----------
Arm(parameters={'channels': 5, 'blocks': 3, 'lr': 0.012161340519149005, 'decay': 0.008839650491027187})


100%|██████████| 60/60 [02:08<00:00,  2.16s/it]


0.9645098547140757
----------
Arm(parameters={'channels': 6, 'blocks': 7, 'lr': 0.138898276680442, 'decay': 0.0005999404984214526})


100%|██████████| 60/60 [03:07<00:00,  3.06s/it]


0.9698945209383965
----------
Arm(parameters={'channels': 3, 'blocks': 4, 'lr': 0.0643767241744673, 'decay': 0.0013201121650438425})


100%|██████████| 60/60 [02:26<00:00,  2.40s/it]


1.0368678122758865
----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.02478188700653261, 'decay': 0.00020157616898544504})


100%|██████████| 60/60 [02:40<00:00,  2.68s/it]


1.1157233516375225
----------
Arm(parameters={'channels': 6, 'blocks': 4, 'lr': 0.3632666147287835, 'decay': 0.003793766245015651})


100%|██████████| 60/60 [02:27<00:00,  2.49s/it]


1.1517122238874435
----------
Arm(parameters={'channels': 6, 'blocks': 6, 'lr': 0.015073313733104839, 'decay': 0.001600272527865882})


100%|██████████| 60/60 [02:55<00:00,  2.88s/it]


1.1239951699972153
----------
Arm(parameters={'channels': 5, 'blocks': 3, 'lr': 0.5938569986663507, 'decay': 0.000847313876941212})


100%|██████████| 60/60 [02:08<00:00,  2.11s/it]


1.0458907708525658
----------
Arm(parameters={'channels': 3, 'blocks': 6, 'lr': 0.2689328593529833, 'decay': 0.005553453025899038})


100%|██████████| 60/60 [02:52<00:00,  2.90s/it]


1.1609033147494
----------
Arm(parameters={'channels': 6, 'blocks': 5, 'lr': 0.033510423975188344, 'decay': 0.00025320470296377273})


100%|██████████| 60/60 [02:41<00:00,  2.69s/it]


1.1047308991352718
----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.5195733450716649, 'decay': 0.0020945049926299492})


100%|██████████| 60/60 [02:39<00:00,  2.66s/it]


1.1240117500225704
----------
Arm(parameters={'channels': 4, 'blocks': 4, 'lr': 0.023027093393896025, 'decay': 0.0003518754630674335})


100%|██████████| 60/60 [02:22<00:00,  2.43s/it]


1.133261541525523
----------
Arm(parameters={'channels': 4, 'blocks': 7, 'lr': 0.09250491903100018, 'decay': 0.007037341918464})


100%|██████████| 60/60 [03:06<00:00,  3.14s/it]


1.061123659213384
----------
Arm(parameters={'channels': 7, 'blocks': 4, 'lr': 0.1296266042558986, 'decay': 0.00010112183396864997})


100%|██████████| 60/60 [02:30<00:00,  2.56s/it]


1.048171396056811
----------
Arm(parameters={'channels': 7, 'blocks': 7, 'lr': 0.6746489361349897, 'decay': 0.0002097686712755458})


100%|██████████| 60/60 [03:19<00:00,  3.34s/it]


0.9731405079364777
----------
Arm(parameters={'channels': 4, 'blocks': 4, 'lr': 0.01773373941067203, 'decay': 0.003508841498129898})


100%|██████████| 60/60 [02:26<00:00,  2.45s/it]


1.0611062447230022
----------
Arm(parameters={'channels': 4, 'blocks': 5, 'lr': 0.03797808310726527, 'decay': 0.0007009816376283641})


100%|██████████| 60/60 [02:40<00:00,  2.67s/it]


1.077457735935847
----------
Arm(parameters={'channels': 5, 'blocks': 4, 'lr': 0.31573191265112316, 'decay': 0.0010119629753065657})


100%|██████████| 60/60 [02:26<00:00,  2.45s/it]


0.9955700685580572
----------
Arm(parameters={'channels': 6, 'blocks': 6, 'lr': 0.02027555876095451, 'decay': 0.0001598692747843798})


100%|██████████| 60/60 [02:54<00:00,  2.84s/it]

1.1282283266385396





In [5]:
save(exp, 'experiment2.json')

In [6]:
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 1.070 {'channels': 4, 'blocks': 5, 'lr': 0.21603543371097486, 'decay': 0.0029580901153477192}
1: 1.217 {'channels': 4, 'blocks': 6, 'lr': 0.9848772287629501, 'decay': 0.00014802888526512398}
2: 0.965 {'channels': 5, 'blocks': 3, 'lr': 0.012161340519149005, 'decay': 0.008839650491027187}
3: 0.970 {'channels': 6, 'blocks': 7, 'lr': 0.138898276680442, 'decay': 0.0005999404984214526}
4: 1.037 {'channels': 3, 'blocks': 4, 'lr': 0.0643767241744673, 'decay': 0.0013201121650438425}
5: 1.116 {'channels': 5, 'blocks': 5, 'lr': 0.02478188700653261, 'decay': 0.00020157616898544504}
6: 1.152 {'channels': 6, 'blocks': 4, 'lr': 0.3632666147287835, 'decay': 0.003793766245015651}
7: 1.124 {'channels': 6, 'blocks': 6, 'lr': 0.015073313733104839, 'decay': 0.001600272527865882}
8: 1.046 {'channels': 5, 'blocks': 3, 'lr': 0.5938569986663507, 'decay': 0.000847313876941212}
9: 1.161 {'channels': 3, 'blocks': 6, 'lr': 0.2689328593529833, 'decay': 0.005553453025899038}
10: 1.105 {'channels': 6, 'blocks': 5,

## Iterative optimization
Botorch backend for gaussian process expected improvement (GPEI) bayesian optimization

In [20]:
for i in range(20):
    gpei = Models.GPEI(experiment=exp, data=exp.eval())
    gen = gpei.gen(1)
    print('-'*10)
    print(gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
Arm(parameters={'channels': 7, 'blocks': 7, 'lr': 0.28507060267689194, 'decay': 0.0003563685171423693})


100%|██████████| 60/60 [03:19<00:00,  3.31s/it]


0.973848765095075
----------
Arm(parameters={'channels': 6, 'blocks': 5, 'lr': 0.19145894409214131, 'decay': 0.0005460280671159752})


100%|██████████| 60/60 [02:40<00:00,  2.70s/it]


0.9660492936770121
----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.16113289089395605, 'decay': 0.0007045694868138177})


100%|██████████| 60/60 [02:39<00:00,  2.66s/it]


0.9583090369900068
----------
Arm(parameters={'channels': 6, 'blocks': 3, 'lr': 0.018174435662806708, 'decay': 0.01})


100%|██████████| 60/60 [02:09<00:00,  2.16s/it]


0.972822847465674
----------
Arm(parameters={'channels': 6, 'blocks': 3, 'lr': 0.010000000000007097, 'decay': 0.009999999999885332})


100%|██████████| 60/60 [02:09<00:00,  2.16s/it]


1.0137784754236538
----------
Arm(parameters={'channels': 5, 'blocks': 3, 'lr': 0.02163336923793108, 'decay': 0.00999999999999998})


100%|██████████| 60/60 [02:01<00:00,  2.13s/it]


0.9678780163327853
----------
Arm(parameters={'channels': 5, 'blocks': 6, 'lr': 0.20475248088323109, 'decay': 0.0005434495951457596})


100%|██████████| 60/60 [02:52<00:00,  2.86s/it]


0.9581736698746681
----------
Arm(parameters={'channels': 5, 'blocks': 6, 'lr': 0.19046554206537228, 'decay': 0.0006972503970746154})


100%|██████████| 60/60 [02:51<00:00,  2.89s/it]


0.9594393670558929
----------
Arm(parameters={'channels': 6, 'blocks': 6, 'lr': 0.18238651974382744, 'decay': 0.0005037035863202343})


100%|██████████| 60/60 [02:55<00:00,  2.89s/it]


0.9589748084545135
----------
Arm(parameters={'channels': 5, 'blocks': 6, 'lr': 0.18290972156255758, 'decay': 0.0005308518316406664})


100%|██████████| 60/60 [02:52<00:00,  2.88s/it]


0.9675099030137062
----------
Arm(parameters={'channels': 5, 'blocks': 6, 'lr': 0.2637790362606676, 'decay': 0.0005653884360499069})


100%|██████████| 60/60 [02:53<00:00,  2.89s/it]


0.9615141227841377
----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.21770774902915452, 'decay': 0.0006386783589966075})


100%|██████████| 60/60 [02:38<00:00,  2.64s/it]


0.9618698532382647
----------
Arm(parameters={'channels': 6, 'blocks': 7, 'lr': 0.2286035516399889, 'decay': 0.0005652911021648586})


100%|██████████| 60/60 [03:08<00:00,  3.18s/it]


0.967430422703425
----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.13691204871662926, 'decay': 0.0010775611852428116})


100%|██████████| 60/60 [02:40<00:00,  2.67s/it]


0.9562675828735033
----------
Arm(parameters={'channels': 5, 'blocks': 4, 'lr': 0.14251595177924326, 'decay': 0.0009907473128231874})


100%|██████████| 60/60 [02:25<00:00,  2.41s/it]


0.953724722067515
----------
Arm(parameters={'channels': 6, 'blocks': 5, 'lr': 0.13373415852416706, 'decay': 0.0009499481553360297})


100%|██████████| 60/60 [02:42<00:00,  2.71s/it]


0.9557004272937775
----------
Arm(parameters={'channels': 7, 'blocks': 7, 'lr': 0.3886999425816134, 'decay': 0.0001535861335959724})


100%|██████████| 60/60 [03:19<00:00,  3.34s/it]


0.9888438284397125
----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.1553201102775512, 'decay': 0.0009284352198305864})


100%|██████████| 60/60 [02:40<00:00,  2.67s/it]


0.9582739199201266
----------
Arm(parameters={'channels': 5, 'blocks': 4, 'lr': 0.10518893141412945, 'decay': 0.0009977501976736452})


100%|██████████| 60/60 [02:25<00:00,  2.43s/it]


0.963455950220426
----------
Arm(parameters={'channels': 6, 'blocks': 6, 'lr': 0.09837457375786532, 'decay': 0.0013958835395082492})


100%|██████████| 60/60 [02:55<00:00,  2.94s/it]

0.980708102385203





In [21]:
save(exp, 'experiment2.json')

In [22]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(
        i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 1.070 {'channels': 4, 'blocks': 5, 'lr': 0.21603543371097486, 'decay': 0.0029580901153477192}
1: 1.217 {'channels': 4, 'blocks': 6, 'lr': 0.9848772287629501, 'decay': 0.00014802888526512398}
2: 0.965 {'channels': 5, 'blocks': 3, 'lr': 0.012161340519149005, 'decay': 0.008839650491027187}
3: 0.970 {'channels': 6, 'blocks': 7, 'lr': 0.138898276680442, 'decay': 0.0005999404984214526}
4: 1.037 {'channels': 3, 'blocks': 4, 'lr': 0.0643767241744673, 'decay': 0.0013201121650438425}
5: 1.116 {'channels': 5, 'blocks': 5, 'lr': 0.02478188700653261, 'decay': 0.00020157616898544504}
6: 1.152 {'channels': 6, 'blocks': 4, 'lr': 0.3632666147287835, 'decay': 0.003793766245015651}
7: 1.124 {'channels': 6, 'blocks': 6, 'lr': 0.015073313733104839, 'decay': 0.001600272527865882}
8: 1.046 {'channels': 5, 'blocks': 3, 'lr': 0.5938569986663507, 'decay': 0.000847313876941212}
9: 1.161 {'channels': 3, 'blocks': 6, 'lr': 0.2689328593529833, 'decay': 0.005553453025899038}
10: 1.105 {'channels': 6, 'blocks': 5,

### Load experiment and continue testing

In [32]:
exp = load('experiment2.json')
exp.evaluation_function = eval_fn

In [33]:
for i in range(10):
    gpei = Models.GPEI(experiment=exp, data=exp.eval())
    gen = gpei.gen(1)
    print('-'*10)
    print(gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.13086122111363585, 'decay': 0.0009049475747719021})


100%|██████████| 60/60 [02:39<00:00,  2.63s/it]


0.9641616021593412
----------
Arm(parameters={'channels': 5, 'blocks': 4, 'lr': 0.12519093492400113, 'decay': 0.0014152385195784405})


100%|██████████| 60/60 [02:26<00:00,  2.43s/it]


0.9608743538459142
----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.13813912759409627, 'decay': 0.0011335584641535524})


100%|██████████| 60/60 [02:39<00:00,  2.65s/it]


0.9584685117006302
----------
Arm(parameters={'channels': 6, 'blocks': 4, 'lr': 0.1255246184247268, 'decay': 0.001135876619905847})


100%|██████████| 60/60 [02:27<00:00,  2.48s/it]


0.9767742032806078
----------
Arm(parameters={'channels': 6, 'blocks': 6, 'lr': 0.17296412214287352, 'decay': 0.0007201157947105164})


100%|██████████| 60/60 [02:55<00:00,  2.96s/it]


0.973016602297624
----------
Arm(parameters={'channels': 6, 'blocks': 7, 'lr': 0.2583095775478646, 'decay': 0.0003785685082468959})


100%|██████████| 60/60 [03:08<00:00,  3.15s/it]


0.9708575780193011
----------
Arm(parameters={'channels': 4, 'blocks': 3, 'lr': 0.14141740355859383, 'decay': 0.0012275978702929854})


100%|██████████| 60/60 [02:07<00:00,  2.12s/it]


0.9711310217777888
----------
Arm(parameters={'channels': 5, 'blocks': 5, 'lr': 0.017240123221315098, 'decay': 0.008579645880263323})


100%|██████████| 60/60 [02:38<00:00,  2.63s/it]


0.9518526891867319
----------
Arm(parameters={'channels': 5, 'blocks': 4, 'lr': 0.018310713392683586, 'decay': 0.0076383871869869675})


100%|██████████| 60/60 [02:26<00:00,  2.41s/it]


0.9620380947987238
----------
Arm(parameters={'channels': 5, 'blocks': 4, 'lr': 0.01495705515053852, 'decay': 0.01})


100%|██████████| 60/60 [02:25<00:00,  2.43s/it]

0.9663989543914795





In [34]:
save(exp, 'experiment2.json')

In [35]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i in range(40, 50):
    trial = exp.trials[i]
    print('{0}: {1:.3f} {2}'.format(
        i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters))

40: 0.964 {'channels': 5, 'blocks': 5, 'lr': 0.13086122111363585, 'decay': 0.0009049475747719021}
41: 0.961 {'channels': 5, 'blocks': 4, 'lr': 0.12519093492400113, 'decay': 0.0014152385195784405}
42: 0.958 {'channels': 5, 'blocks': 5, 'lr': 0.13813912759409627, 'decay': 0.0011335584641535524}
43: 0.977 {'channels': 6, 'blocks': 4, 'lr': 0.1255246184247268, 'decay': 0.001135876619905847}
44: 0.973 {'channels': 6, 'blocks': 6, 'lr': 0.17296412214287352, 'decay': 0.0007201157947105164}
45: 0.971 {'channels': 6, 'blocks': 7, 'lr': 0.2583095775478646, 'decay': 0.0003785685082468959}
46: 0.971 {'channels': 4, 'blocks': 3, 'lr': 0.14141740355859383, 'decay': 0.0012275978702929854}
47: 0.952 {'channels': 5, 'blocks': 5, 'lr': 0.017240123221315098, 'decay': 0.008579645880263323}
48: 0.962 {'channels': 5, 'blocks': 4, 'lr': 0.018310713392683586, 'decay': 0.0076383871869869675}
49: 0.966 {'channels': 5, 'blocks': 4, 'lr': 0.01495705515053852, 'decay': 0.01}


## Plot results

In [36]:
render(plot_contour(model=gpei, param_x='channels', param_y='blocks', metric_name='min_val'))

In [37]:
render(plot_contour(model=gpei, param_x='lr', param_y='decay', metric_name='min_val'))

In [38]:
render(plot_contour(model=gpei, param_x='channels', param_y='decay', metric_name='min_val'))

In [39]:
render(plot_contour(model=gpei, param_x='lr', param_y='blocks', metric_name='min_val'))

In [40]:
render(plot_slice(model=gpei, param_name='blocks', metric_name='min_val'))

In [41]:
render(plot_slice(model=gpei, param_name='lr', metric_name='min_val'))

In [42]:
render(plot_slice(model=gpei, param_name='decay', metric_name='min_val'))

In [43]:
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple 
# optimization runs, so we wrap out best objectives array in another array.
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
        y=np.minimum.accumulate(objective_means, axis=1),
)
render(best_objective_plot)

In [44]:
exp.trials[np.argmin(objective_means)].arm

Arm(name='47_0', parameters={'channels': 5, 'blocks': 5, 'lr': 0.017240123221315098, 'decay': 0.008579645880263323})