In [1]:
import torch
import torchvision
import torchvision.transforms as T
import torchvision.models as models


import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.laplace import Laplace
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import  tqdm
import seaborn as sns
import pickle as pkl
from pathlib import Path
from functools import partial
import pandas as pd

In [2]:
DATA_ROOT = Path('../data')
DATA_SPLIT = 0.6
DEVICE = torch.device("cuda:3")
BATCH_SIZE = 12

def ARCH():
    m = models.resnet18(pretrained=False)
    m.fc = nn.Linear(512, len(CLASSES))
    return m

In [3]:
torch.cuda.is_available()

True

# DATA & MODEL

## Data Prep

In [4]:
CLASSES = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [5]:
def splitds(train, test):
    X = np.concatenate((train.data,test.data), axis=0)
    Y = train.targets + test.targets
    
    split_id = int(len(X) * DATA_SPLIT)
    train.data, train.targets = X[:split_id], Y[:split_id]
    test.data, test.targets = X[split_id:], Y[split_id:]


In [6]:
def get_dataset(tfms):
    trainset = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=True,
                                        download=True, transform=tfms)

    testset = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=False,
                                           download=True, transform=tfms)
    
    splitds(trainset, testset)
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                              shuffle=True, num_workers=2)

    holdoutloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                             shuffle=False, num_workers=2)
    
    
    return trainloader, holdoutloader 

In [7]:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
TFMS = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), normalize])

train, holdout = get_dataset(TFMS)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
len(train)*BATCH_SIZE,len(holdout)*BATCH_SIZE

(36000, 24000)

## MODEL

In [9]:
INIT_METHODS = [nn.init.xavier_uniform_, nn.init.xavier_normal_, \
                nn.init.kaiming_uniform_, nn.init.kaiming_normal_]

def init_weights(m, init=nn.init.xavier_uniform):
    if type(m) == nn.Linear:
        init(m.weight)
        m.bias.data.fill_(0.01)
        
def init_model(model):
    func = np.random.choice(INIT_METHODS)
    model.apply(partial(init_weights, init=func))

In [10]:
class WeightHistory:
    """
    the idea is to create a folder and keep wieghts there as opposed to memory
    """

    
    def __init__(self, length, savedir):
        self.len = length
        self.savedir = savedir
        
        if not savedir.exists():
            savedir.mkdir()
            
    
    def save_weights(self, individual, glob_step, kidx):
        path = self.savedir / f'model_checkpoint_{kidx}.cpt'
        torch.save({
            'glob_step': glob_step,
            'model_state_dict': individual.model.state_dict(),
            'opt_state_dict': individual.opt.state_dict(),
        }, path) 
        
    def load_weights(self, kidx):
        path = self.savedir / f'model_checkpoint_{kidx}.cpt'
        checkpoint = torch.load(path)
        
        return checkpoint



## THRESHOLDOUT

In [11]:
class Thresholdout:
    def __init__(self, train, holdout, tolerance=0.01/4, scale_factor=4, keep_log=True):
        self.tolerance = tolerance
        self.T = 4*tolerance
        
        self.eps = lambda: np.random.normal(0, 2*self.tolerance, 1)[0]
        self.gamma = lambda: np.random.normal(0, 4*self.tolerance, 1)[0]
        self.eta = lambda: np.random.normal(0, 8*self.tolerance, 1)[0]

        self.train = train
        self.holdout = holdout
        
        
    def verify(self, phi):
        train_val = phi(self.train)
        holdout_val = phi(self.holdout)
                
        delta = abs(train_val - holdout_val)
        
        if delta > self.T + self.eta():
            return holdout_val + self.eps(), True
        else:
            return train_val, False
        

In [12]:
def accuracy(model, data): # phi
    model.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data in tqdm(data,total=len(data)):
            images, labels = data[0].to(DEVICE), data[1]
            outputs = model(images).cpu()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return correct / total

# EXP. ALGORITHM

In [13]:
WEIGHTS_ROOT = Path('./weightsH/')
WEIGHTS_ROOT.mkdir(exist_ok=True)

In [14]:
class individual:
    def __init__(self, idx, K, T, savedir, lr=3e-4):
        self.idx = idx
        self.T = T
        self.K = K
        self.lr = lr
        self.model = ARCH() 
        self.init()
        
        self.criterion = nn.CrossEntropyLoss()
        self.history = WeightHistory(length=K, savedir=savedir)
        
        self.interval = T // K
        self.hidx = [0]
        
    def init(self):
        init_model(self.model)
        self.model.to(DEVICE)
        self.opt = optim.Adam(self.model.parameters(), lr=self.lr)
        
    def assign_weights(self, history, kidx):
        ckpt = history.load_weights(kidx)
        
        self.model.load_state_dict(ckpt['model_state_dict'])
        self.opt.load_state_dict(ckpt['opt_state_dict'])
        
        self.hidx += [kidx]


    def train(self, data, step=0):
        
        self.model.train()

        for i, batch in enumerate(data):

            inputs, labels = batch[0].to(DEVICE), batch[1].to(DEVICE)
            self.opt.zero_grad()

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.opt.step()

            step += 1
            
            if step % self.interval == 0:
                self.history.save_weights(self, step, step // self.interval)
                    
            if step >= self.T:
                return
            
        self.train(data, step=step)
                

In [15]:
class TrainExperiment:
    def __init__(self, train, holdout, weights_root, K=10, T=20_000):
        
        self.K = K
        self.T = T
        
        self.weights_root = weights_root
        self.weights_root.mkdir(exist_ok=True)
        
        self.train = train
        self.holdout = holdout
        
        self.tout = Thresholdout(self.train, self.holdout)

        self.population = [individual(i, T=self.T, K=self.K, savedir=weights_root / f'indv_{i}') for i in range(self.K)]
        
        self.cycleid = 0
        self.log = pd.DataFrame(columns=['CycleNum','individualID', 'score', 'overfit', 'hidx'])
        
        
    def do_one_cycle(self):
        scores = [None] * self.K
        overfit = [None] * self.K
        
        for indv in tqdm(self.population):
            indv.train(self.train)
            scores[indv.idx], overfit[indv.idx] = self.tout.verify(partial(accuracy, indv.model))
        
        top = max(scores)
        print(f'[CYCLE::{self.cycleid}] top performer: [{top:.3f}]')
        topidx = scores.index(top)
        self.topindv = self.population[topidx].history
        
        hs = [np.random.choice(range(self.K+1)) for _ in range(self.K)]
        print(f'[CYCLE::{self.cycleid}] hs: {hs}')

        
        for idx, h in enumerate(hs):
            self.population[idx].init()
            if h != 0:
                self.population[idx].assign_weights(self.topindv, h)
        
        for indv in self.population:
            self.log.loc[len(self.log)] = [self.cycleid, indv.idx, scores[indv.idx], overfit[indv.idx], indv.hidx[-1]]
            
        self.cycleid += 1
            
        
        
        

In [16]:
exp = TrainExperiment(train, holdout, weights_root=WEIGHTS_ROOT, K=10, T=9_000) # T=9_000

In [None]:
for _ in range(10):
    exp.do_one_cycle()
    exp.log.to_csv(f'./logs/cycle_{exp.cycleid}.csv')

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))