In [49]:
import torch
import torchvision
import torchvision.transforms as T
import torchvision.models as models


import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.laplace import Laplace
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import  tqdm
import seaborn as sns
import pickle as pkl
from pathlib import Path
from functools import partial

In [50]:
DATA_ROOT = Path('../data')
DATA_SPLIT = 0.6
DEVICE = torch.device("cuda:3")
BATCH_SIZE = 12

def ARCH():
    m = models.resnet18(pretrained=False)
    m.fc = nn.Linear(512, len(CLASSES))
    return m

In [8]:
torch.cuda.is_available()

True

# DATA & MODEL

## Data Prep

In [5]:
CLASSES = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [18]:
def splitds(train, test):
    X = np.concatenate((train.data,test.data), axis=0)
    Y = train.targets + test.targets
    
    split_id = int(len(X) * DATA_SPLIT)
    train.data, train.targets = X[:split_id], Y[:split_id]
    test.data, test.targets = X[split_id:], Y[split_id:]


In [19]:
def get_dataset(tfms):
    trainset = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=True,
                                        download=True, transform=tfms)

    testset = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=False,
                                           download=True, transform=tfms)
    
    splitds(trainset, testset)
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                              shuffle=True, num_workers=2)

    holdoutloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                             shuffle=False, num_workers=2)
    
    
    return trainloader, holdoutloader 

In [20]:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
TFMS = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), normalize])

train, holdout = get_dataset(TFMS)

Files already downloaded and verified
Files already downloaded and verified


In [24]:
len(train)*BATCH_SIZE,len(holdout)*BATCH_SIZE

(36000, 24000)

## MODEL

In [46]:
INIT_METHODS = [nn.init.xavier_uniform_, nn.init.xavier_normal_, \
                nn.init.kaiming_uniform_, nn.init.kaiming_normal_]

def init_weights(m, init=nn.init.xavier_uniform):
    if type(m) == nn.Linear:
        init(m.weight)
        m.bias.data.fill_(0.01)
        
def init_model(model):
    func = np.random.choice(INIT_METHODS)
    model.apply(partial(init_weights, init=func))

In [95]:
class WeightHistory:
    """
    the idea is to create a folder and keep wieghts there as opposed to memory
    """

    
    def __init__(self, length, savedir):
        self.len = length
        self.savedir = savedir
        
        if not savedir.exists():
            savedir.mkdir()
            
    
    def save_weights(self, individual, glob_step):
        path = self.savedir / f'model_checkpoint_{idx}.cpt'
        torch.save({
            'glob_step': glob_step,
            'model_state_dict': individual.model.state_dict(),
            'opt_state_dict': individual.opt.state_dict(),
        }, path) 
        
    def restore_weights(self, individual, idx):
        path = self.savedir / f'model_checkpoint_{idx}.cpt'
        checkpoint = torch.load(path)
        
        individual.model.load_state_dict(checkpoint['model_state_dict'])
        individual.opt.load_state_dict(checkpoint['opt_state_dict'])
    
    



## THRESHOLDOUT

In [39]:
class Thresholdout:
    def __init__(self, train, holdout, tolerance=0.01/4, scale_factor=4, keep_log=True):
        self.tolerance = tolerance
        self.T = 4*tolerance
        
        self.eps = lambda: np.random.normal(0, 2*self.tolerance, 1)
        self.gamma = lambda: np.random.normal(0, 4*self.tolerance, 1)
        self.eta = lambda: np.random.normal(0, 8*self.tolerance, 1)

        self.train = train
        self.holdout = holdout
        
        
    def verify(self, phi):
        train_val = phi(self.train)
        holdout_val = phi(self.holdout)
                
        delta = abs(train_val - holdout_val)
        
        if delta > self.T + self.eta():
            return holdout_val + self.eps(), False
        else:
            return train_val, False
        

# EXP. ALGORITHM

In [None]:
WEIGHTS_ROOT = Path('./weightsh/')

In [91]:
class individual:
    def __init__(self, idx, K, T, savedir, lr=3e-4):
        self.idx = idx
        
        self.model = ARCH()
        init_model(self.model)
        self.model.to(DEVICE)
        
        self.criterion = nn.CrossEntropyLoss()
        self.opt = optim.Adam(self.model.parameters(), lr=lr)
        self.history = WeightHistory(length=K, savedir=savedir)
        
        self.interval = T // K
        
    def train(self, data)
        
        self.model.train()
        step = 0
        while step < steps:
            for i, batch in enumerate(data):

                inputs, labels = batch[0].to(DEVICE), batch[1].to(DEVICE)
                self.opt.zero_grad()

                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.opt.step()
                
                step += (i+1) * BATCH_SIZE
                if step % self.interval == 0: # bug
                    self.history.save_weights(self, step)
                

In [92]:
class TrainExperiment:
    def __init__(self, train, holdout, weights_root, K=10, T=20_000):
        
        self.K = K
        self.T = T
        self.weights_root = weights_root
        self.weights_root.mkdir(exist_ok=True)
        
        self.train = train
        self.holdout = holdout
        
        self.tout = Thresholdout(self.train, self.holdout)

        self.population = [individual(i, k=self.K, savedir=weights_root) for i in range(self.K)]
        
    def do_one_cycle(self):
        
        pass

In [93]:
exp = TrainExperiment(train, holdout, weights_root=WEIGHTS_ROOT)