In [None]:
! pip install weightwatcher  transformers 

## Model training in pytorch

This notebook shows how to train the models used in the emperical section of the paper. This notebook has the pytorch implementation. For the keras implementation, see `WW_MLP3_Training_pytorch.ipynb`.

In [1]:
import os
os.chdir("../pytorch")

In [2]:
import sys
from pathlib import Path
from time import time

import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader, Dataset

from weightwatcher import WeightWatcher 

from matplotlib import pyplot as plt

from utils import last_epoch
from trainer import Trainer, PreLoader
from pildataset import PILDataSet
from models import MLP2

### Set up the various random seeds for replicability

In [3]:
#NOTE: changing this value requires resetting the notebook kernel.
DETERMINISTIC = True

if DETERMINISTIC:
    import random
    import os

    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

    torch.use_deterministic_algorithms(True)
    def reset_random_seeds(seed_value=1):
        os.environ['PYTHONHASHSEED']=str(seed_value)
        torch.manual_seed(0)
        np.random.seed(seed_value)
        random.seed(seed_value)

    reset_random_seeds()
else:
    reset_random_seeds = lambda: None

### Train a test run.

In [6]:
TRAIN = PILDataSet(True,  DS="MNIST")
TEST  = PILDataSet(False, DS="MNIST")
loader = PreLoader("MNIST", TRAIN, TEST, batch_size=16)

print(f"DETERMINISTIC set to {DETERMINISTIC}")

reset_random_seeds()
m = MLP2(widths=(300, 100), H=28, W=28, C=1)
t = Trainer(m)

t.train_loop("SETOL/TEST", 0, 5, loader, LR=[0.01, 0, 0], loss="CCE")
print("\n")

t.load(0, 0, "SETOL/TEST")
reset_random_seeds()
t.train_loop("SETOL/TEST", 1, 5, loader, LR=[0.01, 0, 0], loss="CCE")
print("\n")

# This time try initializing a new model
reset_random_seeds()
m = MLP2(widths=(300, 100), H=28, W=28, C=1)
t = Trainer(m)
t.train_loop("SETOL/TEST", 2, 5, loader, LR=[0.01, 0, 0], loss="CCE")
print("\n")


# Now try training the second layer
t.load(0, 0, "SETOL/TEST")
reset_random_seeds()
t.train_loop("SETOL/TEST", 3, 5, loader, LR=[0, 0.01, 0], loss="CCE")
print("\n")

t.load(0, 0, "SETOL/TEST")
reset_random_seeds()
t.train_loop("SETOL/TEST", 4, 5, loader, LR=[0, 0.01, 0], loss="CCE")

DETERMINISTIC set to True
SETOL/TEST run 0 epoch 1 loss 0.6368 train accuracy 0.8505 test accuracy 0.9118 22.39 seconds	 alpha 1 4.570	 alpha 2 14.347
SETOL/TEST run 0 epoch 2 loss 0.3296 train accuracy 0.9141 test accuracy 0.9256 22.49 seconds	 alpha 1 3.860	 alpha 2 14.347
SETOL/TEST run 0 epoch 3 loss 0.2774 train accuracy 0.9257 test accuracy 0.9336 19.78 seconds	 alpha 1 3.611	 alpha 2 14.347
SETOL/TEST run 0 epoch 4 loss 0.2484 train accuracy 0.9333 test accuracy 0.9368 22.19 seconds	 alpha 1 3.500	 alpha 2 14.347
SETOL/TEST run 0 epoch 5 loss 0.2284 train accuracy 0.9384 test accuracy 0.9421 22.05 seconds	 alpha 1 3.415	 alpha 2 14.347


SETOL/TEST run 1 epoch 1 loss 0.6368 train accuracy 0.8505 test accuracy 0.9118 22.53 seconds	 alpha 1 4.570	 alpha 2 14.347
SETOL/TEST run 1 epoch 2 loss 0.3296 train accuracy 0.9141 test accuracy 0.9256 22.08 seconds	 alpha 1 3.860	 alpha 2 14.347
SETOL/TEST run 1 epoch 3 loss 0.2774 train accuracy 0.9257 test accuracy 0.9336 22.03 seconds	 al

### Training the whole set

It is recommended to run `train_models` in a screen/tmux session because it could take a long time to run, and because multiple runs can be done concurrently. Nevertheless, here is an example of how to call its `main` function.

In [None]:
from train_models import main

main("MNIST", 5, True, 6, WHITEN=False)