# Fine tune on games 80-100

In [1]:
import numpy as np
# import pandas as pd
from time import sleep

from ax import *

from ax.plot.contour import plot_contour
from ax.plot.slice import plot_slice
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()

# import evaluation function defined locally as the
# mean logscore of 1000 eval games after training
from finetune_ax import eval_fn

[INFO 07-15 10:31:52] ipy_plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.


## Define parameters
Training parameters
* `lr`: Previously 0.0043 (1e-3 to 1e-2), same as previous fine tuning.
* `decay`: Previously 0.0012 (1e-3 to 1e-1), same as initial training.
    -Adjusted due to initial results
* `epochs`: Previously 10 (0 to 20), same as previous fine tuning.

In [2]:
lr_param = RangeParameter(name="lr", parameter_type=ParameterType.FLOAT,
                             lower=8e-4,
                             upper=8e-3,
                             log_scale=True,
                             )
dec_param = RangeParameter(name="decay", parameter_type=ParameterType.FLOAT,
                             lower=1e-3,
                             upper=1e-1,
                             log_scale=True,
                             )
epox_param = RangeParameter(name="epochs", parameter_type=ParameterType.INT,
                             lower=0,
                             upper=15,
                             )
search_space = SearchSpace(
    parameters=[lr_param, dec_param, epox_param],
)

## Define experiment and initial arms

In [3]:
exp = SimpleExperiment(
    name='optimize_finetune',
    search_space=search_space,
    evaluation_function=eval_fn,
    objective_name='log_eval',
    minimize=False,  # MAXIMIZE!!
)

In [4]:
# Couldn't get status_quo to work. Adding control arm manually:
exp.new_trial().add_arm(Arm(name="zero",parameters={'lr': 0.0043, 'decay': 0.0012, 'epochs': 0}))
exp.new_trial().add_arm(Arm(name="control1",parameters={'lr': 0.0043, 'decay': 0.0012, 'epochs': 5}))
exp.new_trial().add_arm(Arm(name="control2",parameters={'lr': 0.0043, 'decay': 0.0012, 'epochs': 10}))
exp.new_trial().add_arm(Arm(name="previous",parameters={'lr': 0.00715, 'decay': 0.01, 'epochs': 6}))
exp.trials[0].fetch_data()
exp.trials[1].fetch_data()
exp.trials[2].fetch_data()

0it [00:00, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

3.934 / 0.009 / 40 sec


100%|██████████| 5/5 [00:10<00:00,  2.16s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

3.997 / 0.009 / 43 sec


100%|██████████| 10/10 [00:21<00:00,  2.21s/it]


3.902 / 0.010 / 38 sec


<ax.core.data.Data at 0x7ff983793a58>

In [5]:
exp.trials[3].fetch_data()

100%|██████████| 6/6 [00:13<00:00,  2.20s/it]


3.995 / 0.008 / 42 sec


<ax.core.data.Data at 0x7ff8ffe32278>

Fine tuning with previous parameters already gives benefit (this isn't guaranteed).

In [6]:
sobol = Models.SOBOL(exp.search_space)
for i in range(20):
    gen = sobol.gen(1)
    print('-'*10)
    print(i, gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
0 Arm(parameters={'lr': 0.0011934390082424268, 'decay': 0.06686069314247507, 'epochs': 1})


100%|██████████| 1/1 [00:02<00:00,  2.17s/it]


3.968 / 0.008 / 40 sec
----------
1 Arm(parameters={'lr': 0.0019558212574087366, 'decay': 0.0011824514174953505, 'epochs': 5})


100%|██████████| 5/5 [00:10<00:00,  2.13s/it]


3.983 / 0.009 / 42 sec
----------
2 Arm(parameters={'lr': 0.005551494394407805, 'decay': 0.010283260945966472, 'epochs': 8})


100%|██████████| 8/8 [00:17<00:00,  2.18s/it]


3.938 / 0.010 / 40 sec
----------
3 Arm(parameters={'lr': 0.007662923912231201, 'decay': 0.002680738706964375, 'epochs': 3})


100%|██████████| 3/3 [00:06<00:00,  2.18s/it]


4.011 / 0.009 / 44 sec
----------
4 Arm(parameters={'lr': 0.0016311940380312328, 'decay': 0.030532060346985843, 'epochs': 13})


100%|██████████| 13/13 [00:28<00:00,  2.14s/it]


4.033 / 0.008 / 45 sec
----------
5 Arm(parameters={'lr': 0.0008013008433305626, 'decay': 0.005499813737312483, 'epochs': 11})


100%|██████████| 11/11 [00:24<00:00,  2.19s/it]


3.996 / 0.009 / 45 sec
----------
6 Arm(parameters={'lr': 0.004044864308475928, 'decay': 0.04707769638039004, 'epochs': 6})


100%|██████████| 6/6 [00:13<00:00,  2.21s/it]


4.020 / 0.008 / 44 sec
----------
7 Arm(parameters={'lr': 0.0037989071059585247, 'decay': 0.0013841435750271724, 'epochs': 9})


100%|██████████| 9/9 [00:19<00:00,  2.19s/it]


3.921 / 0.010 / 39 sec
----------
8 Arm(parameters={'lr': 0.0010034284778267387, 'decay': 0.01619763305775833, 'epochs': 4})


100%|██████████| 4/4 [00:08<00:00,  2.17s/it]


3.978 / 0.008 / 42 sec
----------
9 Arm(parameters={'lr': 0.0017365582947129383, 'decay': 0.008977215067036219, 'epochs': 1})


100%|██████████| 1/1 [00:02<00:00,  2.11s/it]


3.976 / 0.009 / 45 sec
----------
10 Arm(parameters={'lr': 0.006116741263131368, 'decay': 0.07895861915890333, 'epochs': 11})


100%|██████████| 11/11 [00:23<00:00,  2.19s/it]


3.614 / 0.006 / 23 sec
----------
11 Arm(parameters={'lr': 0.00512507486326056, 'decay': 0.003491511312417426, 'epochs': 7})


100%|██████████| 7/7 [00:15<00:00,  2.17s/it]


3.998 / 0.009 / 43 sec
----------
12 Arm(parameters={'lr': 0.0024074724487850477, 'decay': 0.04021334838836546, 'epochs': 9})


100%|██████████| 9/9 [00:19<00:00,  2.19s/it]


4.026 / 0.008 / 45 sec
----------
13 Arm(parameters={'lr': 0.0012932799645160857, 'decay': 0.002270241551766799, 'epochs': 15})


100%|██████████| 15/15 [00:32<00:00,  2.18s/it]


3.965 / 0.009 / 41 sec
----------
14 Arm(parameters={'lr': 0.0025618335265802736, 'decay': 0.019564295308741823, 'epochs': 2})


100%|██████████| 2/2 [00:04<00:00,  2.14s/it]


4.014 / 0.008 / 42 sec
----------
15 Arm(parameters={'lr': 0.002774570008448762, 'decay': 0.00197962435500844, 'epochs': 4})


100%|██████████| 4/4 [00:08<00:00,  2.16s/it]


3.969 / 0.009 / 41 sec
----------
16 Arm(parameters={'lr': 0.0014007643909409292, 'decay': 0.02253605918351418, 'epochs': 9})


100%|██████████| 9/9 [00:19<00:00,  2.20s/it]


4.012 / 0.009 / 43 sec
----------
17 Arm(parameters={'lr': 0.0022177364780740904, 'decay': 0.004060196864519902, 'epochs': 12})


100%|██████████| 12/12 [00:26<00:00,  2.19s/it]


4.000 / 0.009 / 45 sec
----------
18 Arm(parameters={'lr': 0.004721522756697272, 'decay': 0.03473938680518261, 'epochs': 2})


100%|██████████| 2/2 [00:04<00:00,  2.21s/it]


4.016 / 0.008 / 43 sec
----------
19 Arm(parameters={'lr': 0.006507217278079819, 'decay': 0.007829091848627883, 'epochs': 10})


100%|██████████| 10/10 [00:20<00:00,  2.18s/it]


3.970 / 0.008 / 42 sec


In [7]:
save(exp, 'finetune3.json')

In [8]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 3.934 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 0}
1: 3.997 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 5}
2: 3.902 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 10}
3: 3.995 {'lr': 0.00715, 'decay': 0.01, 'epochs': 6}
4: 3.968 {'lr': 0.0011934390082424268, 'decay': 0.06686069314247507, 'epochs': 1}
5: 3.983 {'lr': 0.0019558212574087366, 'decay': 0.0011824514174953505, 'epochs': 5}
6: 3.938 {'lr': 0.005551494394407805, 'decay': 0.010283260945966472, 'epochs': 8}
7: 4.011 {'lr': 0.007662923912231201, 'decay': 0.002680738706964375, 'epochs': 3}
8: 4.033 {'lr': 0.0016311940380312328, 'decay': 0.030532060346985843, 'epochs': 13}
9: 3.996 {'lr': 0.0008013008433305626, 'decay': 0.005499813737312483, 'epochs': 11}
10: 4.020 {'lr': 0.004044864308475928, 'decay': 0.04707769638039004, 'epochs': 6}
11: 3.921 {'lr': 0.0037989071059585247, 'decay': 0.0013841435750271724, 'epochs': 9}
12: 3.978 {'lr': 0.0010034284778267387, 'decay': 0.01619763305775833, 'epochs': 4}
13: 3.976 {'lr': 0.00173655829471293

## Iterative optimization
Botorch backend for gaussian process expected improvement (GPEI) bayesian optimization

In [17]:
for i in range(30):
    gpei = Models.GPEI(experiment=exp, data=exp.eval())
    gen = gpei.gen(1)
    print('-'*10)
    print(i, gen.arms[0])
    sleep(1)
    trial = exp.new_trial(generator_run=gen)
    trial.fetch_data()

----------
0 Arm(parameters={'lr': 0.0027287701230472123, 'decay': 0.047467561333178455, 'epochs': 5})


100%|██████████| 5/5 [00:10<00:00,  2.11s/it]


4.008 / 0.009 / 46 sec
----------
1 Arm(parameters={'lr': 0.0018620520236365425, 'decay': 0.014504572431582102, 'epochs': 11})


100%|██████████| 11/11 [00:23<00:00,  2.14s/it]


4.018 / 0.009 / 47 sec
----------
2 Arm(parameters={'lr': 0.0012156310523465015, 'decay': 0.08578899579530594, 'epochs': 11})


100%|██████████| 11/11 [00:24<00:00,  2.19s/it]


4.034 / 0.008 / 45 sec
----------
3 Arm(parameters={'lr': 0.0009034861576172048, 'decay': 0.043318245856435324, 'epochs': 14})


100%|██████████| 14/14 [00:30<00:00,  2.19s/it]


4.037 / 0.008 / 45 sec
----------
4 Arm(parameters={'lr': 0.0011657755539881888, 'decay': 0.1, 'epochs': 15})


100%|██████████| 15/15 [00:32<00:00,  2.19s/it]


4.031 / 0.008 / 45 sec
----------
5 Arm(parameters={'lr': 0.008000000000000004, 'decay': 0.002288829527071239, 'epochs': 7})


100%|██████████| 7/7 [00:15<00:00,  2.20s/it]


3.922 / 0.010 / 40 sec
----------
6 Arm(parameters={'lr': 0.005128081369884944, 'decay': 0.009714020504312918, 'epochs': 3})


100%|██████████| 3/3 [00:06<00:00,  2.20s/it]


4.003 / 0.009 / 44 sec
----------
7 Arm(parameters={'lr': 0.0012034966711127215, 'decay': 0.05331576453873375, 'epochs': 13})


100%|██████████| 13/13 [00:28<00:00,  2.13s/it]


4.035 / 0.008 / 46 sec
----------
8 Arm(parameters={'lr': 0.0011924465660416662, 'decay': 0.018078350513876462, 'epochs': 15})


100%|██████████| 15/15 [00:32<00:00,  2.16s/it]


4.026 / 0.008 / 45 sec
----------
9 Arm(parameters={'lr': 0.0008000000000000004, 'decay': 0.1, 'epochs': 13})


100%|██████████| 13/13 [00:28<00:00,  2.19s/it]


4.040 / 0.007 / 44 sec
----------
10 Arm(parameters={'lr': 0.0008000000000000004, 'decay': 0.09999999999999949, 'epochs': 9})


100%|██████████| 9/9 [00:19<00:00,  2.16s/it]


4.027 / 0.007 / 43 sec
----------
11 Arm(parameters={'lr': 0.0015388220487877905, 'decay': 0.1, 'epochs': 8})


100%|██████████| 8/8 [00:17<00:00,  2.20s/it]


4.041 / 0.007 / 45 sec
----------
12 Arm(parameters={'lr': 0.0016447015306304208, 'decay': 0.05740378606914854, 'epochs': 9})


100%|██████████| 9/9 [00:19<00:00,  2.14s/it]


4.034 / 0.008 / 44 sec
----------
13 Arm(parameters={'lr': 0.0008000000000000011, 'decay': 0.001, 'epochs': 7})


100%|██████████| 7/7 [00:15<00:00,  2.19s/it]


3.978 / 0.009 / 42 sec
----------
14 Arm(parameters={'lr': 0.008000000000000004, 'decay': 0.1, 'epochs': 0})


0it [00:00, ?it/s]


3.913 / 0.009 / 38 sec
----------
15 Arm(parameters={'lr': 0.0008000000000000004, 'decay': 0.034683012185556, 'epochs': 11})


100%|██████████| 11/11 [00:23<00:00,  2.15s/it]


4.045 / 0.008 / 46 sec
----------
16 Arm(parameters={'lr': 0.0008000000000000004, 'decay': 0.019416227604450564, 'epochs': 13})


100%|██████████| 13/13 [00:28<00:00,  2.20s/it]


4.006 / 0.009 / 43 sec
----------
17 Arm(parameters={'lr': 0.00212708645449816, 'decay': 0.01127820018448251, 'epochs': 15})


100%|██████████| 15/15 [00:32<00:00,  2.15s/it]


3.984 / 0.009 / 42 sec
----------
18 Arm(parameters={'lr': 0.0008000000000000004, 'decay': 0.1, 'epochs': 15})


100%|██████████| 15/15 [00:32<00:00,  2.16s/it]


4.028 / 0.008 / 45 sec
----------
19 Arm(parameters={'lr': 0.0008000000000000004, 'decay': 0.03401669027082794, 'epochs': 8})


100%|██████████| 8/8 [00:17<00:00,  2.17s/it]


4.023 / 0.008 / 44 sec



A not p.d., added jitter of 1e-08 to the diagonal



----------
20 Arm(parameters={'lr': 0.003633636002786291, 'decay': 0.028776448733151633, 'epochs': 3})


100%|██████████| 3/3 [00:06<00:00,  2.17s/it]


4.037 / 0.008 / 44 sec
----------
21 Arm(parameters={'lr': 0.008000000000000004, 'decay': 0.007147212988528556, 'epochs': 0})


0it [00:00, ?it/s]


3.924 / 0.009 / 38 sec
----------
22 Arm(parameters={'lr': 0.0008000000000000004, 'decay': 0.062132377301891986, 'epochs': 11})


100%|██████████| 11/11 [00:23<00:00,  2.17s/it]


4.036 / 0.008 / 46 sec
----------
23 Arm(parameters={'lr': 0.0008000000000000004, 'decay': 0.001, 'epochs': 0})


0it [00:00, ?it/s]


3.940 / 0.009 / 39 sec
----------
24 Arm(parameters={'lr': 0.004308203330778961, 'decay': 0.028771024643772, 'epochs': 4})


100%|██████████| 4/4 [00:08<00:00,  2.19s/it]


4.036 / 0.008 / 47 sec
----------
25 Arm(parameters={'lr': 0.0036967447563253904, 'decay': 0.1, 'epochs': 3})


100%|██████████| 3/3 [00:06<00:00,  2.21s/it]


3.940 / 0.007 / 37 sec
----------
26 Arm(parameters={'lr': 0.0010861821066179202, 'decay': 0.1, 'epochs': 6})


100%|██████████| 6/6 [00:12<00:00,  2.14s/it]


4.019 / 0.007 / 42 sec
----------
27 Arm(parameters={'lr': 0.0029168908105427667, 'decay': 0.02052006286485383, 'epochs': 6})


100%|██████████| 6/6 [00:13<00:00,  2.20s/it]


4.046 / 0.008 / 44 sec
----------
28 Arm(parameters={'lr': 0.0014333221894011758, 'decay': 0.003932877349377023, 'epochs': 8})


100%|██████████| 8/8 [00:17<00:00,  2.18s/it]


3.985 / 0.009 / 45 sec
----------
29 Arm(parameters={'lr': 0.0034402691960054464, 'decay': 0.024358234364913824, 'epochs': 5})


100%|██████████| 5/5 [00:10<00:00,  2.20s/it]


4.056 / 0.007 / 45 sec


In [18]:
save(exp, 'finetune3.json')

In [19]:
gpei = Models.GPEI(experiment=exp, data=exp.eval())
for i, trial in exp.trials.items():
    print('{0}: {1:.3f} {2}'.format(
        i,
        trial.fetch_data().df.at[0, 'mean'],
        trial.arm.parameters
    ))

0: 3.934 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 0}
1: 3.997 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 5}
2: 3.902 {'lr': 0.0043, 'decay': 0.0012, 'epochs': 10}
3: 3.995 {'lr': 0.00715, 'decay': 0.01, 'epochs': 6}
4: 3.968 {'lr': 0.0011934390082424268, 'decay': 0.06686069314247507, 'epochs': 1}
5: 3.983 {'lr': 0.0019558212574087366, 'decay': 0.0011824514174953505, 'epochs': 5}
6: 3.938 {'lr': 0.005551494394407805, 'decay': 0.010283260945966472, 'epochs': 8}
7: 4.011 {'lr': 0.007662923912231201, 'decay': 0.002680738706964375, 'epochs': 3}
8: 4.033 {'lr': 0.0016311940380312328, 'decay': 0.030532060346985843, 'epochs': 13}
9: 3.996 {'lr': 0.0008013008433305626, 'decay': 0.005499813737312483, 'epochs': 11}
10: 4.020 {'lr': 0.004044864308475928, 'decay': 0.04707769638039004, 'epochs': 6}
11: 3.921 {'lr': 0.0037989071059585247, 'decay': 0.0013841435750271724, 'epochs': 9}
12: 3.978 {'lr': 0.0010034284778267387, 'decay': 0.01619763305775833, 'epochs': 4}
13: 3.976 {'lr': 0.00173655829471293

### Load experiment

In [28]:
exp = load('finetune3.json')
exp.evaluation_function = eval_fn

## Plot results

In [20]:
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
        y=np.maximum.accumulate(objective_means, axis=1),
)
render(best_objective_plot)

In [21]:
exp.trials[np.argmax(objective_means)].arm

Arm(name='53_0', parameters={'lr': 0.0034402691960054464, 'decay': 0.024358234364913824, 'epochs': 5})

In [22]:
render(plot_contour(model=gpei, param_x='lr', param_y='decay', metric_name='log_eval'))

In [23]:
render(plot_contour(model=gpei, param_x='epochs', param_y='lr', metric_name='log_eval'))

Epochs and lr are inversely correlated

In [24]:
render(plot_contour(model=gpei, param_x='epochs', param_y='decay', metric_name='log_eval'))

In [25]:
render(plot_slice(model=gpei, param_name='lr', metric_name='log_eval'))

In [26]:
render(plot_slice(model=gpei, param_name='decay', metric_name='log_eval'))

In [27]:
render(plot_slice(model=gpei, param_name='epochs', metric_name=4'log_eval'))

New parameters found: lr => 0.0034 (coming down), decay => 0.024 (much higher), epochs=>5

In [30]:
exp.trials[54].fetch_data()

100%|██████████| 5/5 [00:10<00:00,  2.15s/it]


4.036 / 0.009 / 46 sec


<ax.core.data.Data at 0x7ff8dcf40eb8>

In [31]:
exp.new_trial().add_arm(Arm(name="best",parameters={'lr': 0.0034, 'decay': 0.024, 'epochs': 6}))
exp.trials[55].fetch_data()

100%|██████████| 6/6 [00:12<00:00,  2.13s/it]


4.029 / 0.009 / 46 sec


<ax.core.data.Data at 0x7ff8dcf89dd8>

In [32]:
exp.new_trial().add_arm(Arm(name="best",parameters={'lr': 0.0027, 'decay': 0.027, 'epochs': 7}))
exp.trials[56].fetch_data()

100%|██████████| 7/7 [00:15<00:00,  2.17s/it]


4.025 / 0.008 / 43 sec


<ax.core.data.Data at 0x7ff8f80ef6d8>