In [1]:
import torch
import torchvision
import torchvision.transforms as T
import torchvision.models as models

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.laplace import Laplace
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import  tqdm
import seaborn as sns; sns.set()
import pickle as pkl
from pathlib import Path
import pandas as pd
from functools import partial 

sns.set(rc={'figure.figsize':(15, 6)})

In [2]:
DEVICE = torch.device("cuda:3")
DATA_ROOT = Path('../data')

In [3]:
torch.cuda.is_available()

True

# Data & Model Prep

## Data

In [4]:
def make_data(transforms):

    trainset = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=True,
                                            download=True, transform=transforms)

    testset = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=False,
                                           download=True, transform=transforms)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=12,
                                              shuffle=True, num_workers=2)

    testloader = torch.utils.data.DataLoader(testset, batch_size=12,
                                             shuffle=False, num_workers=2)
    return trainloader, testloader


CLASSES = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## Models

In [5]:
lrs = np.linspace(3e-4, 3e-3, num=10)

In [6]:
class StrongerConvNet(nn.Module):
    def __init__(self):
        super(StrongerConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 6 * 6, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = x.view(-1, 32 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        
        return x

## Thresholdout

In [7]:
class Thresholdout:
    def __init__(self, train, holdout, tolerance=0.01/4, scale_factor=4, keep_log=True):
        self.tolerance = tolerance
        
        self.laplace_eps = Laplace(torch.tensor([0.0]), torch.tensor([2*self.tolerance]))
        self.laplace_gamma = Laplace(torch.tensor([0.0]), torch.tensor([4*self.tolerance]))
        self.laplace_eta = Laplace(torch.tensor([0.0]), torch.tensor([8*self.tolerance]))

        self.train = train
        self.holdout = holdout
        
        self.T = 4*tolerance + self.noise(self.laplace_gamma)
        # self.budget = ???
        
        self.keep_log = keep_log
        if keep_log:
            self.log = pd.DataFrame(columns=['GlobStep', 'threshold', 'delta', 'phi_train', 'phi_holdout', 'estimate', 'overfit'])
        
        
    def noise(self, dist):
        return dist.sample().item()
        
    def verify_statistic(self, phi, glob_step=None):
        """
            - phi(dataset) -> statistic: 
              function returns the average of some statistic
        """
        
        train_val = phi(self.train)
        holdout_val = phi(self.holdout)
                
        delta = abs(train_val - holdout_val)
        thresh = self.T + self.noise(self.laplace_eta)
        
        if delta > thresh:
            self.T += self.noise(self.laplace_gamma)
            estimate = holdout_val + self.noise(self.laplace_eps)
        else:
            estimate = train_val
            
        if self.keep_log:
            if glob_step is None: 
                raise ValueException('please provide glob step if logging is on')
            self.log.loc[len(self.log)] = [glob_step, thresh, delta, train_val, holdout_val, estimate, delta > thresh]
            
        return estimate
            
        

In [8]:
def test_accuracy(model, data_loader): 
    correct = 0
    total = 0
    with torch.no_grad():
        for data in data_loader:
            images, labels = data[0].to(DEVICE), data[1]
            outputs = model(images).cpu()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return correct / total

# Training

In [9]:
def train(model, data, lr, epochs=10, sample_step=500 ):
    
    trainloader, testloader = data
    tout = Thresholdout(trainloader, testloader, keep_log=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    loss_history = []
    glob_step = 0
    
    for epoch in tqdm(range(epochs)):  # loop over the dataset multiple times

        running_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            loss_history += [loss.item()]
            glob_step += 1
            
            if i % sample_step == 0 and i:
                acc_val = tout.verify_statistic(partial(test_accuracy, model), glob_step)
                print(f'[{epoch+1}, step::{i}] loss [{running_loss / sample_step :.3f}] accuracy [{acc_val:.3f}]')
                running_loss = 0.0
                
    return loss_history, tout
    
    
    
    

In [10]:
conv_transform = T.Compose([T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
conv_data = make_data(conv_transform)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
results = dict()
for lr in lrs:
    print(f'--------- [LR: [{lr}]] ------------')
    convnet = StrongerConvNet()
    convnet.to(DEVICE)
    
    results[lr] = train(convnet, conv_data, lr)

--------- [LR: [0.0003]] ------------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

[1, step::500] loss [2.104] accuracy [0.182]
[1, step::1000] loss [1.958] accuracy [0.218]
[1, step::1500] loss [1.845] accuracy [0.318]
[1, step::2000] loss [1.763] accuracy [0.364]
[1, step::2500] loss [1.671] accuracy [0.361]
[1, step::3000] loss [1.652] accuracy [0.401]
[1, step::3500] loss [1.592] accuracy [0.419]
[1, step::4000] loss [1.580] accuracy [0.419]
[2, step::500] loss [1.534] accuracy [0.421]
[2, step::1000] loss [1.521] accuracy [0.435]
[2, step::1500] loss [1.477] accuracy [0.452]
[2, step::2000] loss [1.445] accuracy [0.466]
[2, step::2500] loss [1.487] accuracy [0.470]
[2, step::3000] loss [1.422] accuracy [0.465]
[2, step::3500] loss [1.434] accuracy [0.487]
[2, step::4000] loss [1.424] accuracy [0.455]
[3, step::500] loss [1.383] accuracy [0.497]
[3, step::1000] loss [1.346] accuracy [0.506]
[3, step::1500] loss [1.340] accuracy [0.501]
[3, step::2000] loss [1.335] accuracy [0.501]
[3, step::2500] loss [1.303] accuracy [0.519]
[3, step::3000] loss [1.309] accuracy

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

[1, step::500] loss [2.067] accuracy [0.216]
[1, step::1000] loss [1.889] accuracy [0.280]
[1, step::1500] loss [1.778] accuracy [0.332]
[1, step::2000] loss [1.708] accuracy [0.353]
[1, step::2500] loss [1.649] accuracy [0.382]
[1, step::3000] loss [1.589] accuracy [0.418]
[1, step::3500] loss [1.544] accuracy [0.440]
[1, step::4000] loss [1.531] accuracy [0.446]
[2, step::500] loss [1.483] accuracy [0.448]
[2, step::1000] loss [1.470] accuracy [0.470]
[2, step::1500] loss [1.450] accuracy [0.446]
[2, step::2000] loss [1.421] accuracy [0.494]
[2, step::2500] loss [1.433] accuracy [0.511]
[2, step::3000] loss [1.406] accuracy [0.511]
[2, step::3500] loss [1.359] accuracy [0.522]
[2, step::4000] loss [1.345] accuracy [0.535]
[3, step::500] loss [1.305] accuracy [0.549]
[3, step::1000] loss [1.282] accuracy [0.535]
[3, step::1500] loss [1.295] accuracy [0.560]
[3, step::2000] loss [1.273] accuracy [0.559]
[3, step::2500] loss [1.284] accuracy [0.567]
[3, step::3000] loss [1.266] accuracy

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

[1, step::500] loss [2.098] accuracy [0.194]
[1, step::1000] loss [1.948] accuracy [0.233]
[1, step::1500] loss [1.884] accuracy [0.268]
[1, step::2000] loss [1.819] accuracy [0.304]
[1, step::2500] loss [1.708] accuracy [0.394]
[1, step::3000] loss [1.661] accuracy [0.396]
[1, step::3500] loss [1.587] accuracy [0.408]
[1, step::4000] loss [1.557] accuracy [0.410]
[2, step::500] loss [1.506] accuracy [0.444]
[2, step::1000] loss [1.487] accuracy [0.478]
[2, step::1500] loss [1.426] accuracy [0.452]
[2, step::2000] loss [1.424] accuracy [0.489]
[2, step::2500] loss [1.419] accuracy [0.485]
[2, step::3000] loss [1.423] accuracy [0.481]
[2, step::3500] loss [1.395] accuracy [0.477]
[2, step::4000] loss [1.384] accuracy [0.496]
[3, step::500] loss [1.336] accuracy [0.512]
[3, step::1000] loss [1.317] accuracy [0.496]
[3, step::1500] loss [1.308] accuracy [0.520]
[3, step::2000] loss [1.319] accuracy [0.516]
[3, step::2500] loss [1.275] accuracy [0.521]
[3, step::3000] loss [1.304] accuracy

In [None]:
with open('convnet-lr-outcome.pkl','wb') as fp:
    pkl.dump(results, fp)