In [29]:
#!L
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

from linearized_nns.estimator import Estimator
from linearized_nns.pytorch_impl.estimators import SgdEstimator
from linearized_nns.pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from linearized_nns.pytorch_impl import ClassifierTraining
from linearized_nns.pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from linearized_nns.pytorch_impl.nns.utils import to_one_hot, print_sizes
from linearized_nns.from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [3]:
#!L
device = torch.device('cuda:0') if (torch.cuda.is_available()) else torch.device('cpu')
device

device(type='cuda', index=0)

In [4]:
#!L
class GpEstimator(Estimator):
    def __init__(self, models, n_classes, learning_rate, x_example, device, groups=1):
        super(GpEstimator, self).__init__()
        self.models    = [model.to(device) for model in models]
        self.lr        = learning_rate
        self.n_classes = n_classes
        self.device    = device
        
        n = len(models)
        X = torch.stack([x_example]).to(device)
        
        model = models[0].to(device)
        readout_size = model.readout(X).size()[1]
    
        # TODO: Assert that models have the same readout size
        
        self.w      = torch.zeros([n, readout_size, n_classes]).to(device)
        self.w_size = n * groups
        
    def get_w_update(self, X, right_vector):
        with torch.no_grad():
            assert len(X) == len(right_vector)

            X            = X.to(self.device)
            right_vector = right_vector.to(self.device)

            n = len(X)
            w_updates = []
            
            for model in self.models:
                features = self.to_model_features(X, model)
                update = torch.matmul(features.T, right_vector)
                w_updates.append(update)
            return torch.stack(w_updates)
                                  
    def to_model_features(self, X, model):
        with torch.no_grad():
            model = model.to(device)
            return model.readout(X) * (1. / np.sqrt(self.w_size))
        
    def calc_kernel(self, X):
        with torch.no_grad():
            X = X.to(device)
            
            res = torch.zeros([len(X), len(X)]).to(device).double()
            for model in self.models:
                features_x = self.to_model_features(X, model)
                
                res += torch.matmul(features_x, features_x.T).double()
            return res.float()
        
    def calc_kernel_pred(self, X):
        with torch.no_grad():
            X = X.to(device)
            
            n = len(X)
            y_pred = torch.zeros([n, self.n_classes]).to(device).double()
            kernel = torch.zeros([len(X), len(X)]).to(device).double()
            
            for model, w in zip(self.models, self.w):
                features = self.to_model_features(X, model)
                
                kernel += torch.matmul(features, features.T).double()
                y_pred += torch.matmul(features, w).double()
            return kernel, y_pred.float()
        
    def calc_kernels(self, X_train, X_test):
        with torch.no_grad():
            X_train = X_train.to(device)
            X_test  = X_test.to(device)
            
            res_train = torch.zeros([len(X_train), len(X_train)]).to(device)
            res_test  = torch.zeros([len(X_test),  len(X_train)]).to(device)
            for model in self.models:
                features_train = self.to_model_features(X_train, model)
                features_test  = self.to_model_features(X_test,  model)
                
                res_train += torch.matmul(features_train, features_train.T)
                res_test  += torch.matmul(features_test,  features_train.T)
            return res_train, res_test
            
    def predict(self, X, cur_w=None):
        X = X.to(self.device)
        if cur_w is None:
            cur_w = self.w
         
        with torch.no_grad():
            n = len(X)
            res = torch.zeros([n, self.n_classes]).double().to(device)

            for model, w in zip(self.models, cur_w):
                features = self.to_model_features(X, model)
                res += torch.matmul(features, w).double()
            return res.float()

In [5]:
#!L
def calc_right_vector(kernel, y, learning_rate=1e5, reg_param=0):
    with torch.no_grad():
        y      = y.to(device)
        kernel = kernel.to(device)
        
        n      = len(kernel)
        reg = torch.eye(n).to(device) * reg_param
        
        exp_term = - learning_rate * compute_exp_term(- learning_rate * (kernel + reg), device)
        right_vector = torch.matmul(exp_term, - y)
        return right_vector

In [37]:
#!L
%%time

X_train, labels_train, X_test, labels_test = get_cifar_zca()

N_train = 50000
N_test  = 1000

X_train      = torch.tensor(X_train[:N_train]).float()
labels_train = torch.tensor(labels_train[:N_train], dtype=torch.long)

X_test       = torch.tensor(X_test[:N_test]).float()
labels_test  = torch.tensor(labels_test[:N_test],  dtype=torch.long)

num_classes = 10

y_train = to_one_hot(labels_train, num_classes).to(device)
y_test  = to_one_hot(labels_test,  num_classes).to(device)

Files already downloaded and verified
Files already downloaded and verified
CPU times: user 3min 12s, sys: 1min 26s, total: 4min 39s
Wall time: 51.2 s


In [38]:
#!L
cifar_train = CustomTensorDataset(torch.tensor(X_train), torch.tensor(y_train).float(), transform='flips')
cifar_test  = CustomTensorDataset(torch.tensor(X_test),  torch.tensor(labels_test, dtype=torch.long))

  import sys
  import sys
  
  


In [43]:
#!L
def boosting(estimator, train_loader, test_loader, learning_rate=1e5, beta=1., n_iter=10):
    output_kernel = False
    
    with torch.no_grad():
        batches_num = 0
        for _ in enumerate(train_loader):
            batches_num += 1
        n_batches = (batches_num * 2) // 3 + 1
        
        test_size = 0
        for _, (X, _) in enumerate(test_loader):
            test_size += len(X)
            
        for iter_num  in range(n_iter):
            print(f"iter {iter_num} ==========================")
            
            w_update = 0
            iter_start = time.time()
            
            for batch_i, (X, y) in enumerate(train_loader):
                if batch_i >= n_batches:
                    break
                
                X = X.to(device)
                y = y.to(device)
                
                kernel, y_pred = estimator.calc_kernel_pred(X)
                if output_kernel:
                    print(f"kernel\n{kernel[:5,:5]}")
                
                y_residual = y_pred - y
                
                train_acc = (y_pred.argmax(dim=1) == y.argmax(dim=1)).float().mean().item()
                train_mse = (y_residual ** 2).mean().item()
                print(f"batch {batch_i}: train_acc {train_acc:.4f}, train_mse {train_mse:.6f}")
                
                right_vector = calc_right_vector(kernel, y_residual, learning_rate=learning_rate)
                
                w_update += estimator.get_w_update(X, right_vector).double()
                
                pred_change = torch.matmul(kernel, right_vector.double())
                

            w_update = (w_update / n_batches).float()
            
            _, (X, y) = next(enumerate(train_loader))
            
            y_pred      = estimator.predict(X)
            pred_change = estimator.predict(X, w_update)
            
            y_residual = y_pred - y
                
            estimator.w -= w_update * beta 
            
            test_acc = 0
            for _, (X_test, labels) in enumerate(test_loader):
                y_pred = estimator.predict(X_test) 
                test_acc += (y_pred.argmax(dim=1) == labels.to(device)).float().sum().item() / test_size
                
            print(f"iter {iter_num} done. took {time.time() - iter_start:.0f}s. beta {beta:.3f}, test_acc {test_acc:.4f}")
            print()

In [40]:
#!L
cifar_train = CustomTensorDataset(torch.tensor(X_train), torch.tensor(y_train).float(), transform='flips')
cifar_test  = CustomTensorDataset(torch.tensor(X_test),  torch.tensor(labels_test, dtype=torch.long))

train_loader = torch.utils.data.DataLoader(cifar_train, batch_size=1280 * 2)
test_loader  = torch.utils.data.DataLoader(cifar_test, batch_size=1000)

  import sys
  import sys
  
  


In [27]:
#!L
%state_exclude models
n_models = 500

# 500 * 50 * 32  = 800k

models = [Myrtle7(num_filters=1, groups=50) for _ in range(n_models)]
n_models

estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device, groups=50)

In [28]:
#!L
boosting(estimator, train_loader, test_loader, learning_rate=1e6, n_iter=30)

batch 0: train_acc 0.1008, train_mse 1.000000
batch 1: train_acc 0.0945, train_mse 1.000000
batch 2: train_acc 0.1059, train_mse 1.000000
batch 3: train_acc 0.1016, train_mse 1.000000
iter 0 done. took 361s. beta -1.046, test_acc 0.7520

batch 0: train_acc 0.9000, train_mse 0.142131
batch 1: train_acc 0.8918, train_mse 0.143286
batch 2: train_acc 0.8859, train_mse 0.145053
batch 3: train_acc 0.8977, train_mse 0.142880
iter 1 done. took 361s. beta -2.028, test_acc 0.7710

batch 0: train_acc 0.9555, train_mse 0.091109
batch 1: train_acc 0.9535, train_mse 0.092889
batch 2: train_acc 0.9543, train_mse 0.091222
batch 3: train_acc 0.9547, train_mse 0.093278
iter 2 done. took 361s. beta -1.610, test_acc 0.7900

batch 0: train_acc 0.9832, train_mse 0.064314
batch 1: train_acc 0.9805, train_mse 0.066563
batch 2: train_acc 0.9820, train_mse 0.063889
batch 3: train_acc 0.9789, train_mse 0.065481
iter 3 done. took 361s. beta -1.669, test_acc 0.7990

batch 0: train_acc 0.9883, train_mse 0.049136
ba

In [34]:
#!L
%whos

Variable              Type                   Data/Info
------------------------------------------------------
ClassifierTraining    LazyVariable           Lazy variable
CustomTensorDataset   type                   <class 'linearized_nns.fr<...>set.CustomTensorDataset'>
Estimator             LazyVariable           Lazy variable
F                     LazyVariable           Lazy variable
FashionMNIST          LazyVariable           Lazy variable
GpEstimator           LazyVariable           Lazy variable
Image                 LazyVariable           Lazy variable
Myrtle10              LazyVariable           Lazy variable
Myrtle5               LazyVariable           Lazy variable
Myrtle7               LazyVariable           Lazy variable
N_test                LazyVariable           Lazy variable of type int
N_train               LazyVariable           Lazy variable of type int
SgdEstimator          LazyVariable           Lazy variable
X_test                Tensor                 tensor([[[[ 

In [35]:
#!L
%state_exclude models
%state_exclude estimator
%state_exclude myrtle10_estimator

torch.manual_seed(0)
np.random.seed(0)

n_models = 500

# 500 * 50 * 32  = 800k

models = [Myrtle10(num_filters=1, groups=50) for _ in range(n_models)]
n_models

myrtle10_estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device, groups=50)

In [36]:
#!L

boosting(myrtle10_estimator, train_loader, test_loader, learning_rate=1e5, n_iter=100)

batch 0: train_acc 0.1008, train_mse 1.000000
batch 1: train_acc 0.0945, train_mse 1.000000
batch 2: train_acc 0.1059, train_mse 1.000000
batch 3: train_acc 0.1016, train_mse 1.000000
iter 0 done. took 1231s. beta -1.046, test_acc 0.7600

batch 0: train_acc 0.8922, train_mse 0.140524
batch 1: train_acc 0.8930, train_mse 0.142831
batch 2: train_acc 0.8898, train_mse 0.143399
batch 3: train_acc 0.8824, train_mse 0.141551
iter 1 done. took 1230s. beta -2.035, test_acc 0.7860

batch 0: train_acc 0.9598, train_mse 0.088664
batch 1: train_acc 0.9574, train_mse 0.090821
batch 2: train_acc 0.9570, train_mse 0.090587
batch 3: train_acc 0.9520, train_mse 0.092746
iter 2 done. took 1230s. beta -1.618, test_acc 0.7990

batch 0: train_acc 0.9777, train_mse 0.065684
batch 1: train_acc 0.9805, train_mse 0.065301
batch 2: train_acc 0.9816, train_mse 0.064903
batch 3: train_acc 0.9770, train_mse 0.064836
iter 3 done. took 1231s. beta -1.674, test_acc 0.8070

batch 0: train_acc 0.9906, train_mse 0.04993



KeyboardInterrupt: 

In [46]:
#!L
%state_exclude models
%state_exclude estimator
%state_exclude myrtle10_estimator

torch.manual_seed(0)
np.random.seed(0)

n_models = 10

# 500 * 50 * 32  = 800k

models = [Myrtle10(num_filters=1, groups=50) for _ in range(n_models)]
n_models

myrtle10_estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device, groups=50)

In [None]:
#!L
boosting(myrtle10_estimator, train_loader, test_loader, learning_rate=1e5, beta=1., n_iter=100)

batch 0: train_acc 0.1008, train_mse 1.000000
batch 1: train_acc 0.0945, train_mse 1.000000
batch 2: train_acc 0.1059, train_mse 1.000000
batch 3: train_acc 0.1016, train_mse 1.000000
batch 4: train_acc 0.0977, train_mse 1.000000
batch 5: train_acc 0.1086, train_mse 1.000000
batch 6: train_acc 0.0988, train_mse 1.000000
batch 7: train_acc 0.1023, train_mse 1.000000
batch 8: train_acc 0.1000, train_mse 1.000000
batch 9: train_acc 0.1031, train_mse 1.000000
batch 10: train_acc 0.1008, train_mse 1.000000
batch 11: train_acc 0.1012, train_mse 1.000000
batch 12: train_acc 0.0973, train_mse 1.000000
batch 13: train_acc 0.0957, train_mse 1.000000
iter 0 done. took 3602s. beta 1.000, test_acc 0.7740

batch 0: train_acc 0.8445, train_mse 0.167174
batch 1: train_acc 0.8371, train_mse 0.168976
batch 2: train_acc 0.8289, train_mse 0.168668
batch 3: train_acc 0.8316, train_mse 0.168418
batch 4: train_acc 0.8301, train_mse 0.169411
batch 5: train_acc 0.8281, train_mse 0.170728
batch 6: train_acc 0.8

In [None]:
#!L
