# Optimizing Neural Networks

To remind, necessary ingredients to train NN:
    * model
    * objective
    * optimizer
    
Today we will try to understand basics of optimization of neural networks, giving context for the last two lectures. Goal is to:
* Understand basics of generalization, and the difference between optimization and generalization (more on that in "Understanding generalization" lab)
* Understand impact of hyperparameters in SGD on:

  - generalization (lr, batch size)
  - speed of optimization (lr, momentum, batch size) 

References:
* Deep Learning book chapter on optimization: http://www.deeplearningbook.org/contents/optimization.html

# Setup

In [4]:
# Boilerplate code to get started

%load_ext autoreload
%autoreload 
%matplotlib inline

import json
import matplotlib as mpl
from src import fmnist_utils
from src.fmnist_utils import *

def plot(H):
    plt.title(max(H['test_acc']))
    plt.plot(H['acc'], label="acc")
    plt.plot(H['test_acc'], label="test_acc")
    plt.legend()

mpl.rcParams['lines.linewidth'] = 2
mpl.rcParams['figure.figsize'] = (7, 7)
mpl.rcParams['axes.titlesize'] = 12
mpl.rcParams['axes.labelsize'] = 12

(x_train, y_train), (x_test, y_test) = fmnist_utils.get_data()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


# Exercise 1: optimization speed

Assuming fixed number of *epochs*, it is usually better to use either smaller batch size, or larger learning rate. Theoretical reason for it is not completely clear, so let's focus in this exercise on an empirical investigation.

Assume you are allowed to train the given network for 10 epochs. Answer the following questions:

* a) What was the optimal $\eta$ (assuming $S$=128 and $\mu$=0.9) for the final training accuracy?
* b) Did it also provide the best test accuracy? If yes, why (hint: consider if model is under or over-fitting)?
* c) What is the optimal $S$ (assuming $\eta$=0.1 and $\mu$=0.9) for the final training accuracy?
* d) Why is higher learning rate, or smaller batch size, optimizing faster? Give your best explanation (it can be hypothetical, there is no obvious theoretical answer)?

In [18]:
for lr in [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5]:
    model = build_mlp(784, 10, hidden_dims=[512])
    loss = torch.nn.CrossEntropyLoss(size_average=True)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    H = train(loss=loss, model=model, x_train=x_train, y_train=y_train,
              x_test=x_test, y_test=y_test,
              optim=optimizer, batch_size=128, n_epochs=10)
    print("lr: ", lr, " train_acc: ", H['acc'][-1], " test_acc: ", H['test_acc'][-1])

100%|██████████| 10/10 [00:01<00:00,  8.83it/s]
 20%|██        | 2/10 [00:00<00:00, 11.36it/s]

lr:  0.0001  train_acc:  0.103  test_acc:  0.121


100%|██████████| 10/10 [00:00<00:00, 11.76it/s]
 20%|██        | 2/10 [00:00<00:00, 12.42it/s]

lr:  0.0005  train_acc:  0.474  test_acc:  0.475


100%|██████████| 10/10 [00:00<00:00, 10.69it/s]
 20%|██        | 2/10 [00:00<00:00, 12.53it/s]

lr:  0.001  train_acc:  0.568  test_acc:  0.541


100%|██████████| 10/10 [00:00<00:00, 11.90it/s]
 20%|██        | 2/10 [00:00<00:00, 11.79it/s]

lr:  0.005  train_acc:  0.684  test_acc:  0.666


100%|██████████| 10/10 [00:00<00:00, 11.84it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

lr:  0.01  train_acc:  0.738  test_acc:  0.704


100%|██████████| 10/10 [00:00<00:00, 10.97it/s]
 10%|█         | 1/10 [00:00<00:01,  8.76it/s]

lr:  0.05  train_acc:  0.869  test_acc:  0.769


100%|██████████| 10/10 [00:00<00:00, 11.18it/s]
 20%|██        | 2/10 [00:00<00:00, 12.64it/s]

lr:  0.1  train_acc:  0.876  test_acc:  0.775


100%|██████████| 10/10 [00:00<00:00, 11.86it/s]

lr:  0.5  train_acc:  0.675  test_acc:  0.591





In [19]:
for bs in [2, 4, 8, 16, 32, 64, 128, 256, 512]:
    model = build_mlp(784, 10, hidden_dims=[512])
    loss = torch.nn.CrossEntropyLoss(size_average=True)
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    H = train(loss=loss, model=model, x_train=x_train, y_train=y_train,
              x_test=x_test, y_test=y_test,
              optim=optimizer, batch_size=bs, n_epochs=10)
    print("bs: ", bs, " train_acc: ", H['acc'][-1], " test_acc: ", H['test_acc'][-1])

100%|██████████| 10/10 [00:26<00:00,  2.81s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

bs:  2  train_acc:  0.104  test_acc:  0.105


100%|██████████| 10/10 [00:12<00:00,  1.40s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

bs:  4  train_acc:  0.134  test_acc:  0.133


100%|██████████| 10/10 [00:05<00:00,  1.54it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

bs:  8  train_acc:  0.115  test_acc:  0.116


100%|██████████| 10/10 [00:02<00:00,  4.79it/s]
 10%|█         | 1/10 [00:00<00:01,  6.54it/s]

bs:  16  train_acc:  0.664  test_acc:  0.56


100%|██████████| 10/10 [00:01<00:00,  6.99it/s]
 10%|█         | 1/10 [00:00<00:01,  8.25it/s]

bs:  32  train_acc:  0.832  test_acc:  0.702


100%|██████████| 10/10 [00:01<00:00,  9.72it/s]
 20%|██        | 2/10 [00:00<00:00, 12.28it/s]

bs:  64  train_acc:  0.919  test_acc:  0.789


100%|██████████| 10/10 [00:00<00:00, 12.92it/s]
 20%|██        | 2/10 [00:00<00:00, 16.18it/s]

bs:  128  train_acc:  0.884  test_acc:  0.776


100%|██████████| 10/10 [00:00<00:00, 15.41it/s]
 20%|██        | 2/10 [00:00<00:00, 16.33it/s]

bs:  256  train_acc:  0.787  test_acc:  0.693


100%|██████████| 10/10 [00:00<00:00, 16.52it/s]

bs:  512  train_acc:  0.691  test_acc:  0.656





In [23]:
answers = {"a": "0.1", "b": "Yes. Might be underfitting?", "c": "128", "d": "Higher lr / lower batchsize is more noisy. Maybe it allows you to break out of flat areas / past saddle-points faster."}
json.dump(answers, open("7b_ex1.json", "w"))

# Exercise 2: generalization

Story with generalization is also unclear, but it is generally accepted that higher noise levels in SGD lead to better generalization. Think of noise in optimization (leading to low fidelity, as seen in lab 7a, for instance) as a close analog of typical regularizations (like dropout or batch normalization, that we will discuss next time).

Your task is to:

a) Check a range of LR and BS and find the best generalizing combination of LR and BS. What test accuracy were you able to achieve? What is the best LR and BS combination?

b) Answer the following question: Is stability correlated with using large LR or small BS. If yes, what is the intuitive reason for it? Feel free to give a hypothesis.

Hints:

* Make sure you achieve 100% training accuracy with each run, discard hyperparameters that are not achieving convergence.

Notes:

* Do not change the model in the starting code. It is on purpose a bit more complex MLP.

* You can measure stability by computing margin. This is implemented for you (using DeepFool method, https://arxiv.org/abs/1511.04599). Measuring margin is expensive, so recommended approach would be to compute it only on few final runs. 

In [21]:
for lr in [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5]:
    for bs in [2, 4, 8, 16, 32, 64, 128, 256, 512]:
        model = build_mlp(784, 10, hidden_dims=[512])
        loss = torch.nn.CrossEntropyLoss(size_average=True)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        H = train(loss=loss, model=model, x_train=x_train, y_train=y_train,
                  x_test=x_test, y_test=y_test,
                  optim=optimizer, batch_size=bs, n_epochs=20)
        print("lr: ", lr, " bs: ", bs, " train_acc: ", H['acc'][-1], " test_acc: ", H['test_acc'][-1])

100%|██████████| 20/20 [00:27<00:00,  1.40s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.0001  bs:  2  train_acc:  0.803  test_acc:  0.745


100%|██████████| 20/20 [00:14<00:00,  1.23it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.0001  bs:  4  train_acc:  0.75  test_acc:  0.712


100%|██████████| 20/20 [00:07<00:00,  2.85it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.0001  bs:  8  train_acc:  0.685  test_acc:  0.662


100%|██████████| 20/20 [00:04<00:00,  4.43it/s]
  5%|▌         | 1/20 [00:00<00:02,  6.95it/s]

lr:  0.0001  bs:  16  train_acc:  0.649  test_acc:  0.631


100%|██████████| 20/20 [00:03<00:00,  6.81it/s]
  5%|▌         | 1/20 [00:00<00:02,  7.75it/s]

lr:  0.0001  bs:  32  train_acc:  0.593  test_acc:  0.583


100%|██████████| 20/20 [00:02<00:00,  9.24it/s]
 10%|█         | 2/20 [00:00<00:01, 12.22it/s]

lr:  0.0001  bs:  64  train_acc:  0.466  test_acc:  0.473


100%|██████████| 20/20 [00:01<00:00, 11.85it/s]
 10%|█         | 2/20 [00:00<00:01, 12.27it/s]

lr:  0.0001  bs:  128  train_acc:  0.124  test_acc:  0.13


100%|██████████| 20/20 [00:01<00:00, 13.14it/s]
 10%|█         | 2/20 [00:00<00:00, 18.40it/s]

lr:  0.0001  bs:  256  train_acc:  0.103  test_acc:  0.122


100%|██████████| 20/20 [00:01<00:00, 15.96it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.0001  bs:  512  train_acc:  0.104  test_acc:  0.124


100%|██████████| 20/20 [00:27<00:00,  1.30s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.0005  bs:  2  train_acc:  0.919  test_acc:  0.8


100%|██████████| 20/20 [00:12<00:00,  1.59it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.0005  bs:  4  train_acc:  0.878  test_acc:  0.789


100%|██████████| 20/20 [00:07<00:00,  2.91it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.0005  bs:  8  train_acc:  0.813  test_acc:  0.754


100%|██████████| 20/20 [00:04<00:00,  4.77it/s]
  5%|▌         | 1/20 [00:00<00:02,  6.44it/s]

lr:  0.0005  bs:  16  train_acc:  0.752  test_acc:  0.716


100%|██████████| 20/20 [00:02<00:00,  6.76it/s]
  5%|▌         | 1/20 [00:00<00:02,  8.04it/s]

lr:  0.0005  bs:  32  train_acc:  0.7  test_acc:  0.682


100%|██████████| 20/20 [00:02<00:00,  9.14it/s]
 10%|█         | 2/20 [00:00<00:01, 12.54it/s]

lr:  0.0005  bs:  64  train_acc:  0.644  test_acc:  0.624


100%|██████████| 20/20 [00:01<00:00, 12.14it/s]
 10%|█         | 2/20 [00:00<00:01, 13.06it/s]

lr:  0.0005  bs:  128  train_acc:  0.593  test_acc:  0.581


100%|██████████| 20/20 [00:01<00:00, 14.67it/s]
 10%|█         | 2/20 [00:00<00:01, 14.16it/s]

lr:  0.0005  bs:  256  train_acc:  0.438  test_acc:  0.432


100%|██████████| 20/20 [00:01<00:00, 18.36it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.0005  bs:  512  train_acc:  0.103  test_acc:  0.121


100%|██████████| 20/20 [00:25<00:00,  1.26s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.001  bs:  2  train_acc:  0.941  test_acc:  0.791


100%|██████████| 20/20 [00:12<00:00,  1.47it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.001  bs:  4  train_acc:  0.913  test_acc:  0.783


100%|██████████| 20/20 [00:07<00:00,  2.71it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.001  bs:  8  train_acc:  0.874  test_acc:  0.783


100%|██████████| 20/20 [00:04<00:00,  3.31it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.001  bs:  16  train_acc:  0.81  test_acc:  0.756


100%|██████████| 20/20 [00:03<00:00,  6.76it/s]
  5%|▌         | 1/20 [00:00<00:02,  8.81it/s]

lr:  0.001  bs:  32  train_acc:  0.766  test_acc:  0.72


100%|██████████| 20/20 [00:02<00:00,  8.76it/s]
  5%|▌         | 1/20 [00:00<00:01,  9.92it/s]

lr:  0.001  bs:  64  train_acc:  0.695  test_acc:  0.675


100%|██████████| 20/20 [00:01<00:00, 12.31it/s]
 10%|█         | 2/20 [00:00<00:01, 13.90it/s]

lr:  0.001  bs:  128  train_acc:  0.64  test_acc:  0.622


100%|██████████| 20/20 [00:01<00:00, 14.92it/s]
 10%|█         | 2/20 [00:00<00:01, 14.78it/s]

lr:  0.001  bs:  256  train_acc:  0.573  test_acc:  0.546


100%|██████████| 20/20 [00:01<00:00, 16.76it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.001  bs:  512  train_acc:  0.115  test_acc:  0.126


100%|██████████| 20/20 [00:31<00:00,  1.62s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.005  bs:  2  train_acc:  0.901  test_acc:  0.762


100%|██████████| 20/20 [00:13<00:00,  1.50it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.005  bs:  4  train_acc:  0.918  test_acc:  0.779


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.005  bs:  8  train_acc:  0.953  test_acc:  0.8


100%|██████████| 20/20 [00:04<00:00,  4.35it/s]
  5%|▌         | 1/20 [00:00<00:03,  5.46it/s]

lr:  0.005  bs:  16  train_acc:  0.935  test_acc:  0.799


100%|██████████| 20/20 [00:03<00:00,  6.56it/s]
  5%|▌         | 1/20 [00:00<00:01,  9.76it/s]

lr:  0.005  bs:  32  train_acc:  0.891  test_acc:  0.782


100%|██████████| 20/20 [00:02<00:00,  8.73it/s]
 10%|█         | 2/20 [00:00<00:01, 12.66it/s]

lr:  0.005  bs:  64  train_acc:  0.829  test_acc:  0.756


100%|██████████| 20/20 [00:01<00:00, 11.32it/s]
 10%|█         | 2/20 [00:00<00:01, 12.97it/s]

lr:  0.005  bs:  128  train_acc:  0.765  test_acc:  0.733


100%|██████████| 20/20 [00:01<00:00, 13.85it/s]
 15%|█▌        | 3/20 [00:00<00:00, 19.99it/s]

lr:  0.005  bs:  256  train_acc:  0.691  test_acc:  0.662


100%|██████████| 20/20 [00:01<00:00, 16.64it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.005  bs:  512  train_acc:  0.53  test_acc:  0.511


100%|██████████| 20/20 [00:51<00:00,  2.71s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.01  bs:  2  train_acc:  0.848  test_acc:  0.737


100%|██████████| 20/20 [00:17<00:00,  1.04s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.01  bs:  4  train_acc:  0.925  test_acc:  0.787


100%|██████████| 20/20 [00:07<00:00,  2.60it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.01  bs:  8  train_acc:  0.957  test_acc:  0.799


100%|██████████| 20/20 [00:04<00:00,  3.80it/s]
  5%|▌         | 1/20 [00:00<00:02,  7.43it/s]

lr:  0.01  bs:  16  train_acc:  0.944  test_acc:  0.785


100%|██████████| 20/20 [00:02<00:00,  7.10it/s]
  5%|▌         | 1/20 [00:00<00:01,  9.86it/s]

lr:  0.01  bs:  32  train_acc:  0.928  test_acc:  0.803


100%|██████████| 20/20 [00:02<00:00,  8.83it/s]
 10%|█         | 2/20 [00:00<00:01, 10.93it/s]

lr:  0.01  bs:  64  train_acc:  0.855  test_acc:  0.765


100%|██████████| 20/20 [00:01<00:00, 13.45it/s]
 10%|█         | 2/20 [00:00<00:01, 14.16it/s]

lr:  0.01  bs:  128  train_acc:  0.817  test_acc:  0.768


100%|██████████| 20/20 [00:01<00:00, 14.66it/s]
 10%|█         | 2/20 [00:00<00:00, 19.66it/s]

lr:  0.01  bs:  256  train_acc:  0.716  test_acc:  0.693


100%|██████████| 20/20 [00:01<00:00, 19.52it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.01  bs:  512  train_acc:  0.648  test_acc:  0.631


100%|██████████| 20/20 [00:57<00:00,  2.93s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.05  bs:  2  train_acc:  0.115  test_acc:  0.095


100%|██████████| 20/20 [00:27<00:00,  1.53s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.05  bs:  4  train_acc:  0.101  test_acc:  0.1


100%|██████████| 20/20 [00:11<00:00,  1.39it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.05  bs:  8  train_acc:  0.668  test_acc:  0.559


100%|██████████| 20/20 [00:04<00:00,  3.26it/s]
  5%|▌         | 1/20 [00:00<00:02,  7.31it/s]

lr:  0.05  bs:  16  train_acc:  0.843  test_acc:  0.721


100%|██████████| 20/20 [00:02<00:00,  6.86it/s]
  5%|▌         | 1/20 [00:00<00:02,  7.79it/s]

lr:  0.05  bs:  32  train_acc:  0.938  test_acc:  0.794


100%|██████████| 20/20 [00:02<00:00,  9.43it/s]
 10%|█         | 2/20 [00:00<00:01, 13.09it/s]

lr:  0.05  bs:  64  train_acc:  0.952  test_acc:  0.789


100%|██████████| 20/20 [00:01<00:00, 12.52it/s]
 10%|█         | 2/20 [00:00<00:01, 13.05it/s]

lr:  0.05  bs:  128  train_acc:  0.899  test_acc:  0.778


100%|██████████| 20/20 [00:01<00:00, 14.81it/s]
 10%|█         | 2/20 [00:00<00:01, 16.35it/s]

lr:  0.05  bs:  256  train_acc:  0.866  test_acc:  0.778


100%|██████████| 20/20 [00:01<00:00, 19.60it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.05  bs:  512  train_acc:  0.71  test_acc:  0.682


100%|██████████| 20/20 [00:55<00:00,  2.89s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.1  bs:  2  train_acc:  0.095  test_acc:  0.115


100%|██████████| 20/20 [00:26<00:00,  1.48s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.1  bs:  4  train_acc:  0.107  test_acc:  0.108


100%|██████████| 20/20 [00:12<00:00,  1.28it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.1  bs:  8  train_acc:  0.187  test_acc:  0.186


100%|██████████| 20/20 [00:05<00:00,  2.64it/s]
  5%|▌         | 1/20 [00:00<00:02,  7.28it/s]

lr:  0.1  bs:  16  train_acc:  0.614  test_acc:  0.533


100%|██████████| 20/20 [00:02<00:00,  7.11it/s]
 10%|█         | 2/20 [00:00<00:01, 10.42it/s]

lr:  0.1  bs:  32  train_acc:  0.8  test_acc:  0.711


100%|██████████| 20/20 [00:02<00:00,  9.43it/s]
 10%|█         | 2/20 [00:00<00:01, 11.84it/s]

lr:  0.1  bs:  64  train_acc:  0.893  test_acc:  0.768


100%|██████████| 20/20 [00:01<00:00, 12.51it/s]
 10%|█         | 2/20 [00:00<00:01, 16.48it/s]

lr:  0.1  bs:  128  train_acc:  0.948  test_acc:  0.785


100%|██████████| 20/20 [00:01<00:00, 15.98it/s]
 15%|█▌        | 3/20 [00:00<00:00, 20.67it/s]

lr:  0.1  bs:  256  train_acc:  0.91  test_acc:  0.791


100%|██████████| 20/20 [00:01<00:00, 18.31it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.1  bs:  512  train_acc:  0.74  test_acc:  0.715


100%|██████████| 20/20 [00:52<00:00,  2.76s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.5  bs:  2  train_acc:  0.1  test_acc:  0.097


100%|██████████| 20/20 [00:26<00:00,  1.43s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.5  bs:  4  train_acc:  0.094  test_acc:  0.095


100%|██████████| 20/20 [00:13<00:00,  1.24it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

lr:  0.5  bs:  8  train_acc:  0.086  test_acc:  0.111


100%|██████████| 20/20 [00:05<00:00,  2.46it/s]
  5%|▌         | 1/20 [00:00<00:03,  5.70it/s]

lr:  0.5  bs:  16  train_acc:  0.115  test_acc:  0.095


100%|██████████| 20/20 [00:03<00:00,  5.49it/s]
  5%|▌         | 1/20 [00:00<00:02,  7.80it/s]

lr:  0.5  bs:  32  train_acc:  0.115  test_acc:  0.107


100%|██████████| 20/20 [00:02<00:00,  9.21it/s]
 10%|█         | 2/20 [00:00<00:01, 13.54it/s]

lr:  0.5  bs:  64  train_acc:  0.181  test_acc:  0.142


100%|██████████| 20/20 [00:01<00:00, 12.53it/s]
 10%|█         | 2/20 [00:00<00:01, 14.17it/s]

lr:  0.5  bs:  128  train_acc:  0.71  test_acc:  0.634


100%|██████████| 20/20 [00:01<00:00, 15.15it/s]
 10%|█         | 2/20 [00:00<00:01, 17.42it/s]

lr:  0.5  bs:  256  train_acc:  0.815  test_acc:  0.735


100%|██████████| 20/20 [00:01<00:00, 18.94it/s]

lr:  0.5  bs:  512  train_acc:  0.682  test_acc:  0.645





In [39]:
answers = {"a": "Training till 100% training accuracy, the combination with the highest test accuracy was lr: 0.05, bs: 128 (with test acc of 0.793).", "b": "It does seem like it is (though with lr mattering more). If higher lr regularizes more, it makes sense that stability is higher as well.", "c": "There is no c?"}
json.dump(answers, open("7b_ex2.json", "w"))

## Stability measure

In 7a lab we discussed bias/variance view. Here, we will take a stability based view. To estimate stability, 
we will record maximum change in prediction when adding gaussian noise to examples. This is a very rudimentary
way to estimate geometric margin of the network, and we will talk more about this later.

In [8]:
from src.deepfool import measure_stability_deepfool

## Finding optimal $\eta$ and $S$

In [36]:
## Starting code

Hs = []
Lrs = [0.001, 0.005, 0.01, 0.05]
Margins = []
bss = [16, 32, 64, 128]

for lr in Lrs:
    for bs in bss:
        model = build_mlp(784, 10, hidden_dims=[100, 100, 100])
        loss = torch.nn.CrossEntropyLoss(size_average=True)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        H = train(loss=loss, model=model, x_train=x_train, y_train=y_train,
                  x_test=x_test, y_test=y_test,
                  optim=optimizer, batch_size=100, n_epochs=400)
        Margins.append(measure_stability_deepfool(model=model, 
                    x_train=x_train, y_train=y_train, loss=loss, N=1000))
        Hs.append(H)

100%|██████████| 400/400 [00:16<00:00, 23.57it/s]
100%|██████████| 400/400 [00:18<00:00, 21.07it/s]
100%|██████████| 400/400 [00:18<00:00, 21.50it/s]
100%|██████████| 400/400 [00:20<00:00, 19.85it/s]
100%|██████████| 400/400 [00:18<00:00, 20.71it/s]
100%|██████████| 400/400 [00:18<00:00, 21.44it/s]
100%|██████████| 400/400 [00:19<00:00, 16.73it/s]
100%|██████████| 400/400 [00:19<00:00, 20.10it/s]
100%|██████████| 400/400 [00:19<00:00, 20.22it/s]
100%|██████████| 400/400 [00:28<00:00, 11.17it/s]
100%|██████████| 400/400 [00:36<00:00, 11.24it/s]
100%|██████████| 400/400 [00:27<00:00, 14.77it/s]
100%|██████████| 400/400 [00:34<00:00, 11.99it/s]
100%|██████████| 400/400 [00:29<00:00, 15.34it/s]
100%|██████████| 400/400 [00:28<00:00, 15.10it/s]
100%|██████████| 400/400 [00:31<00:00, 13.08it/s]


In [37]:
i = 0
for lr in Lrs:
    for bs in bss:
        print("lr: ", lr, " bs: ", bs, " train_acc: ", Hs[i]['acc'][-1], " test_acc: ", Hs[i]['test_acc'][-1])
        print(Margins[i])
        i += 1

lr:  0.001  bs:  16  train_acc:  0.909  test_acc:  0.765
0.6658068
lr:  0.001  bs:  32  train_acc:  0.914  test_acc:  0.776
0.6743854
lr:  0.001  bs:  64  train_acc:  0.916  test_acc:  0.771
0.66958976
lr:  0.001  bs:  128  train_acc:  0.913  test_acc:  0.774
0.67680544
lr:  0.005  bs:  16  train_acc:  1.0  test_acc:  0.765
0.43906567
lr:  0.005  bs:  32  train_acc:  1.0  test_acc:  0.763
0.44927588
lr:  0.005  bs:  64  train_acc:  1.0  test_acc:  0.767
0.4437598
lr:  0.005  bs:  128  train_acc:  1.0  test_acc:  0.761
0.43897283
lr:  0.01  bs:  16  train_acc:  1.0  test_acc:  0.775
0.47367164
lr:  0.01  bs:  32  train_acc:  1.0  test_acc:  0.775
0.46811914
lr:  0.01  bs:  64  train_acc:  1.0  test_acc:  0.773
0.4699875
lr:  0.01  bs:  128  train_acc:  1.0  test_acc:  0.773
0.46465495
lr:  0.05  bs:  16  train_acc:  1.0  test_acc:  0.791
0.8310167
lr:  0.05  bs:  32  train_acc:  1.0  test_acc:  0.789
0.7979353
lr:  0.05  bs:  64  train_acc:  1.0  test_acc:  0.79
0.9257687
lr:  0.05  bs: