# Fine tune on games 60-80

In [1]:
import numpy as np
# import pandas as pd
from time import sleep

from ax import *

from ax.plot.contour import plot_contour
from ax.plot.slice import plot_slice
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()

# import evaluation function defined locally as the
# mean logscore of 1000 eval games after training
from eval_fn_ax import eval_fn

[INFO 08-09 07:43:19] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.


## Define parameters
Network parameters
* `channels`: Defaults to 16 (256 in paper)
* `num_blocks`: Defaults to 6 (19 in paper)
* `out_c`: Defaults to 4 (2 in paper). Keep this fixed for now.

Training parameters
* `lr`: Defaults to 0.1
* `decay`: Defaults to 1e-4.
* `momentum`: Defaults to 0.9. Keep this fixed for now.

In [2]:
lr_param = RangeParameter(name="lr", parameter_type=ParameterType.FLOAT,
                             lower=1e-3,
                             upper=1e-1,
                             log_scale=True,
                             )
dec_param = RangeParameter(name="decay", parameter_type=ParameterType.FLOAT,
                             lower=1e-4,
                             upper=1e-2,
                             log_scale=True,
                             )
epox_param = RangeParameter(name="epochs", parameter_type=ParameterType.INT,
                            lower=20,
                            upper=100,
                            log_scale=False,
                            )
search_space = SearchSpace(
    parameters=[lr_param, dec_param, epox_param],
)

## Define experiment and initial arms

In [3]:
exp = SimpleExperiment(
    name='optimize',
    search_space=search_space,
    evaluation_function=eval_fn,
    objective_name='min_move',
    minimize=False,
)

In [4]:
exp.new_trial().add_arm(Arm(name="control",parameters={'lr': 0.06, 'decay': 2e-3, 'epochs': 70}))
exp.trials[0].fetch_data()

100%|██████████| 69/69 [01:46<00:00,  1.51s/it]


101.4 / 11 sec


<ax.core.data.Data at 0x7f138814c7f0>

In [5]:
sobol = Models.SOBOL(exp.search_space)
for i in range(10):
    gen = sobol.gen(1)
    print('-'*20)
    print(i, gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

--------------------
0 Arm(parameters={'lr': 0.07839425700208687, 'decay': 0.004649107668526686, 'epochs': 71})


100%|██████████| 69/69 [01:46<00:00,  1.58s/it]


64.7 / 8 sec
--------------------
1 Arm(parameters={'lr': 0.016511722904021142, 'decay': 0.00041288715787430073, 'epochs': 83})


100%|██████████| 81/81 [02:05<00:00,  1.53s/it]


54.5 / 7 sec
--------------------
2 Arm(parameters={'lr': 0.007976956609636175, 'decay': 0.0028356278671029336, 'epochs': 45})


100%|██████████| 43/43 [01:06<00:00,  1.53s/it]


64.7 / 7 sec
--------------------
3 Arm(parameters={'lr': 0.003719857998352154, 'decay': 0.0006385193350709473, 'epochs': 61})


100%|██████████| 59/59 [01:31<00:00,  1.54s/it]


61.7 / 8 sec
--------------------
4 Arm(parameters={'lr': 0.018280684548419125, 'decay': 0.0013376593059985793, 'epochs': 28})


100%|██████████| 27/27 [00:41<00:00,  1.55s/it]


85.8 / 10 sec
--------------------
5 Arm(parameters={'lr': 0.036598483632419736, 'decay': 0.00012315357925586642, 'epochs': 57})


100%|██████████| 55/55 [01:26<00:00,  1.57s/it]


57.1 / 7 sec
--------------------
6 Arm(parameters={'lr': 0.0017964235797745013, 'decay': 0.009421615689971193, 'epochs': 94})


100%|██████████| 93/93 [02:24<00:00,  1.55s/it]


62.5 / 8 sec
--------------------
7 Arm(parameters={'lr': 0.0027460956900874065, 'decay': 0.0005059898389563055, 'epochs': 44})


100%|██████████| 43/43 [01:07<00:00,  1.57s/it]


61.5 / 8 sec
--------------------
8 Arm(parameters={'lr': 0.0426709388481696, 'decay': 0.002247218136513755, 'epochs': 87})


100%|██████████| 85/85 [02:12<00:00,  1.51s/it]


87.2 / 9 sec
--------------------
9 Arm(parameters={'lr': 0.02800566925168594, 'decay': 0.0002677039277472593, 'epochs': 78})


100%|██████████| 77/77 [02:00<00:00,  1.57s/it]


58.9 / 7 sec


In [6]:
save(exp, 'opt_0_100.json')

In [7]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 101.400 {'lr': 0.06, 'decay': 0.002, 'epochs': 70}
1: 64.725 {'lr': 0.07839425700208687, 'decay': 0.004649107668526686, 'epochs': 71}
2: 54.450 {'lr': 0.016511722904021142, 'decay': 0.00041288715787430073, 'epochs': 83}
3: 64.700 {'lr': 0.007976956609636175, 'decay': 0.0028356278671029336, 'epochs': 45}
4: 61.675 {'lr': 0.003719857998352154, 'decay': 0.0006385193350709473, 'epochs': 61}
5: 85.850 {'lr': 0.018280684548419125, 'decay': 0.0013376593059985793, 'epochs': 28}
6: 57.125 {'lr': 0.036598483632419736, 'decay': 0.00012315357925586642, 'epochs': 57}
7: 62.500 {'lr': 0.0017964235797745013, 'decay': 0.009421615689971193, 'epochs': 94}
8: 61.475 {'lr': 0.0027460956900874065, 'decay': 0.0005059898389563055, 'epochs': 44}
9: 87.200 {'lr': 0.0426709388481696, 'decay': 0.002247218136513755, 'epochs': 87}
10: 58.850 {'lr': 0.02800566925168594, 'decay': 0.0002677039277472593, 'epochs': 78}


## Iterative optimization
Botorch backend for gaussian process expected improvement (GPEI) bayesian optimization

In [15]:
for i in range(20):
    gpei = Models.GPEI(experiment=exp, data=exp.eval())
    gen = gpei.gen(1)
    print('-'*10)
    print(i, gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
0 Arm(parameters={'lr': 0.0685188376679517, 'decay': 0.0016482494131049618, 'epochs': 60})


100%|██████████| 59/59 [01:32<00:00,  1.61s/it]


80.5 / 9 sec
----------
1 Arm(parameters={'lr': 0.0697169377514626, 'decay': 0.001987605498296668, 'epochs': 75})


100%|██████████| 73/73 [01:54<00:00,  1.55s/it]


87.7 / 9 sec
----------
2 Arm(parameters={'lr': 0.05441266281255867, 'decay': 0.001897592002189751, 'epochs': 73})


100%|██████████| 71/71 [01:51<00:00,  1.59s/it]


74.4 / 9 sec
----------
3 Arm(parameters={'lr': 0.06593370364780642, 'decay': 0.001681425245434636, 'epochs': 69})


100%|██████████| 67/67 [01:44<00:00,  1.55s/it]


105.0 / 12 sec
----------
4 Arm(parameters={'lr': 0.0370041853218101, 'decay': 0.0017477292689809694, 'epochs': 69})


100%|██████████| 67/67 [01:45<00:00,  1.56s/it]


96.7 / 11 sec
----------
5 Arm(parameters={'lr': 0.09938492371943926, 'decay': 0.0015626758209768153, 'epochs': 70})


100%|██████████| 69/69 [01:47<00:00,  1.53s/it]


87.1 / 11 sec
----------
6 Arm(parameters={'lr': 0.05809303501408004, 'decay': 0.0020319116910216093, 'epochs': 69})


100%|██████████| 67/67 [01:36<00:00,  1.27s/it]


98.5 / 11 sec
----------
7 Arm(parameters={'lr': 0.06278539344348018, 'decay': 0.0018074309146501322, 'epochs': 68})


100%|██████████| 67/67 [01:29<00:00,  1.52s/it]


98.6 / 11 sec
----------
8 Arm(parameters={'lr': 0.05581202297348793, 'decay': 0.0014879180519567713, 'epochs': 69})


100%|██████████| 67/67 [01:41<00:00,  1.28s/it]


100.0 / 12 sec
----------
9 Arm(parameters={'lr': 0.06862525836312534, 'decay': 0.0016138097700495336, 'epochs': 70})


100%|██████████| 69/69 [01:45<00:00,  1.54s/it]


101.6 / 11 sec
----------
10 Arm(parameters={'lr': 0.09373144497300952, 'decay': 0.0021325495969080665, 'epochs': 78})


100%|██████████| 77/77 [01:59<00:00,  1.53s/it]


93.5 / 11 sec
----------
11 Arm(parameters={'lr': 0.1, 'decay': 0.0014867706919937198, 'epochs': 77})


100%|██████████| 75/75 [01:56<00:00,  1.52s/it]


91.4 / 10 sec
----------
12 Arm(parameters={'lr': 0.05025995239891894, 'decay': 0.0017713524339547954, 'epochs': 78})


100%|██████████| 77/77 [01:59<00:00,  1.57s/it]


90.3 / 11 sec
----------
13 Arm(parameters={'lr': 0.0689066000284994, 'decay': 0.00198557475386927, 'epochs': 82})


100%|██████████| 81/81 [02:04<00:00,  1.54s/it]


67.7 / 8 sec
----------
14 Arm(parameters={'lr': 0.03293699870473663, 'decay': 0.0024731563662717107, 'epochs': 90})


100%|██████████| 89/89 [02:01<00:00,  1.38s/it]


74.6 / 8 sec
----------
15 Arm(parameters={'lr': 0.05811432304506772, 'decay': 0.0027781964192298, 'epochs': 77})


100%|██████████| 75/75 [01:46<00:00,  1.46s/it]


71.0 / 8 sec
----------
16 Arm(parameters={'lr': 0.045972301814203344, 'decay': 0.0013639257353456765, 'epochs': 65})


100%|██████████| 63/63 [01:39<00:00,  1.57s/it]


95.2 / 11 sec
----------
17 Arm(parameters={'lr': 0.09030546574920537, 'decay': 0.0013564204087068111, 'epochs': 65})


100%|██████████| 63/63 [01:29<00:00,  1.24s/it]


92.0 / 11 sec
----------
18 Arm(parameters={'lr': 0.055816927378238644, 'decay': 0.0010293252620656157, 'epochs': 63})


100%|██████████| 61/61 [01:34<00:00,  1.55s/it]


91.8 / 10 sec
----------
19 Arm(parameters={'lr': 0.05379056153462837, 'decay': 0.0009920240891720343, 'epochs': 66})


100%|██████████| 65/65 [01:41<00:00,  1.60s/it]


99.3 / 11 sec


In [16]:
save(exp, 'opt_0_100.json')

In [17]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(
        i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 101.400 {'lr': 0.06, 'decay': 0.002, 'epochs': 70}
1: 64.725 {'lr': 0.07839425700208687, 'decay': 0.004649107668526686, 'epochs': 71}
2: 54.450 {'lr': 0.016511722904021142, 'decay': 0.00041288715787430073, 'epochs': 83}
3: 64.700 {'lr': 0.007976956609636175, 'decay': 0.0028356278671029336, 'epochs': 45}
4: 61.675 {'lr': 0.003719857998352154, 'decay': 0.0006385193350709473, 'epochs': 61}
5: 85.850 {'lr': 0.018280684548419125, 'decay': 0.0013376593059985793, 'epochs': 28}
6: 57.125 {'lr': 0.036598483632419736, 'decay': 0.00012315357925586642, 'epochs': 57}
7: 62.500 {'lr': 0.0017964235797745013, 'decay': 0.009421615689971193, 'epochs': 94}
8: 61.475 {'lr': 0.0027460956900874065, 'decay': 0.0005059898389563055, 'epochs': 44}
9: 87.200 {'lr': 0.0426709388481696, 'decay': 0.002247218136513755, 'epochs': 87}
10: 58.850 {'lr': 0.02800566925168594, 'decay': 0.0002677039277472593, 'epochs': 78}
11: 80.475 {'lr': 0.0685188376679517, 'decay': 0.0016482494131049618, 'epochs': 60}
12: 87.650 {'l

### Load experiment

In [4]:
exp = load('finetune1.json')
exp.evaluation_function = eval_fn

## Plot results

In [18]:
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple 
# optimization runs, so we wrap out best objectives array in another array.
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
        y=np.maximum.accumulate(objective_means, axis=1),
)
render(best_objective_plot)

In [19]:
exp.trials[np.argmax(objective_means)].arm

Arm(name='14_0', parameters={'lr': 0.06593370364780642, 'decay': 0.001681425245434636, 'epochs': 69})

In [20]:
render(plot_contour(model=gpei, param_x='lr', param_y='decay', metric_name='min_move'))

In [21]:
render(plot_contour(model=gpei, param_x='lr', param_y='epochs', metric_name='min_move'))

In [22]:
render(plot_slice(model=gpei, param_name='lr', metric_name='min_move'))

In [23]:
render(plot_slice(model=gpei, param_name='decay', metric_name='min_move'))

In [24]:
render(plot_slice(model=gpei, param_name='epochs', metric_name='min_move'))

Results:
* 30 arm trial: lr=0.06, decay=1.7e-3, epochs=69

In [25]:
print(eval_fn({'lr': 0.06, 'decay': 1.7e-3, 'epochs': 69}))

100%|██████████| 67/67 [01:34<00:00,  1.40s/it]


102.5 / 12 sec
102.525
