In [1]:
import os
import shutil
import zipfile
import urllib.request

def download_repo(url, save_to):
    zip_filename = save_to + '.zip'
    urllib.request.urlretrieve(url, zip_filename)
    
    if os.path.exists(save_to):
        shutil.rmtree(save_to)
    with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
        zip_ref.extractall('.')
    del zip_ref
    assert os.path.exists(save_to)

In [2]:
REPO_PATH = 'LinearizedNNs-master'

# download_repo(url='https://github.com/maxkvant/LinearizedNNs/archive/master.zip', save_to=REPO_PATH)

In [3]:
import sys
sys.path.append(f"{REPO_PATH}/src")

In [4]:
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

from estimator import Estimator
from pytorch_impl.estimators import SgdEstimator
from pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from pytorch_impl import ClassifierTraining
from pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from pytorch_impl.nns.utils import to_one_hot, print_sizes
from from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [5]:
device = torch.device('cuda:0') if (torch.cuda.is_available()) else torch.device('cpu')
device

device(type='cuda', index=0)

In [6]:
kernels_12k = np.load('../data/kernels_12k.npz')
list(kernels_12k)

['train_kernel',
 'test_kernel',
 'test_1_kernel',
 'labels_train',
 'labels_test',
 'labels_test_1']

In [7]:
train_kernel = kernels_12k['train_kernel']

N = len(train_kernel)
N

12800

In [81]:
class GpEstimator(Estimator):
    def __init__(self, models, n_classes, learning_rate, x_example, device):
        super(GpEstimator, self).__init__()
        self.models    = models
        self.lr        = learning_rate
        self.n_classes = n_classes
        self.device    = device
        
        n = len(models)
        X = torch.stack([x_example]).to(device)
        
        model = models[0].to(device)
        readout_size = model.readout(X).size()[1]

        
        self.w      = torch.zeros([n, readout_size, n_classes])
        self.w_size = n * readout_size
        
    def get_w_update(self, X, kernel, y):
        with torch.no_grad():
            assert len(X) == len(kernel)

            X      = X.to(self.device) 
            y      = y.to(self.device)
            kernel = kernel.to(self.device)

            n = len(X)
            w_updates = []

            exp_term = - self.lr * compute_exp_term(- self.lr * kernel, device)
            right_vector = torch.matmul(exp_term, - y_train)
            
            for model in models:
                features = self.to_model_features(X, model)
                update = torch.matmul(features.T, right_vector)
            
                w_updates.append(update)
            return torch.stack(w_updates)
                                  
    def to_model_features(self, X, model):
        with torch.no_grad():
            model = model.to(device)
            return model.readout(X) * (1. / np.sqrt(self.w_size))
        
    def calc_kernel(self, X):
        with torch.no_grad():
            X = X.to(device)
            n = len(X)
            res = torch.zeros([n, n]).to(device)
            for model in models:
                features = self.to_model_features(X, model)
                res += torch.matmul(features, features.T)
            return res
            
    def predict(self, X):
        X = X.to(self.device)
        with torch.no_grad():
            n = len(X)
            res = torch.zeros([n, self.n_classes]).to(device)

            for model, w in zip(models, self.w):
                features = self.to_model_features(X, model)
                res += torch.matmul(features, w)
            return res

In [53]:
%%time

X_train, labels_train, X_test, labels_test = get_cifar_zca()

Files already downloaded and verified
Files already downloaded and verified
CPU times: user 3min 51s, sys: 1min 2s, total: 4min 53s
Wall time: 43.9 s


In [54]:
N = 1280

X_train      = torch.tensor(X_train[:N]).float()
X_test       = torch.tensor(X_test[:N]).float()
labels_train = torch.tensor(labels_train[:N], dtype=torch.long)
labels_test  = torch.tensor(labels_test[:N],  dtype=torch.long)

num_classes = 10

y_train = to_one_hot(labels_train, num_classes).to(device)
y_test  = to_one_hot(labels_test,  num_classes).to(device)

In [55]:
train_kernel = torch.tensor(kernels_12k['train_kernel'][:N,:N])
train_kernel.size()

torch.Size([1280, 1280])

In [89]:
n_models = 5000

models = [Myrtle10(num_filters=32) for _ in range(n_models)]

In [90]:
estimator = GpEstimator(models, num_classes, 1e4, X_train[0], device)

In [91]:
w_update = estimator.get_w_update(X_train, estimator.calc_kernel(X_train), y_train)

In [92]:
w_update.size()

torch.Size([5000, 32, 10])

In [93]:
estimator.w = w_update

In [94]:
y_pred = estimator.predict(X_test).argmax(dim=1)

In [95]:
(y_pred.cpu() == labels_test).float().mean()

tensor(0.5898)