# Fine tune on games 60-80

In [1]:
import numpy as np
# import pandas as pd
from time import sleep

from ax import *

from ax.plot.contour import plot_contour
from ax.plot.slice import plot_slice
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()

# import evaluation function defined locally as the
# mean logscore of 1000 eval games after training
from finetune_ax import eval_fn

[INFO 07-01 10:44:53] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.


2724


## Define parameters
Network parameters
* `channels`: Defaults to 16 (256 in paper)
* `num_blocks`: Defaults to 6 (19 in paper)
* `out_c`: Defaults to 4 (2 in paper). Keep this fixed for now.

Training parameters
* `lr`: Defaults to 0.1
* `decay`: Defaults to 1e-4.
* `momentum`: Defaults to 0.9. Keep this fixed for now.

In [2]:
lr_param = RangeParameter(name="lr", parameter_type=ParameterType.FLOAT,
                             lower=1e-4,
                             upper=0.125,
                             log_scale=True,
                             )
dec_param = RangeParameter(name="decay", parameter_type=ParameterType.FLOAT,
                             lower=1e-5,
                             upper=1e-2,
                             log_scale=True,
                             )
epox_param = RangeParameter(name="epochs", parameter_type=ParameterType.INT,
                             lower=0,
                             upper=50,
                             )
search_space = SearchSpace(
    parameters=[lr_param, dec_param, epox_param],
)

## Define experiment and initial arms

In [3]:
exp = SimpleExperiment(
    name='optimize_finetune',
    search_space=search_space,
    evaluation_function=eval_fn,
    objective_name='log_eval',
    minimize=False,
)

In [4]:
# Couldn't get status_quo to work. Adding control arm manually:
exp.new_trial().add_arm(Arm(name="zero",parameters={'lr': 0.125, 'decay': 0.0012, 'epochs': 0}))
exp.new_trial().add_arm(Arm(name="control",parameters={'lr': 0.125, 'decay': 0.0012, 'epochs': 20}))
exp.trials[0].fetch_data()
exp.trials[1].fetch_data()

0it [00:00, ?it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

3.340 / 0.011


100%|██████████| 20/20 [00:26<00:00,  1.32s/it]


3.155 / 0.012


<ax.core.data.Data at 0x7f3e06175400>

In [5]:
sobol = Models.SOBOL(exp.search_space)
for i in range(20):
    gen = sobol.gen(1)
    print('-'*10)
    print(gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
Arm(parameters={'lr': 0.08222266530738154, 'decay': 0.000787696268305935, 'epochs': 28})


100%|██████████| 28/28 [00:37<00:00,  1.34s/it]


3.177 / 0.013
----------
Arm(parameters={'lr': 0.017995068324623793, 'decay': 2.0432724797792104e-05, 'epochs': 44})


100%|██████████| 44/44 [00:59<00:00,  1.35s/it]


3.235 / 0.012
----------
Arm(parameters={'lr': 0.00024507982427034327, 'decay': 0.0017904906868920903, 'epochs': 3})


100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


3.394 / 0.011
----------
Arm(parameters={'lr': 0.00017137166229077094, 'decay': 3.5115084834239404e-05, 'epochs': 34})


100%|██████████| 34/34 [00:45<00:00,  1.36s/it]


3.454 / 0.011
----------
Arm(parameters={'lr': 0.004551821441836866, 'decay': 0.007375404959949084, 'epochs': 19})


100%|██████████| 19/19 [00:25<00:00,  1.37s/it]


3.632 / 0.013
----------
Arm(parameters={'lr': 0.023212195123507594, 'decay': 8.455812526476041e-05, 'epochs': 9})


100%|██████████| 9/9 [00:12<00:00,  1.36s/it]


3.580 / 0.014
----------
Arm(parameters={'lr': 0.0012195918445462937, 'decay': 0.0005595295912680022, 'epochs': 44})


100%|██████████| 44/44 [01:00<00:00,  1.38s/it]


3.656 / 0.012
----------
Arm(parameters={'lr': 0.0006781060687356291, 'decay': 1.0946176570293072e-05, 'epochs': 0})


0it [00:00, ?it/s]


3.319 / 0.011
----------
Arm(parameters={'lr': 0.04978427747990496, 'decay': 0.003527565121958421, 'epochs': 47})


100%|██████████| 47/47 [01:03<00:00,  1.35s/it]


3.112 / 0.012
----------
Arm(parameters={'lr': 0.006247040048381291, 'decay': 0.0001438024293565146, 'epochs': 25})


100%|██████████| 25/25 [00:33<00:00,  1.38s/it]


3.424 / 0.014
----------
Arm(parameters={'lr': 0.00014838731399692169, 'decay': 0.0014697324908619514, 'epochs': 22})


100%|██████████| 22/22 [00:30<00:00,  1.36s/it]


3.416 / 0.011
----------
Arm(parameters={'lr': 0.0005519527802397052, 'decay': 0.00010246759190872975, 'epochs': 41})


100%|██████████| 41/41 [00:55<00:00,  1.36s/it]


3.546 / 0.013
----------
Arm(parameters={'lr': 0.010504083634447867, 'decay': 0.00043748352890575796, 'epochs': 12})


100%|██████████| 12/12 [00:16<00:00,  1.36s/it]


3.634 / 0.013
----------
Arm(parameters={'lr': 0.07475959993275257, 'decay': 4.4930071267683065e-05, 'epochs': 16})


100%|██████████| 16/16 [00:21<00:00,  1.36s/it]


3.361 / 0.013
----------
Arm(parameters={'lr': 0.0028143370815851038, 'decay': 0.006083735575886048, 'epochs': 37})


100%|██████████| 37/37 [00:50<00:00,  1.35s/it]


3.528 / 0.012
----------
Arm(parameters={'lr': 0.0031722460970575166, 'decay': 5.153458431222948e-05, 'epochs': 48})


100%|██████████| 48/48 [01:05<00:00,  1.38s/it]


3.357 / 0.012
----------
Arm(parameters={'lr': 0.061217137158285154, 'decay': 0.005093566161426102, 'epochs': 2})


100%|██████████| 2/2 [00:02<00:00,  1.37s/it]


3.168 / 0.012
----------
Arm(parameters={'lr': 0.012302214003413503, 'decay': 0.0001241083471404148, 'epochs': 24})


100%|██████████| 24/24 [00:32<00:00,  1.36s/it]


3.311 / 0.013
----------
Arm(parameters={'lr': 0.00045671365891730595, 'decay': 0.0003864138666601675, 'epochs': 27})


100%|██████████| 27/27 [00:36<00:00,  1.36s/it]


3.493 / 0.012
----------
Arm(parameters={'lr': 0.00011818864786637608, 'decay': 0.00017291991110745218, 'epochs': 11})


100%|██████████| 11/11 [00:15<00:00,  1.38s/it]


3.413 / 0.011


In [6]:
save(exp, 'finetune1.json')

In [7]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 3.340 {'lr': 0.125, 'decay': 0.0012, 'epochs': 0}
1: 3.155 {'lr': 0.125, 'decay': 0.0012, 'epochs': 20}
2: 3.177 {'lr': 0.08222266530738154, 'decay': 0.000787696268305935, 'epochs': 28}
3: 3.235 {'lr': 0.017995068324623793, 'decay': 2.0432724797792104e-05, 'epochs': 44}
4: 3.394 {'lr': 0.00024507982427034327, 'decay': 0.0017904906868920903, 'epochs': 3}
5: 3.454 {'lr': 0.00017137166229077094, 'decay': 3.5115084834239404e-05, 'epochs': 34}
6: 3.632 {'lr': 0.004551821441836866, 'decay': 0.007375404959949084, 'epochs': 19}
7: 3.580 {'lr': 0.023212195123507594, 'decay': 8.455812526476041e-05, 'epochs': 9}
8: 3.656 {'lr': 0.0012195918445462937, 'decay': 0.0005595295912680022, 'epochs': 44}
9: 3.319 {'lr': 0.0006781060687356291, 'decay': 1.0946176570293072e-05, 'epochs': 0}
10: 3.112 {'lr': 0.04978427747990496, 'decay': 0.003527565121958421, 'epochs': 47}
11: 3.424 {'lr': 0.006247040048381291, 'decay': 0.0001438024293565146, 'epochs': 25}
12: 3.416 {'lr': 0.00014838731399692169, 'decay': 

## Iterative optimization
Botorch backend for gaussian process expected improvement (GPEI) bayesian optimization

In [16]:
for i in range(30):
    gpei = Models.GPEI(experiment=exp, data=exp.eval())
    gen = gpei.gen(1)
    print('-'*10)
    print(gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
Arm(parameters={'lr': 0.0042958675936845365, 'decay': 0.0010471843418864235, 'epochs': 10})


100%|██████████| 10/10 [00:13<00:00,  1.36s/it]


3.678 / 0.013
----------
Arm(parameters={'lr': 0.008370803588215753, 'decay': 0.00022333890726881868, 'epochs': 2})


100%|██████████| 2/2 [00:02<00:00,  1.36s/it]


3.519 / 0.012
----------
Arm(parameters={'lr': 0.003423667093075686, 'decay': 0.0013399111889739398, 'epochs': 15})


100%|██████████| 15/15 [00:20<00:00,  1.35s/it]


3.650 / 0.012
----------
Arm(parameters={'lr': 0.0006869113033726057, 'decay': 0.0018229361036326858, 'epochs': 45})


100%|██████████| 45/45 [00:59<00:00,  1.32s/it]


3.598 / 0.013
----------
Arm(parameters={'lr': 0.00558261633492821, 'decay': 0.0020691025936057034, 'epochs': 12})


100%|██████████| 12/12 [00:15<00:00,  1.32s/it]


3.594 / 0.014
----------
Arm(parameters={'lr': 0.0038535517531817546, 'decay': 0.00039530815909391445, 'epochs': 11})


100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


3.604 / 0.012
----------
Arm(parameters={'lr': 0.001831685554727031, 'decay': 0.0062428424383467575, 'epochs': 21})


100%|██████████| 21/21 [00:28<00:00,  1.35s/it]


3.612 / 0.012
----------
Arm(parameters={'lr': 0.0013184302548125534, 'decay': 0.0008283421858509831, 'epochs': 37})


100%|██████████| 37/37 [00:50<00:00,  1.34s/it]


3.617 / 0.013
----------
Arm(parameters={'lr': 0.0005830066799163658, 'decay': 0.0004613016523184007, 'epochs': 49})


100%|██████████| 49/49 [01:06<00:00,  1.37s/it]


3.599 / 0.011
----------
Arm(parameters={'lr': 0.0023223184893015874, 'decay': 0.0016108860028932247, 'epochs': 9})


100%|██████████| 9/9 [00:12<00:00,  1.34s/it]


3.559 / 0.012
----------
Arm(parameters={'lr': 0.005391083307693113, 'decay': 0.0008003947035889143, 'epochs': 12})


100%|██████████| 12/12 [00:16<00:00,  1.36s/it]


3.625 / 0.012
----------
Arm(parameters={'lr': 0.003103488748160297, 'decay': 0.0035133869602214274, 'epochs': 22})


100%|██████████| 22/22 [00:29<00:00,  1.36s/it]


3.642 / 0.015
----------
Arm(parameters={'lr': 0.0012152946378868432, 'decay': 0.0011010052894948998, 'epochs': 50})


100%|██████████| 50/50 [01:07<00:00,  1.33s/it]


3.668 / 0.013
----------
Arm(parameters={'lr': 0.0030470682447321057, 'decay': 0.01, 'epochs': 17})


100%|██████████| 17/17 [00:22<00:00,  1.37s/it]


3.695 / 0.014
----------
Arm(parameters={'lr': 0.003323542918335707, 'decay': 0.01, 'epochs': 12})


100%|██████████| 12/12 [00:16<00:00,  1.34s/it]


3.568 / 0.013
----------
Arm(parameters={'lr': 0.0029671541634396707, 'decay': 0.00999999999999998, 'epochs': 20})


100%|██████████| 20/20 [00:26<00:00,  1.34s/it]


3.647 / 0.013
----------
Arm(parameters={'lr': 0.0028992131921908664, 'decay': 0.00484521577137573, 'epochs': 17})


100%|██████████| 17/17 [00:23<00:00,  1.35s/it]


3.631 / 0.013
----------
Arm(parameters={'lr': 0.0015930102849367376, 'decay': 0.0012974881820559708, 'epochs': 46})


100%|██████████| 46/46 [01:02<00:00,  1.37s/it]


3.624 / 0.013
----------
Arm(parameters={'lr': 0.0011762941184528637, 'decay': 0.0005560489900774055, 'epochs': 50})


100%|██████████| 50/50 [01:07<00:00,  1.36s/it]


3.632 / 0.013
----------
Arm(parameters={'lr': 0.002157058720142707, 'decay': 0.0008397548096370134, 'epochs': 24})


100%|██████████| 24/24 [00:32<00:00,  1.35s/it]


3.616 / 0.013
----------
Arm(parameters={'lr': 0.0011339380519406353, 'decay': 0.0036838855554993935, 'epochs': 50})


100%|██████████| 50/50 [01:07<00:00,  1.36s/it]


3.639 / 0.014
----------
Arm(parameters={'lr': 0.000966822231346963, 'decay': 0.0014159222420817687, 'epochs': 50})


100%|██████████| 50/50 [01:07<00:00,  1.35s/it]


3.667 / 0.013
----------
Arm(parameters={'lr': 0.04044537228104681, 'decay': 1.0773867568563373e-05, 'epochs': 0})


0it [00:00, ?it/s]


3.308 / 0.012
----------
Arm(parameters={'lr': 0.016038689473977068, 'decay': 0.00033224162007139364, 'epochs': 8})


100%|██████████| 8/8 [00:10<00:00,  1.37s/it]


3.620 / 0.013
----------
Arm(parameters={'lr': 0.0014196493120919636, 'decay': 0.0001894540113910872, 'epochs': 34})


100%|██████████| 34/34 [00:45<00:00,  1.37s/it]


3.614 / 0.013
----------
Arm(parameters={'lr': 0.00034276999307719663, 'decay': 0.00999999999999998, 'epochs': 50})


100%|██████████| 50/50 [01:08<00:00,  1.35s/it]


3.527 / 0.012
----------
Arm(parameters={'lr': 0.005938148149707883, 'decay': 0.0007840983805538117, 'epochs': 7})


100%|██████████| 7/7 [00:09<00:00,  1.37s/it]


3.648 / 0.012
----------
Arm(parameters={'lr': 0.001031454624264051, 'decay': 0.0009518252724083013, 'epochs': 48})


100%|██████████| 48/48 [01:05<00:00,  1.39s/it]


3.622 / 0.013
----------
Arm(parameters={'lr': 0.0015862531604133104, 'decay': 0.00034996139778338385, 'epochs': 40})


100%|██████████| 40/40 [00:54<00:00,  1.33s/it]


3.637 / 0.013
----------
Arm(parameters={'lr': 0.0001, 'decay': 0.0007786274174831826, 'epochs': 50})


100%|██████████| 50/50 [01:08<00:00,  1.34s/it]


3.438 / 0.012


In [17]:
save(exp, 'finetune1.json')

In [18]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(
        i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 3.340 {'lr': 0.125, 'decay': 0.0012, 'epochs': 0}
1: 3.155 {'lr': 0.125, 'decay': 0.0012, 'epochs': 20}
2: 3.177 {'lr': 0.08222266530738154, 'decay': 0.000787696268305935, 'epochs': 28}
3: 3.235 {'lr': 0.017995068324623793, 'decay': 2.0432724797792104e-05, 'epochs': 44}
4: 3.394 {'lr': 0.00024507982427034327, 'decay': 0.0017904906868920903, 'epochs': 3}
5: 3.454 {'lr': 0.00017137166229077094, 'decay': 3.5115084834239404e-05, 'epochs': 34}
6: 3.632 {'lr': 0.004551821441836866, 'decay': 0.007375404959949084, 'epochs': 19}
7: 3.580 {'lr': 0.023212195123507594, 'decay': 8.455812526476041e-05, 'epochs': 9}
8: 3.656 {'lr': 0.0012195918445462937, 'decay': 0.0005595295912680022, 'epochs': 44}
9: 3.319 {'lr': 0.0006781060687356291, 'decay': 1.0946176570293072e-05, 'epochs': 0}
10: 3.112 {'lr': 0.04978427747990496, 'decay': 0.003527565121958421, 'epochs': 47}
11: 3.424 {'lr': 0.006247040048381291, 'decay': 0.0001438024293565146, 'epochs': 25}
12: 3.416 {'lr': 0.00014838731399692169, 'decay': 

### Load experiment

In [4]:
exp = load('finetune1.json')
exp.evaluation_function = eval_fn

## Plot results

In [19]:
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple 
# optimization runs, so we wrap out best objectives array in another array.
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
        y=np.maximum.accumulate(objective_means, axis=1),
)
render(best_objective_plot)

In [20]:
exp.trials[np.argmax(objective_means)].arm

Arm(name='35_0', parameters={'lr': 0.0030470682447321057, 'decay': 0.01, 'epochs': 17})

In [21]:
render(plot_contour(model=gpei, param_x='lr', param_y='decay', metric_name='log_eval'))

Best decay about the same as in regular training.

In [22]:
render(plot_contour(model=gpei, param_x='epochs', param_y='lr', metric_name='log_eval'))

A much lower lr for fine tuning and shorter epochs. The more epochs, the lower the lr.

In [23]:
render(plot_contour(model=gpei, param_x='epochs', param_y='decay', metric_name='log_eval'))

In [24]:
render(plot_slice(model=gpei, param_name='lr', metric_name='log_eval'))

In [25]:
render(plot_slice(model=gpei, param_name='decay', metric_name='log_eval'))

In [26]:
render(plot_slice(model=gpei, param_name='epochs', metric_name='log_eval'))

Decided to go with run 24 due to short epochs and similar decay: epochs=>10, lr=>0.0043, decay=0.0012