In [1]:
import torch
import torchvision
import torchvision.transforms as T
import torchvision.models as models


import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.laplace import Laplace
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import  tqdm
import seaborn as sns
import pickle as pkl
from pathlib import Path
import pandas as pd
from functools import partial 


In [2]:
DEVICE = torch.device("cuda:3")
DATA_ROOT = Path('../data')

In [3]:
torch.cuda.is_available()

True

# Data & Model Prep

## Data

In [4]:
def make_data(transforms):

    trainset = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=True,
                                            download=True, transform=transforms)

    testset = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=False,
                                           download=True, transform=transforms)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                              shuffle=True, num_workers=2)

    testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                             shuffle=False, num_workers=2)
    return trainloader, testloader


CLASSES = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## Models

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


convnet = ConvNet()
convnet.to(DEVICE)

In [5]:
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(512, len(CLASSES))
resnet.to(DEVICE)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## Thresholdout

In [6]:
class Thresholdout:
    def __init__(self, train, holdout, tolerance=0.01/4, scale_factor=4, keep_log=True):
        self.tolerance = tolerance
        
        self.laplace_eps = Laplace(torch.tensor([0.0]), torch.tensor([2*self.tolerance]))
        self.laplace_gamma = Laplace(torch.tensor([0.0]), torch.tensor([4*self.tolerance]))
        self.laplace_eta = Laplace(torch.tensor([0.0]), torch.tensor([8*self.tolerance]))

        self.train = train
        self.holdout = holdout
        
        self.T = 4*tolerance + self.noise(self.laplace_gamma)
        # self.budget = ???
        
        self.keep_log = keep_log
        if keep_log:
            self.log = pd.DataFrame(columns=['GlobStep', 'threshold', 'delta', 'phi_train', \
                                             'phi_holdout', 'estimate', 'overfit'])
        
        
    def noise(self, dist):
        return dist.sample().item()
        
    def verify_statistic(self, phi, glob_step=None):
        """
            - phi(dataset) -> statistic: 
              function returns the average of some statistic
        """
        
        train_val = phi(self.train)
        holdout_val = phi(self.holdout)
                
        delta = abs(train_val - holdout_val)
        thresh = self.T + self.noise(self.laplace_eta)
        
        if delta > thresh:
            self.T += self.noise(self.laplace_gamma)
            estimate = holdout_val + self.noise(self.laplace_eps)
        else:
            estimate = train_val
            
        if self.keep_log:
            if glob_step is None: 
                raise ValueException('please provide glob step if logging is on')
            self.log.loc[len(self.log)] = [glob_step, thresh, delta, train_val, holdout_val, estimate, delta > thresh]
            
        return estimate
            
        

In [7]:
def test_accuracy(model, data_loader): 
    correct = 0
    total = 0
    with torch.no_grad():
        for data in data_loader:
            images, labels = data[0].to(DEVICE), data[1]
            outputs = model(images).cpu()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return correct / total

# Training

In [8]:
def train(model, data, epochs=20, sample_step=1000):
    
    trainloader, testloader = data
    tout = Thresholdout(trainloader, testloader, keep_log=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=3e-4)
    
    loss_history = []
    glob_step = 0
    
    for epoch in tqdm(range(epochs)):  # loop over the dataset multiple times

        running_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            loss_history += [loss.item()]
            glob_step += 1
            
            if i % sample_step == 0 and i:
                acc_val = tout.verify_statistic(partial(test_accuracy, model), glob_step)
                print(f'[{epoch+1}, step::{i}] loss [{running_loss / sample_step :.3f}] accuracy [{acc_val:.3f}]')
                running_loss = 0.0
                
    return loss_history, tout
    
    
    
    

In [None]:
resnet_normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
resnet_transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), resnet_normalize])

resnet_data = make_data(resnet_transform)
resnet_history, resnet_tout = train(resnet, resnet_data)

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

[1, step::1000] loss [1.600] accuracy [0.500]
[1, step::2000] loss [1.198] accuracy [0.626]
[1, step::3000] loss [1.002] accuracy [0.664]
[2, step::1000] loss [0.850] accuracy [0.732]
[2, step::2000] loss [0.791] accuracy [0.724]
[2, step::3000] loss [0.737] accuracy [0.741]
[3, step::1000] loss [0.636] accuracy [0.757]
[3, step::2000] loss [0.620] accuracy [0.780]
[3, step::3000] loss [0.610] accuracy [0.792]
[4, step::1000] loss [0.498] accuracy [0.801]
[4, step::2000] loss [0.495] accuracy [0.799]
[4, step::3000] loss [0.482] accuracy [0.809]
[5, step::1000] loss [0.381] accuracy [0.809]
[5, step::2000] loss [0.390] accuracy [0.813]
[5, step::3000] loss [0.403] accuracy [0.793]
[6, step::1000] loss [0.290] accuracy [0.816]
[6, step::2000] loss [0.306] accuracy [0.817]
[6, step::3000] loss [0.308] accuracy [0.816]
[7, step::1000] loss [0.201] accuracy [0.812]
[7, step::2000] loss [0.244] accuracy [0.810]
[7, step::3000] loss [0.242] accuracy [0.830]
[8, step::1000] loss [0.147] accur

In [None]:
with open('resnet-outcome.pkl','wb') as fp:
    pkl.dump((resnet_history, resnet_tout), fp)