## Download repo

In [1]:
import os
import shutil
import zipfile
import urllib.request

def download_repo(url, save_to):
    zip_filename = save_to + '.zip'
    urllib.request.urlretrieve(url, zip_filename)
    
    if os.path.exists(save_to):
        shutil.rmtree(save_to)
    with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
        zip_ref.extractall('.')
    del zip_ref
    assert os.path.exists(save_to)

In [2]:
REPO_PATH = 'LinearizedNNs-master'

# download_repo(url='https://github.com/maxkvant/LinearizedNNs/archive/master.zip', save_to=REPO_PATH)

## Imports

In [3]:
import sys
sys.path.append(f"{REPO_PATH}/src")

In [4]:
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

from estimator import Estimator
from pytorch_impl.estimators import SgdEstimator
from pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from pytorch_impl import ClassifierTraining
from pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from pytorch_impl.nns.utils import to_one_hot, print_sizes
from from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [5]:
device = torch.device('cuda:0') if (torch.cuda.is_available()) else torch.device('cpu')
device

device(type='cuda', index=0)

## Cifar 10

In [28]:
class GpEstimator(Estimator):
    def __init__(self, models, n_classes, learning_rate, x_example, device):
        super(GpEstimator, self).__init__()
        self.models    = [model.to(device) for model in models]
        self.lr        = learning_rate
        self.n_classes = n_classes
        self.device    = device
        
        n = len(models)
        X = torch.stack([x_example]).to(device)
        
        model = models[0].to(device)
        readout_size = model.readout(X).size()[1]
    
        # TODO: Assert that models have the same readout size
        
        self.w      = torch.zeros([n, readout_size, n_classes]).to(device)
        self.w_size = n * readout_size 
        
    def get_w_update(self, X, right_vector):
        with torch.no_grad():
            assert len(X) == len(right_vector)

            X            = X.to(self.device)
            right_vector = right_vector.to(self.device)

            n = len(X)
            w_updates = []
            
            for model in models:
                features = self.to_model_features(X, model)
                update = torch.matmul(features.T, right_vector)
                w_updates.append(update)
            return torch.stack(w_updates)
                                  
    def to_model_features(self, X, model):
        with torch.no_grad():
            model = model.to(device)
            return model.readout(X) * (1. / np.sqrt(self.w_size))
        
    def calc_kernel(self, X):
        with torch.no_grad():
            X = X.to(device)
            
            res = torch.zeros([len(X), len(X)]).to(device)
            for model in models:
                features_x = self.to_model_features(X, model)
                
                res += torch.matmul(features_x, features_x.T)
            return res
        
    def calc_kernel_pred(self, X):
        with torch.no_grad():
            X = X.to(device)
            
            n = len(X)
            y_pred = torch.zeros([n, self.n_classes]).to(device)
            kernel = torch.zeros([len(X), len(X)]).to(device)
            
            for model, w in zip(models, self.w):
                features = self.to_model_features(X, model)
                
                kernel += torch.matmul(features, features.T)
                y_pred += torch.matmul(features, w)
            return kernel, y_pred
        
    def calc_kernels(self, X_train, X_test):
        with torch.no_grad():
            X_train = X_train.to(device)
            X_test  = X_test.to(device)
            
            res_train = torch.zeros([len(X_train), len(X_train)]).to(device)
            res_test  = torch.zeros([len(X_test),  len(X_train)]).to(device)
            for model in models:
                features_train = self.to_model_features(X_train, model)
                features_test  = self.to_model_features(X_test,  model)
                
                res_train += torch.matmul(features_train, features_train.T)
                res_test  += torch.matmul(features_test,  features_train.T)
            return res_train, res_test
            
    def predict(self, X):
        X = X.to(self.device)
        with torch.no_grad():
            n = len(X)
            res = torch.zeros([n, self.n_classes]).to(device)

            for model, w in zip(models, self.w):
                features = self.to_model_features(X, model)
                res += torch.matmul(features, w)
            return res

In [29]:
def calc_right_vector(kernel, y, learning_rate=1e5, reg_param=0):
    with torch.no_grad():
        y      = y.to(device)
        kernel = kernel.to(device)
        
        n      = len(kernel)
        reg = torch.eye(n).to(device) * reg_param
        
        exp_term = - learning_rate * compute_exp_term(- learning_rate * (kernel + reg), device)
        right_vector = torch.matmul(exp_term, - y)
        return right_vector

In [8]:
%%time

X_train, labels_train, X_test, labels_test = get_cifar_zca()

N_train = 12800
N_test  = 1000

X_train      = torch.tensor(X_train[:N_train]).float()
labels_train = torch.tensor(labels_train[:N_train], dtype=torch.long)

X_test       = torch.tensor(X_test[:N_test]).float()
labels_test  = torch.tensor(labels_test[:N_test],  dtype=torch.long)

num_classes = 10

y_train = to_one_hot(labels_train, num_classes).to(device)
y_test  = to_one_hot(labels_test,  num_classes).to(device)

Files already downloaded and verified
Files already downloaded and verified
CPU times: user 4min 2s, sys: 1min 11s, total: 5min 13s
Wall time: 55.4 s


In [9]:
y_train = np.load('../data/cifar10_targets_2.npz')['targets'][:N_train]

In [10]:
cifar_train = CustomTensorDataset(torch.tensor(X_train), torch.tensor(y_train).float(), transform='all')
train_loader = torch.utils.data.DataLoader(cifar_train, batch_size=1280 * 4)

  """Entry point for launching an IPython kernel.


## Kernel method

In [36]:
def boosting(estimator, train_loader, X_test, labels_test, learning_rate=1e5, n_iter=10):
    with torch.no_grad():
        batches_num = 0
        for _ in enumerate(train_loader):
            batches_num += 1
        n_batches = batches_num - 1 # (batches_num * 2) // 3 + 1
        
        for iter_num  in range(n_iter):
            print(f"iter {iter_num} ==========================")
            
            betas = []
            w_update = 0
            iter_start = time.time()
            
            for batch_i, (X, y) in enumerate(train_loader):
                if batch_i >= n_batches:
                    break
                
                X = X.to(device)
                y = y.to(device)
                
                kernel, y_pred = estimator.calc_kernel_pred(X)
                
                y_residual = y_pred - y
                
                train_acc = (y_pred.argmax(dim=1) == y.argmax(dim=1)).float().mean().item()
                train_mse = (y_residual ** 2).mean().item()
                print(f"batch {batch_i}: train_acc {train_acc:.4f}, train_mse {train_mse:.6f}")
                
                right_vector = calc_right_vector(kernel, y_residual, learning_rate=learning_rate)
                
                cur_w_update = estimator.get_w_update(X, right_vector)
                
                pred_change = torch.matmul(kernel, right_vector)
                cur_beta    = (- y_residual * pred_change).sum() / (pred_change ** 2).sum()
                
                w_update += cur_w_update
                betas.append(cur_beta.item())
                
            beta = np.average(betas)
            estimator.w += w_update * beta / n_batches
            
            y_pred = estimator.predict(X_test) 
            test_acc = (y_pred.argmax(dim=1) == labels_test.to(device)).float().mean().item()
            
            print(f"iter {iter_num} done. took {time.time() - iter_start:.0f}s test_acc {test_acc:.4f}")
            print()

In [31]:
n_models = 500

models = [Myrtle5(num_filters=32) for _ in range(n_models)]

estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device)


In [32]:
boosting(estimator, train_loader, X_test, labels_test, learning_rate=1e5)

batch 0: train_acc 0.0977, train_mse 7.183765
batch 1: train_acc 0.1037, train_mse 7.173288
batch 2: train_acc 0.0973, train_mse 7.196385
iter 0 done. took 99s test_acc 0.6870

batch 0: train_acc 0.7100, train_mse 3.731628
batch 1: train_acc 0.6975, train_mse 3.724422
batch 2: train_acc 0.7102, train_mse 3.665908
iter 1 done. took 102s test_acc 0.7150

batch 0: train_acc 0.7180, train_mse 3.717055
batch 1: train_acc 0.7289, train_mse 3.651026
batch 2: train_acc 0.7379, train_mse 3.604275
iter 2 done. took 98s test_acc 0.7050

batch 0: train_acc 0.7320, train_mse 3.682441
batch 1: train_acc 0.7352, train_mse 3.648643
batch 2: train_acc 0.7445, train_mse 3.613023
iter 3 done. took 103s test_acc 0.7030

batch 0: train_acc 0.7506, train_mse 3.654448
batch 1: train_acc 0.7479, train_mse 3.588690
batch 2: train_acc 0.7602, train_mse 3.535364
iter 4 done. took 97s test_acc 0.7140

batch 0: train_acc 0.7625, train_mse 3.591476
batch 1: train_acc 0.7438, train_mse 3.613229
batch 2: train_acc 0.

In [33]:
estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device)
boosting(estimator, train_loader, X_test, labels_test, learning_rate=1e4)

batch 0: train_acc 0.0977, train_mse 7.183765
batch 1: train_acc 0.1037, train_mse 7.173288
batch 2: train_acc 0.0973, train_mse 7.196385
iter 0 done. took 100s test_acc 0.6620

batch 0: train_acc 0.6648, train_mse 3.893153
batch 1: train_acc 0.6594, train_mse 3.899528
batch 2: train_acc 0.6680, train_mse 3.843274
iter 1 done. took 102s test_acc 0.6920

batch 0: train_acc 0.7027, train_mse 3.694586
batch 1: train_acc 0.7041, train_mse 3.693863
batch 2: train_acc 0.7086, train_mse 3.654493
iter 2 done. took 99s test_acc 0.6800

batch 0: train_acc 0.7195, train_mse 3.575631
batch 1: train_acc 0.7086, train_mse 3.580958
batch 2: train_acc 0.7141, train_mse 3.561848
iter 3 done. took 102s test_acc 0.7100

batch 0: train_acc 0.7357, train_mse 3.459083
batch 1: train_acc 0.7262, train_mse 3.465120
batch 2: train_acc 0.7371, train_mse 3.448418
iter 4 done. took 96s test_acc 0.6890

batch 0: train_acc 0.7461, train_mse 3.431979
batch 1: train_acc 0.7318, train_mse 3.432266
batch 2: train_acc 0

In [34]:
estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device)
boosting(estimator, train_loader, X_test, labels_test, learning_rate=1e6)

batch 0: train_acc 0.0977, train_mse 7.183765
batch 1: train_acc 0.1037, train_mse 7.173288
batch 2: train_acc 0.0973, train_mse 7.196385
iter 0 done. took 99s test_acc 0.6670

batch 0: train_acc 0.6859, train_mse 4.011244
batch 1: train_acc 0.6986, train_mse 3.945497
batch 2: train_acc 0.7074, train_mse 3.875035
iter 1 done. took 103s test_acc 0.6880

batch 0: train_acc 0.7076, train_mse 3.973394
batch 1: train_acc 0.7105, train_mse 3.957006
batch 2: train_acc 0.7211, train_mse 3.942075
iter 2 done. took 98s test_acc 0.6740

batch 0: train_acc 0.7281, train_mse 3.969994
batch 1: train_acc 0.7188, train_mse 3.965328
batch 2: train_acc 0.7355, train_mse 3.896935
iter 3 done. took 103s test_acc 0.6790

batch 0: train_acc 0.7303, train_mse 3.971407
batch 1: train_acc 0.7201, train_mse 3.975674
batch 2: train_acc 0.7227, train_mse 3.989176
iter 4 done. took 100s test_acc 0.6820

batch 0: train_acc 0.7400, train_mse 3.962332
batch 1: train_acc 0.7293, train_mse 3.992778
batch 2: train_acc 0

In [None]:
n_models = 5000

models = [Myrtle5(num_filters=32) for _ in range(n_models)]

estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device)

In [30]:
boosting(estimator, train_loader, X_test, labels_test)

batch 0: train_acc 0.0992, train_mse 1.000000
batch 1: train_acc 0.1023, train_mse 1.000000
batch 2: train_acc 0.0867, train_mse 1.000000
batch 3: train_acc 0.1023, train_mse 1.000000
batch 4: train_acc 0.1063, train_mse 1.000000
batch 5: train_acc 0.1055, train_mse 1.000000
batch 6: train_acc 0.1086, train_mse 1.000000
iter 0 done. test_acc 0.6610
batch 0: train_acc 0.6898, train_mse 0.228092
batch 1: train_acc 0.6898, train_mse 0.229243
batch 2: train_acc 0.6852, train_mse 0.231006
batch 3: train_acc 0.6664, train_mse 0.232013
batch 4: train_acc 0.6750, train_mse 0.231813
batch 5: train_acc 0.6703, train_mse 0.232756
batch 6: train_acc 0.6609, train_mse 0.234007
iter 1 done. test_acc 0.6930
batch 0: train_acc 0.7344, train_mse 0.212709
batch 1: train_acc 0.7438, train_mse 0.211246
batch 2: train_acc 0.7203, train_mse 0.213644
batch 3: train_acc 0.7297, train_mse 0.215032
batch 4: train_acc 0.7312, train_mse 0.209461
batch 5: train_acc 0.7203, train_mse 0.212636
batch 6: train_acc 0.7

In [11]:
cifar_train = CustomTensorDataset(torch.tensor(X_train), torch.tensor(labels_train, dtype=torch.long), transform=None)

  """Entry point for launching an IPython kernel.


In [12]:
n_models = 5000

models = [Myrtle5(num_filters=32) for _ in range(n_models)]

estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device)

In [13]:
boosting(estimator, train_loader, X_test, labels_test)

batch 0: train_acc 0.0992, train_mse 1.000000
batch 1: train_acc 0.1023, train_mse 1.000000
batch 2: train_acc 0.0867, train_mse 1.000000
batch 3: train_acc 0.1023, train_mse 1.000000
batch 4: train_acc 0.1063, train_mse 1.000000
batch 5: train_acc 0.1055, train_mse 1.000000
batch 6: train_acc 0.1086, train_mse 1.000000
iter 0 done. took 630s test_acc 0.6570

batch 0: train_acc 0.6945, train_mse 0.230021
batch 1: train_acc 0.6820, train_mse 0.229634
batch 2: train_acc 0.6617, train_mse 0.232192
batch 3: train_acc 0.6844, train_mse 0.231658
batch 4: train_acc 0.6617, train_mse 0.233073
batch 5: train_acc 0.6609, train_mse 0.234637
batch 6: train_acc 0.6672, train_mse 0.234089
iter 1 done. took 627s test_acc 0.6890

batch 0: train_acc 0.7398, train_mse 0.210167
batch 1: train_acc 0.7594, train_mse 0.208906
batch 2: train_acc 0.7469, train_mse 0.209019
batch 3: train_acc 0.7258, train_mse 0.212323
batch 4: train_acc 0.7227, train_mse 0.214496
batch 5: train_acc 0.7148, train_mse 0.212550


In [14]:
boosting(estimator, train_loader, X_test, labels_test, n_iter=5)

batch 0: train_acc 0.8664, train_mse 0.160979
batch 1: train_acc 0.8602, train_mse 0.162237
batch 2: train_acc 0.8539, train_mse 0.163230
batch 3: train_acc 0.8453, train_mse 0.167403
batch 4: train_acc 0.8555, train_mse 0.163886
batch 5: train_acc 0.8516, train_mse 0.164142
batch 6: train_acc 0.8484, train_mse 0.163351
iter 0 done. took 628s test_acc 0.7280

batch 0: train_acc 0.8641, train_mse 0.158279
batch 1: train_acc 0.8680, train_mse 0.157631
batch 2: train_acc 0.8727, train_mse 0.158165
batch 3: train_acc 0.8672, train_mse 0.160398
batch 4: train_acc 0.8695, train_mse 0.161058
batch 5: train_acc 0.8625, train_mse 0.159203
batch 6: train_acc 0.8688, train_mse 0.160781
iter 1 done. took 627s test_acc 0.7360

batch 0: train_acc 0.8781, train_mse 0.157369
batch 1: train_acc 0.8758, train_mse 0.155880
batch 2: train_acc 0.8727, train_mse 0.156660
batch 3: train_acc 0.8602, train_mse 0.157308
batch 4: train_acc 0.8578, train_mse 0.160059
batch 5: train_acc 0.8625, train_mse 0.159050


## Full Cifar

In [37]:
%%time

X_train, labels_train, X_test, labels_test = get_cifar_zca()

X_train      = torch.tensor(X_train).float()
labels_train = torch.tensor(labels_train, dtype=torch.long)

X_test       = torch.tensor(X_test).float()
labels_test  = torch.tensor(labels_test,  dtype=torch.long)

num_classes = 10

y_train = np.load('../data/cifar10_targets_2.npz')['targets']

Files already downloaded and verified
Files already downloaded and verified
CPU times: user 4min 24s, sys: 1min 35s, total: 5min 59s
Wall time: 1min 4s


In [38]:
cifar_train = CustomTensorDataset(torch.tensor(X_train), torch.tensor(y_train).float(), transform='all')
train_loader = torch.utils.data.DataLoader(cifar_train, batch_size=1280 * 4)

  """Entry point for launching an IPython kernel.


In [39]:
n_models = 5000

models = [Myrtle5(num_filters=32) for _ in range(n_models)]

estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device)

In [None]:
boosting(estimator, train_loader, X_test, labels_test, n_iter=32)

batch 0: train_acc 0.0977, train_mse 7.183765
batch 1: train_acc 0.1037, train_mse 7.173288
batch 2: train_acc 0.1029, train_mse 7.183653
batch 3: train_acc 0.1004, train_mse 7.181955
batch 4: train_acc 0.1016, train_mse 7.197064
batch 5: train_acc 0.1010, train_mse 7.201131
batch 6: train_acc 0.0965, train_mse 7.186081
batch 7: train_acc 0.1010, train_mse 7.178968
batch 8: train_acc 0.0977, train_mse 7.189648
iter 0 done. took 2568s test_acc 0.7608

batch 0: train_acc 0.7449, train_mse 3.208964
batch 1: train_acc 0.7357, train_mse 3.204618
batch 2: train_acc 0.7324, train_mse 3.207517
batch 3: train_acc 0.7365, train_mse 3.232342
batch 4: train_acc 0.7402, train_mse 3.166216
batch 5: train_acc 0.7441, train_mse 3.196871
batch 6: train_acc 0.7322, train_mse 3.264790
batch 7: train_acc 0.7424, train_mse 3.195215
batch 8: train_acc 0.7387, train_mse 3.228606
iter 1 done. took 2411s test_acc 0.7885

batch 0: train_acc 0.7838, train_mse 2.925961
batch 1: train_acc 0.7742, train_mse 2.94123

## Log Reg

In [32]:
n_models = 5000

models = [Myrtle5(num_filters=32) for _ in range(n_models)]

learning_rate = 0.2
estimator = GpEstimator(models, num_classes, learning_rate, X_train[0], device)

n_iter = 20

for iter_ in range(n_iter):
    print(f"iter {iter_} ======================= ")
    
    for batch_i, (X, labels) in enumerate(train_loader):
        batch_start = time.time()
        
        X      = X.to(device)
        labels = labels.to(device)
        
        y_pred = estimator.predict(X)
        train_acc = (y_pred.argmax(dim=1) == labels).float().mean().item()
        print(f"train_acc {train_acc}")
        
        y_pred.requires_grad = True

        loss = nn.CrossEntropyLoss()(y_pred, labels)
        loss.backward()
        grad = y_pred.grad
        
        estimator.w += -learning_rate * estimator.get_w_update(X, grad)
        
        print(f"batch_{batch_i} took {time.time() - batch_start:0}s")
        print()
        
    y_pred = estimator.predict(X_test)
    test_acc = (y_pred.argmax(dim=1) == labels_test.to(device)).float().mean().item()
    print(f"iter {iter_} done, test_acc = {test_acc}")

train_acc 0.09921874850988388
batch_0 took 57.80506658554077s

train_acc 0.09140624850988388
batch_1 took 55.33869671821594s

train_acc 0.1328125
batch_2 took 55.32820272445679s

train_acc 0.09296875447034836
batch_3 took 55.33759117126465s

train_acc 0.09375
batch_4 took 55.329872846603394s

train_acc 0.1171875
batch_5 took 55.33920168876648s

train_acc 0.09765625
batch_6 took 55.32954454421997s

train_acc 0.17734375596046448
batch_7 took 55.33659553527832s

train_acc 0.09375
batch_8 took 55.32691836357117s

train_acc 0.17499999701976776
batch_9 took 55.34224081039429s

iter 0 done, test_acc = 0.1120000034570694
train_acc 0.11015625298023224
batch_0 took 55.32840847969055s



KeyboardInterrupt: 

## Matrix exp

In [None]:
%%time

X_train, labels_train, X_test, labels_test = get_cifar_zca()

N = 1280

X_train      = torch.tensor(X_train[:N]).float()
X_test       = torch.tensor(X_test[:N]).float()
labels_train = torch.tensor(labels_train[:N], dtype=torch.long)
labels_test  = torch.tensor(labels_test[:N],  dtype=torch.long)

num_classes = 10

y_train = to_one_hot(labels_train, num_classes).to(device)
y_test  = to_one_hot(labels_test,  num_classes).to(device)

In [9]:
n_models = 5000

models = [Myrtle5(num_filters=32) for _ in range(n_models)]

estimator = GpEstimator(models, num_classes, 1e4, X_train[0], device)

In [10]:
%%time
train_kernel, test_kernel = estimator.calc_kernels(X_train, X_test)

right_vector = calc_right_vector(train_kernel, y_train, learning_rate=1e5, reg_param=0)
y_pred = torch.matmul(test_kernel, right_vector).argmax(dim=1)

test_acc = (y_pred.cpu() == labels_test).float().mean()
print(test_acc)
print()

tensor(0.6102)

CPU times: user 1min 16s, sys: 25.5 s, total: 1min 42s
Wall time: 1min 42s


In [11]:
%%time
w_update = estimator.get_w_update(X_train, right_vector)

print(w_update.size())
print()

torch.Size([5000, 32, 10])

CPU times: user 34.7 s, sys: 14.7 s, total: 49.4 s
Wall time: 49.4 s


In [12]:
estimator.w = w_update

In [13]:
%%time

y_pred = estimator.predict(X_test).argmax(dim=1)

test_acc = (y_pred.cpu() == labels_test).float().mean()
print(test_acc)
print()

tensor(0.6102)

CPU times: user 35.2 s, sys: 14.9 s, total: 50.1 s
Wall time: 50 s
