# Visualizing Hyper-parameter Relationship in Jupyter Notebook

This notebook demonstrates a visualization utility of Optuna.
After optimizing the hyperparameter of neural networks, `plot_contour()` plots hyper-parameter relationship of completed trials in a study.

**Note:**  If a parameter contains missing values, a trial with missing values is not plotted.

## Setting up MNIST Dataset

In [1]:
import chainer
import numpy as np

N_TRAIN_EXAMPLES = 3000
N_TEST_EXAMPLES = 1000

rng = np.random.RandomState(0)
train, test = chainer.datasets.get_mnist()
train = chainer.datasets.SubDataset(
    train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train)))
test = chainer.datasets.SubDataset(test, 0, N_TEST_EXAMPLES, order=rng.permutation(len(test)))

## Defining Objective Function

In [2]:
import chainer.functions as F
import chainer.links as L

BATCHSIZE = 128
EPOCH = 10
PRUNER_INTERVAL = 3

def create_model(trial):
    # We optimize the numbers of layers and their units.
    n_layers = trial.suggest_int('n_layers', 1, 3)

    layers = []
    for i in range(n_layers):
        n_units = int(trial.suggest_loguniform('n_units_l{}'.format(i), 32, 256))
        layers.append(L.Linear(None, n_units))
        layers.append(F.relu)
    layers.append(L.Linear(None, 10))

    return chainer.Sequential(*layers)


def objective(trial):
    model = L.Classifier(create_model(trial))
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE)
    test_iter = chainer.iterators.SerialIterator(test, BATCHSIZE, repeat=False, shuffle=False)

    # Setup trainer.
    updater = chainer.training.StandardUpdater(train_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (EPOCH, 'epoch'))

    # Add Chainer extension for pruners.
    trainer.extend(
        optuna.integration.ChainerPruningExtension(trial, 'validation/main/loss',
                                                   (PRUNER_INTERVAL, 'epoch')))

    trainer.extend(chainer.training.extensions.Evaluator(test_iter, model))
    trainer.extend(
        chainer.training.extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy'
        ]))
    log_report_extension = chainer.training.extensions.LogReport(log_name=None)
    trainer.extend(log_report_extension)

    # Run training.
    # Please set show_loop_exception_msg False to inhibit messages about TrialPruned exception.
    # ChainerPruningExtension raises TrialPruned exception to stop training, and
    # trainer shows some messages every time it receive TrialPruned.
    trainer.run(show_loop_exception_msg=False)

    # Save loss and accuracy to user attributes.
    log_last = log_report_extension.log[-1]
    for key, value in log_last.items():
        trial.set_user_attr(key, value)

    return log_report_extension.log[-1]['validation/main/loss']

## Running Optimization

In [3]:
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)  # This verbosity change is just to simplify the notebook output.

study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=100)

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.77157     1.06736               0.509766       0.781701                  
[J2           0.766593    0.514403              0.815897       0.869817                  
[J3           0.449314    0.388877              0.882487       0.889047                  
[J4           0.341216    0.336289              0.911345       0.907602                  
[J5           0.281409    0.298026              0.923828       0.914663                  
[J6           0.24313     0.302383              0.934783       0.907076                  
[J7           0.211282    0.28121               0.945638       0.916842                  
[J8           0.186808    0.274703              0.94803        0.915865                  
[J9           0.162612    0.263513              0.956522       0.917819                  
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1     

[J1           1.5649      0.649063              0.595703       0.829627                  
[J2           0.486411    0.380697              0.856658       0.881535                  
[J3           0.332694    0.33267               0.904948       0.900766                  
[J4           0.264645    0.290026              0.924592       0.908353                  
[J5           0.198097    0.283475              0.948568       0.914663                  
[J6           0.169443    0.276528              0.951087       0.915865                  
[J7           0.135681    0.262226              0.960612       0.923227                  
[J8           0.113231    0.271088              0.96841        0.929537                  
[J9           0.081532    0.25621               0.981318       0.930514                  
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.54892     0.715241              0.569661       0.839093                  
[J2

[J5           0.17197     0.271247              0.951497       0.91902                   
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.29496     0.489511              0.655599       0.860577                  
[J2           0.403324    0.350153              0.880774       0.896184                  
[J3           0.284569    0.31942               0.920247       0.914964                  
[J4           0.209359    0.309094              0.93716        0.904447                  
[J5           0.167112    0.289071              0.955404       0.913912                  
[J6           0.128041    0.269414              0.960938       0.926382                  
[J7           0.0971749   0.257706              0.97526        0.932016                  
[J8           0.0781927   0.266693              0.980978       0.924204                  
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1     

[J2           0.515038    0.411065              0.853601       0.876653                  
[J3           0.334854    0.343264              0.903971       0.905649                  
[J4           0.257077    0.302215              0.932745       0.914889                  
[J5           0.22132     0.314684              0.935547       0.902269                  
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.52473     0.658611              0.589193       0.80942                   
[J2           0.491621    0.360638              0.868546       0.900541                  
[J3           0.321218    0.318811              0.907227       0.907828                  
[J4           0.233615    0.311335              0.933084       0.908353                  
[J5           0.195371    0.27991               0.941406       0.914213                  
[J6           0.150552    0.273807              0.956861       0.917593                  
[J7

[J1           1.36579     0.564558              0.629557       0.852314                  
[J2           0.449293    0.375811              0.877378       0.887169                  
[J3           0.300911    0.31607               0.916667       0.900541                  
[J4           0.24265     0.309094              0.932065       0.902494                  
[J5           0.203152    0.288434              0.944336       0.914889                  
[J6           0.157285    0.271429              0.960258       0.910757                  
[J7           0.122717    0.275285              0.964844       0.923453                  
[J8           0.0995494   0.270786              0.974185       0.927359                  
[J9           0.0797445   0.249295              0.980978       0.932692                  
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.83451     1.11604               0.498698       0.803561                  
[J2

[J1           1.28255     0.475821              0.682943       0.870793                  
[J2           0.402854    0.36334               0.888587       0.900316                  
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.26425     0.479812              0.686523       0.882512                  
[J2           0.403292    0.356473              0.883152       0.894456                  
[J3           0.276308    0.291124              0.921875       0.918344                  
[J4           0.213631    0.301742              0.939538       0.898813                  
[J5           0.168726    0.263713              0.956706       0.923678                  
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.29854     0.499221              0.66862        0.860577                  
[J2           0.436917    0.401632              0.874321       0.880334                  
[J3     

[J7           0.104944    0.254458              0.972982       0.923678                  
[J8           0.078585    0.24779               0.980978       0.929312                  
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.31786     0.524048              0.662435       0.865234                  
[J2           0.441133    0.336268              0.87534        0.903921                  
[J3           0.294877    0.304635              0.91862        0.913687                  
[J4           0.233541    0.282229              0.939878       0.917593                  
[J5           0.195839    0.314405              0.947591       0.900015                  
[J6           0.161086    0.265596              0.955163       0.925406                  
[J7           0.128023    0.24915               0.966797       0.925856                  
[J8           0.0970712   0.260332              0.977582       0.922251                  
epoc

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
[J1           1.34992     0.548518              0.655599       0.855469                  
[J2           0.464023    0.36857               0.863451       0.896184                  
[J3           0.304237    0.31187               0.914714       0.911734                  
[J4           0.228476    0.288256              0.93784        0.91271                   
[J5           0.184758    0.280293              0.948242       0.918795                  
[J6           0.151008    0.262881              0.957541       0.922927                  
[J7           0.118196    0.251928              0.970052       0.926833                  
[J8           0.0926157   0.239585              0.97894        0.925856                  
[J9           0.070933    0.242431              0.986753       0.930739                  


## Plotting Hyper-parameter Relationship of Trials

In [4]:
from optuna.visualization import plot_contour

plot_contour(study)

## Select parameters to Visualize

In [5]:
plot_contour(study, params=["n_units_l0", "n_units_l1"])