In [None]:
# export
from fastai2.basics import *
from fastai2.callback.tracker import TrackerCallback

import optuna
from optuna.trial import Trial

In [None]:
from fastai2.test_utils import *

In [None]:
# default_exp optuna

# Optuna
>

## PruningCallback

In [None]:
# export
def get_latest_metric_val(recorder: Recorder, monitor='valid_loss'):
    idx = list(recorder.metric_names[1:]).index(monitor)
    val = recorder.values[-1][idx]
    return val

In [None]:
# export
class OptunaPruningCallback(Callback):
    def __init__(self, trial: Trial, monitor='valid_loss'):
        self.trial = trial
        self.monitor = monitor
    def after_epoch(self):
        val = get_latest_metric_val(self.recorder, self.monitor)
        self.trial.report(float(val), step=self.epoch)
        if self.trial.should_prune():
            message = 'Trial was pruned at epoch {}.'.format(self.epoch)
            raise optuna.exceptions.TrialPruned(message)

In [None]:
def objective(trial):
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
    learn = synth_learner(metrics=[mae], cbs=[OptunaPruningCallback(trial, monitor='mae')])
    learn.fit(10, lr=lr)
    return get_latest_metric_val(learn.recorder, monitor='mae')

In [None]:
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction='minimize', pruner=pruner)

In [None]:
study.optimize(objective, n_trials=10, timeout=600)

epoch,train_loss,valid_loss,mae,time
0,1.637815,1.656043,1.274559,00:00
1,1.608574,1.591213,1.248966,00:00
2,1.572296,1.518652,1.219556,00:00
3,1.531279,1.445926,1.189448,00:00
4,1.485566,1.375029,1.159439,00:00
5,1.438598,1.307544,1.130156,00:00
6,1.390133,1.243132,1.101429,00:00
7,1.341516,1.182062,1.073486,00:00
8,1.293058,1.124138,1.046248,00:00
9,1.244529,1.06905,1.019774,00:00


[I 2020-02-12 06:32:06,147] Finished trial#0 resulted in value: 1.0197744369506836. Current best value is 1.0197744369506836 with parameters: {'lr': 0.00012794600961090692}.


epoch,train_loss,valid_loss,mae,time
0,15.789175,17.256245,3.589227,00:00
1,12.829947,9.464449,2.661068,00:00
2,9.803594,4.432116,1.824718,00:00
3,7.363648,1.766372,1.161035,00:00
4,5.505668,0.666084,0.715211,00:00
5,4.159391,0.230255,0.417638,00:00
6,3.182611,0.075911,0.242488,00:00
7,2.465482,0.028719,0.150401,00:00
8,1.931119,0.015501,0.105461,00:00
9,1.526555,0.012239,0.083911,00:00


[I 2020-02-12 06:32:07,534] Finished trial#1 resulted in value: 0.08391130715608597. Current best value is 0.08391130715608597 with parameters: {'lr': 0.0015643689297346257}.


epoch,train_loss,valid_loss,mae,time
0,11.797703,2.936768,1.67774,00:00
1,5.762062,0.639909,0.690653,00:00
2,3.909071,0.737906,0.801102,00:00
3,2.742692,0.044243,0.163628,00:00
4,1.974672,0.095877,0.260683,00:00
5,1.478591,0.031447,0.155098,00:00
6,1.125162,0.008943,0.077742,00:00
7,0.872022,0.007271,0.069367,00:00
8,0.684309,0.007061,0.070259,00:00
9,0.542195,0.0079,0.074678,00:00


[I 2020-02-12 06:32:08,896] Finished trial#2 resulted in value: 0.07467806339263916. Current best value is 0.07467806339263916 with parameters: {'lr': 0.007709428142656766}.


epoch,train_loss,valid_loss,mae,time
0,15.669979,14.285971,3.649371,00:00
1,14.118828,11.007122,3.197709,00:00
2,12.321001,8.009432,2.721652,00:00
3,10.554479,5.70624,2.286328,00:00
4,8.940463,4.000061,1.905309,00:00
5,7.516487,2.773105,1.57998,00:00
6,6.284949,1.926517,1.307925,00:00
7,5.236779,1.328461,1.080087,00:00
8,4.350972,0.916054,0.890552,00:00
9,3.607452,0.630851,0.733182,00:00


[I 2020-02-12 06:32:10,293] Finished trial#3 resulted in value: 0.7331821918487549. Current best value is 0.07467806339263916 with parameters: {'lr': 0.007709428142656766}.


epoch,train_loss,valid_loss,mae,time
0,11.497913,8.971736,2.958515,00:00
1,8.327281,3.316296,1.802006,00:00
2,5.698774,0.717773,0.839866,00:00
3,3.926574,0.070004,0.245621,00:00
4,2.805215,0.013914,0.095343,00:00
5,2.079527,0.02771,0.138601,00:00
6,1.582646,0.024546,0.130341,00:00
7,1.225652,0.016083,0.104618,00:00
8,0.960965,0.010809,0.085508,00:00
9,0.760585,0.009075,0.07867,00:00


[I 2020-02-12 06:32:11,638] Finished trial#4 resulted in value: 0.07867028564214706. Current best value is 0.07467806339263916 with parameters: {'lr': 0.007709428142656766}.


epoch,train_loss,valid_loss,mae,time
0,9.949349,9.068306,3.005578,00:00
1,9.069424,7.22947,2.68338,00:00


[I 2020-02-12 06:32:11,889] Setting status of trial#5 as TrialState.PRUNED. Trial was pruned at epoch 1.


epoch,train_loss,valid_loss,mae,time
0,7.5067,7.182096,2.193205,00:00
1,7.510965,7.107755,2.181266,00:00


[I 2020-02-12 06:32:12,230] Setting status of trial#6 as TrialState.PRUNED. Trial was pruned at epoch 1.


epoch,train_loss,valid_loss,mae,time
0,34.832165,49.574306,6.002351,00:00
1,34.498737,48.996902,5.967054,00:00


[I 2020-02-12 06:32:12,578] Setting status of trial#7 as TrialState.PRUNED. Trial was pruned at epoch 1.


epoch,train_loss,valid_loss,mae,time
0,14.700743,15.213393,3.700629,00:00
1,14.567232,14.954995,3.668853,00:00


[I 2020-02-12 06:32:12,914] Setting status of trial#8 as TrialState.PRUNED. Trial was pruned at epoch 1.


epoch,train_loss,valid_loss,mae,time
0,3.317702,3.035458,1.626831,00:00
1,2.085379,0.990242,0.951244,00:00
2,1.401456,0.330082,0.535652,00:00
3,0.982392,0.094028,0.285634,00:00
4,0.712013,0.025687,0.139004,00:00
5,0.531641,0.012934,0.092764,00:00
6,0.406463,0.008834,0.07719,00:00
7,0.316081,0.008302,0.074111,00:00
8,0.24927,0.007729,0.070687,00:00
9,0.1987,0.007735,0.072065,00:00


[I 2020-02-12 06:32:14,321] Finished trial#9 resulted in value: 0.07206450402736664. Current best value is 0.07206450402736664 with parameters: {'lr': 0.05228637126518103}.


## Export -

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_data.core.ipynb.
Converted 02_pytorch.transformer.ipynb.
Converted 03_pytorch.model.ipynb.
Converted 04_optuna.ipynb.
Converted index.ipynb.
