# Fine tune on games 80-100

In [1]:
import numpy as np
# import pandas as pd
from time import sleep

from ax import *

from ax.plot.contour import plot_contour
from ax.plot.slice import plot_slice
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()

# import evaluation function defined locally as the
# mean logscore of 1000 eval games after training
from finetune_ax import eval_fn

[INFO 07-10 08:36:51] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.


2724


## Define parameters
Training parameters
* `lr`: Previously 0.0043 (1e-3 to 1e-2)
* `decay`: Previously 0.0012 (1e-4 to 1e-2), same as initial training.
* `epochs`: Previously 10 (0 to 20)

In [2]:
lr_param = RangeParameter(name="lr", parameter_type=ParameterType.FLOAT,
                             lower=1e-4,
                             upper=1e-2,
                             log_scale=True,
                             )
dec_param = RangeParameter(name="decay", parameter_type=ParameterType.FLOAT,
                             lower=1e-4,
                             upper=1e-2,
                             log_scale=True,
                             )
epox_param = RangeParameter(name="epochs", parameter_type=ParameterType.INT,
                             lower=0,
                             upper=20,
                             )
search_space = SearchSpace(
    parameters=[lr_param, dec_param, epox_param],
)

## Define experiment and initial arms

In [3]:
exp = SimpleExperiment(
    name='optimize_finetune',
    search_space=search_space,
    evaluation_function=eval_fn,
    objective_name='log_eval',
    minimize=False,  # MAXIMIZE!!
)

In [5]:
# Couldn't get status_quo to work. Adding control arm manually:
exp.new_trial().add_arm(Arm(name="zero",parameters={'lr': 0.0043, 'decay': 0.0012, 'epochs': 0}))
exp.new_trial().add_arm(Arm(name="control1",parameters={'lr': 0.0043, 'decay': 0.0012, 'epochs': 5}))
exp.new_trial().add_arm(Arm(name="control2",parameters={'lr': 0.0043, 'decay': 0.0012, 'epochs': 10}))
exp.trials[0].fetch_data()
exp.trials[1].fetch_data()
exp.trials[2].fetch_data()

0it [00:00, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

3.636 / 0.012


100%|██████████| 5/5 [00:07<00:00,  1.54s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

3.767 / 0.014


100%|██████████| 10/10 [00:16<00:00,  1.68s/it]


3.794 / 0.014


<ax.core.data.Data at 0x7efeac559240>

Fine tuning with previous parameters already gives benefit (this isn't guaranteed).

In [6]:
sobol = Models.SOBOL(exp.search_space)
for i in range(20):
    gen = sobol.gen(1)
    print('-'*10)
    print(gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
Arm(parameters={'lr': 0.006719886554586096, 'decay': 0.0025712614367976697, 'epochs': 17})


100%|██████████| 17/17 [00:28<00:00,  1.70s/it]


3.528 / 0.014
----------
Arm(parameters={'lr': 0.0012230710776559032, 'decay': 0.0001245490767846484, 'epochs': 14})


100%|██████████| 14/14 [00:24<00:00,  1.75s/it]


3.722 / 0.014
----------
Arm(parameters={'lr': 0.00015703189811414284, 'decay': 0.004250773039859609, 'epochs': 6})


100%|██████████| 6/6 [00:10<00:00,  1.73s/it]


3.656 / 0.014
----------
Arm(parameters={'lr': 0.00027734698125494284, 'decay': 0.00024503393137916316, 'epochs': 19})


100%|██████████| 19/19 [00:32<00:00,  1.71s/it]


3.720 / 0.014
----------
Arm(parameters={'lr': 0.00215089431162011, 'decay': 0.007181061645086981, 'epochs': 2})


100%|██████████| 2/2 [00:03<00:00,  1.75s/it]


3.672 / 0.013
----------
Arm(parameters={'lr': 0.003669268451658114, 'decay': 0.0004685768597399909, 'epochs': 8})


100%|██████████| 8/8 [00:13<00:00,  1.71s/it]


3.754 / 0.014
----------
Arm(parameters={'lr': 0.0005083524213359429, 'decay': 0.0011885017054668588, 'epochs': 11})


100%|██████████| 11/11 [00:19<00:00,  1.73s/it]


3.726 / 0.014
----------
Arm(parameters={'lr': 0.0003885480703922011, 'decay': 0.00017572911480920575, 'epochs': 7})


100%|██████████| 7/7 [00:12<00:00,  1.73s/it]


3.715 / 0.014
----------
Arm(parameters={'lr': 0.004987883132707978, 'decay': 0.003342016822981182, 'epochs': 14})


100%|██████████| 14/14 [00:24<00:00,  1.74s/it]


3.740 / 0.013
----------
Arm(parameters={'lr': 0.0029168282988909335, 'decay': 0.0009197168866264231, 'epochs': 16})


100%|██████████| 16/16 [00:27<00:00,  1.73s/it]


3.719 / 0.015
----------
Arm(parameters={'lr': 0.00021152642014361876, 'decay': 0.0020209888929120954, 'epochs': 3})


100%|██████████| 3/3 [00:05<00:00,  1.73s/it]


3.638 / 0.014
----------
Arm(parameters={'lr': 0.0001155375773105539, 'decay': 0.0003570341502470353, 'epochs': 11})


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


3.682 / 0.014
----------
Arm(parameters={'lr': 0.0016000375420588211, 'decay': 0.001624692156473758, 'epochs': 9})


100%|██████████| 9/9 [00:15<00:00,  1.71s/it]


3.732 / 0.014
----------
Arm(parameters={'lr': 0.008811602322531434, 'decay': 0.000186652607756163, 'epochs': 1})


100%|██████████| 1/1 [00:01<00:00,  1.78s/it]


3.623 / 0.014
----------
Arm(parameters={'lr': 0.0006830909901579769, 'decay': 0.009813793275866853, 'epochs': 18})


100%|██████████| 18/18 [00:31<00:00,  1.73s/it]


3.773 / 0.013
----------
Arm(parameters={'lr': 0.0006154371046006404, 'decay': 0.00021817967393688832, 'epochs': 15})


100%|██████████| 15/15 [00:25<00:00,  1.71s/it]


3.716 / 0.014
----------
Arm(parameters={'lr': 0.008526969624032996, 'decay': 0.008521992697599054, 'epochs': 7})


100%|██████████| 7/7 [00:12<00:00,  1.71s/it]


3.729 / 0.015
----------
Arm(parameters={'lr': 0.0014841675115592483, 'decay': 0.00041652190590026186, 'epochs': 3})


100%|██████████| 3/3 [00:05<00:00,  1.70s/it]


3.686 / 0.014
----------
Arm(parameters={'lr': 0.00011512656911212999, 'decay': 0.0014096383116111051, 'epochs': 15})


100%|██████████| 15/15 [00:26<00:00,  1.75s/it]


3.658 / 0.014
----------
Arm(parameters={'lr': 0.00019974886922630847, 'decay': 0.0007873981650974259, 'epochs': 10})


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


3.685 / 0.014


In [7]:
save(exp, 'finetune2.json')

In [8]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 3.636 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 0}
1: 3.767 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 5}
2: 3.794 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 10}
3: 3.528 {'lr': 0.006719886554586096, 'decay': 0.0025712614367976697, 'epochs': 17}
4: 3.722 {'lr': 0.0012230710776559032, 'decay': 0.0001245490767846484, 'epochs': 14}
5: 3.656 {'lr': 0.00015703189811414284, 'decay': 0.004250773039859609, 'epochs': 6}
6: 3.720 {'lr': 0.00027734698125494284, 'decay': 0.00024503393137916316, 'epochs': 19}
7: 3.672 {'lr': 0.00215089431162011, 'decay': 0.007181061645086981, 'epochs': 2}
8: 3.754 {'lr': 0.003669268451658114, 'decay': 0.0004685768597399909, 'epochs': 8}
9: 3.726 {'lr': 0.0005083524213359429, 'decay': 0.0011885017054668588, 'epochs': 11}
10: 3.715 {'lr': 0.0003885480703922011, 'decay': 0.00017572911480920575, 'epochs': 7}
11: 3.740 {'lr': 0.004987883132707978, 'decay': 0.003342016822981182, 'epochs': 14}
12: 3.719 {'lr': 0.0029168282988909335, 'decay': 0.0009197168866264231, 'epoch

## Iterative optimization
Botorch backend for gaussian process expected improvement (GPEI) bayesian optimization

In [17]:
for i in range(30):
    gpei = Models.GPEI(experiment=exp, data=exp.eval())
    gen = gpei.gen(1)
    print('-'*10)
    print(gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
Arm(parameters={'lr': 0.003125001755326415, 'decay': 0.002214232914119949, 'epochs': 12})


100%|██████████| 12/12 [00:20<00:00,  1.74s/it]


3.761 / 0.014
----------
Arm(parameters={'lr': 0.006628480497069864, 'decay': 0.0016183755027265353, 'epochs': 11})


100%|██████████| 11/11 [00:18<00:00,  1.71s/it]


3.653 / 0.015
----------
Arm(parameters={'lr': 0.004157984233597562, 'decay': 0.0037367267717074706, 'epochs': 9})


100%|██████████| 9/9 [00:15<00:00,  1.73s/it]


3.738 / 0.015
----------
Arm(parameters={'lr': 0.003984743419781169, 'decay': 0.0012914146741416816, 'epochs': 13})


100%|██████████| 13/13 [00:22<00:00,  1.75s/it]


3.733 / 0.014
----------
Arm(parameters={'lr': 0.0044755527574204075, 'decay': 0.0008770153419057366, 'epochs': 8})


100%|██████████| 8/8 [00:13<00:00,  1.72s/it]


3.726 / 0.015
----------
Arm(parameters={'lr': 0.0032933923143772794, 'decay': 0.0008111133357022278, 'epochs': 10})


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


3.738 / 0.015
----------
Arm(parameters={'lr': 0.004072644764672069, 'decay': 0.002303923665865714, 'epochs': 11})


100%|██████████| 11/11 [00:18<00:00,  1.73s/it]


3.674 / 0.014
----------
Arm(parameters={'lr': 0.005617095681160452, 'decay': 0.001755380644007118, 'epochs': 10})


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


3.711 / 0.015
----------
Arm(parameters={'lr': 0.004289613569715087, 'decay': 0.0008474777542033326, 'epochs': 10})


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


3.712 / 0.014
----------
Arm(parameters={'lr': 0.002904618553551574, 'decay': 0.0012288470288522085, 'epochs': 10})


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


3.780 / 0.014
----------
Arm(parameters={'lr': 0.0037889689157116543, 'decay': 0.0012146824301619034, 'epochs': 9})


100%|██████████| 9/9 [00:15<00:00,  1.72s/it]


3.721 / 0.015
----------
Arm(parameters={'lr': 0.002438642927999432, 'decay': 0.0025012259718864807, 'epochs': 13})


100%|██████████| 13/13 [00:22<00:00,  1.71s/it]


3.775 / 0.014
----------
Arm(parameters={'lr': 0.0012440216784983, 'decay': 0.0021844963220914236, 'epochs': 12})


100%|██████████| 12/12 [00:20<00:00,  1.73s/it]


3.741 / 0.014
----------
Arm(parameters={'lr': 0.006013150355723934, 'decay': 0.0028060288008963583, 'epochs': 13})


100%|██████████| 13/13 [00:22<00:00,  1.74s/it]


3.615 / 0.015
----------
Arm(parameters={'lr': 0.002683255708419544, 'decay': 0.001773505945887894, 'epochs': 11})


100%|██████████| 11/11 [00:18<00:00,  1.72s/it]


3.752 / 0.014
----------
Arm(parameters={'lr': 0.002245383380378057, 'decay': 0.0017013206868280336, 'epochs': 12})


100%|██████████| 12/12 [00:20<00:00,  1.73s/it]


3.756 / 0.014
----------
Arm(parameters={'lr': 0.002593592510211966, 'decay': 0.0035303285210389617, 'epochs': 14})


100%|██████████| 14/14 [00:24<00:00,  1.72s/it]


3.742 / 0.014
----------
Arm(parameters={'lr': 0.0034429988535537165, 'decay': 0.0012532430884026151, 'epochs': 10})


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


3.742 / 0.014
----------
Arm(parameters={'lr': 0.004597064302145103, 'decay': 0.00129254493547221, 'epochs': 8})


100%|██████████| 8/8 [00:13<00:00,  1.75s/it]


3.757 / 0.014
----------
Arm(parameters={'lr': 0.0026063679668893846, 'decay': 0.002122253915865578, 'epochs': 16})


100%|██████████| 16/16 [00:27<00:00,  1.74s/it]


3.758 / 0.015
----------
Arm(parameters={'lr': 0.002160904220770353, 'decay': 0.0025059118846846645, 'epochs': 15})


100%|██████████| 15/15 [00:25<00:00,  1.68s/it]


3.751 / 0.014
----------
Arm(parameters={'lr': 0.0027091769723221605, 'decay': 0.001036927347367017, 'epochs': 7})


100%|██████████| 7/7 [00:12<00:00,  1.74s/it]


3.719 / 0.015
----------
Arm(parameters={'lr': 0.0007273873451560561, 'decay': 0.007323840961065265, 'epochs': 17})


100%|██████████| 17/17 [00:29<00:00,  1.74s/it]


3.754 / 0.014
----------
Arm(parameters={'lr': 0.0030079183198763126, 'decay': 0.0014661394936190625, 'epochs': 12})


100%|██████████| 12/12 [00:20<00:00,  1.74s/it]


3.722 / 0.016
----------
Arm(parameters={'lr': 0.0027302036324266493, 'decay': 0.002615729671691876, 'epochs': 10})


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


3.807 / 0.013
----------
Arm(parameters={'lr': 0.0025249974010564885, 'decay': 0.0025410187000322456, 'epochs': 8})


100%|██████████| 8/8 [00:13<00:00,  1.65s/it]


3.765 / 0.014
----------
Arm(parameters={'lr': 0.002818156640431079, 'decay': 0.002693597623753263, 'epochs': 12})


100%|██████████| 12/12 [00:20<00:00,  1.70s/it]


3.756 / 0.014
----------
Arm(parameters={'lr': 0.0029062328771851335, 'decay': 0.002520316199161385, 'epochs': 8})


100%|██████████| 8/8 [00:13<00:00,  1.73s/it]


3.794 / 0.014
----------
Arm(parameters={'lr': 0.0028739319116947606, 'decay': 0.0031450826461652663, 'epochs': 8})


100%|██████████| 8/8 [00:13<00:00,  1.75s/it]


3.758 / 0.014
----------
Arm(parameters={'lr': 0.002788867139708552, 'decay': 0.002315431049546887, 'epochs': 9})


100%|██████████| 9/9 [00:15<00:00,  1.75s/it]


3.747 / 0.015


In [18]:
save(exp, 'finetune2.json')

In [19]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(
        i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 3.636 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 0}
1: 3.767 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 5}
2: 3.794 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 10}
3: 3.528 {'lr': 0.006719886554586096, 'decay': 0.0025712614367976697, 'epochs': 17}
4: 3.722 {'lr': 0.0012230710776559032, 'decay': 0.0001245490767846484, 'epochs': 14}
5: 3.656 {'lr': 0.00015703189811414284, 'decay': 0.004250773039859609, 'epochs': 6}
6: 3.720 {'lr': 0.00027734698125494284, 'decay': 0.00024503393137916316, 'epochs': 19}
7: 3.672 {'lr': 0.00215089431162011, 'decay': 0.007181061645086981, 'epochs': 2}
8: 3.754 {'lr': 0.003669268451658114, 'decay': 0.0004685768597399909, 'epochs': 8}
9: 3.726 {'lr': 0.0005083524213359429, 'decay': 0.0011885017054668588, 'epochs': 11}
10: 3.715 {'lr': 0.0003885480703922011, 'decay': 0.00017572911480920575, 'epochs': 7}
11: 3.740 {'lr': 0.004987883132707978, 'decay': 0.003342016822981182, 'epochs': 14}
12: 3.719 {'lr': 0.0029168282988909335, 'decay': 0.0009197168866264231, 'epoch

### Load experiment

In [4]:
exp = load('finetune2.json')
exp.evaluation_function = eval_fn

## Plot results

In [20]:
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
        y=np.maximum.accumulate(objective_means, axis=1),
)
render(best_objective_plot)

Couldn't really beat the previous parameters.

In [21]:
exp.trials[np.argmax(objective_means)].arm

Arm(name='47_0', parameters={'lr': 0.0027302036324266493, 'decay': 0.002615729671691876, 'epochs': 10})

In [22]:
render(plot_contour(model=gpei, param_x='lr', param_y='decay', metric_name='log_eval'))

In [23]:
render(plot_contour(model=gpei, param_x='epochs', param_y='lr', metric_name='log_eval'))

In [24]:
render(plot_contour(model=gpei, param_x='epochs', param_y='decay', metric_name='log_eval'))

In [25]:
render(plot_slice(model=gpei, param_name='lr', metric_name='log_eval'))

In [26]:
render(plot_slice(model=gpei, param_name='decay', metric_name='log_eval'))

In [27]:
render(plot_slice(model=gpei, param_name='epochs', metric_name='log_eval'))

In [29]:
exp.new_trial().add_arm(Arm(name="best",parameters={'lr': 0.0027, 'decay': 0.0026, 'epochs': 10}))
exp.trials[53].fetch_data()

100%|██████████| 10/10 [00:16<00:00,  1.69s/it]


3.789 / 0.014


<ax.core.data.Data at 0x7efde00904a8>

Original parameters still seem better. Will keep with that: epochs=>10, lr=>0.0043, decay=0.0012