In [9]:
#!L

import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST


from linearized_nns.pytorch_impl.estimators import SgdEstimator
from linearized_nns.pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from linearized_nns.pytorch_impl import ClassifierTraining
from linearized_nns.pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from linearized_nns.pytorch_impl.nns.utils import to_one_hot, print_sizes
from linearized_nns.pytorch_impl.nns.primitives import Conv, Flatten, Normalize, ReLU2
from linearized_nns.from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [10]:
#!L
device = torch.device('cuda:0') if (torch.cuda.is_available()) else torch.device('cpu')
device

device(type='cuda', index=0)

In [11]:
#!L
def compute_kernels(models, X_train, X_test, device):
    with torch.no_grad():
        X_train = X_train.to(device)
        X_test  = X_test.to(device)

        n_train = len(X_train)
        n_test  = len(X_test)

        train_kernel = torch.zeros([n_train, n_train]).double().to(device)
        test_kernel  = torch.zeros([n_test,  n_train]).double().to(device)

        m = 0
        start_time = time.time()

        for model_i, model in enumerate(models):
            model = model.to(device)
            if model_i & (model_i - 1) == 0:
                print(f"{model_i} models done. time {time.time() - start_time:.0f}s")

            train_features = model.readout(X_train) 
            test_features  = model.readout(X_test)

            m += 1

            train_kernel += torch.matmul(train_features, train_features.T).double()
            test_kernel  += torch.matmul(test_features,  train_features.T).double()

        train_kernel /= m
        test_kernel  /= m

        return train_kernel.float(), test_kernel.float()

In [19]:
#!L
%%time

X_train, labels_train, X_test, labels_test = get_cifar_zca()

N = 1280

X_train      = torch.tensor(X_train[:N]).float()
X_test       = torch.tensor(X_test[:N]).float()
labels_train = torch.tensor(labels_train[:N], dtype=torch.long)
labels_test  = torch.tensor(labels_test[:N],  dtype=torch.long)

num_classes = 10

y_train = to_one_hot(labels_train, num_classes).to(device)
y_test  = to_one_hot(labels_test,  num_classes).to(device)

Files already downloaded and verified
Files already downloaded and verified
CPU times: user 3min 9s, sys: 1min 27s, total: 4min 37s
Wall time: 51.9 s


In [25]:
#!L
%state_exclude models

torch.manual_seed(0)
np.random.seed(0)

n_models = 500

# 500 * 32 = 160k

models = [Myrtle10(num_filters=1, groups=50) for _ in range(n_models)]
n_models

500

In [27]:
#!L

train_kernel, test_kernel = compute_kernels(models, X_train, X_test, device)

train_kernel = train_kernel.float().to(device)
test_kernel  = test_kernel.float().to(device)

0 models done. time 0s
1 models done. time 0s
2 models done. time 0s
4 models done. time 0s
8 models done. time 0s
16 models done. time 18s
32 models done. time 77s
64 models done. time 194s
128 models done. time 428s
256 models done. time 897s


In [28]:
#!L
train_kernel.size(), test_kernel.size()

(torch.Size([1280, 1280]), torch.Size([1280, 1280]))

In [29]:
#!L
lr = 1e5

n = len(train_kernel)
reg = torch.eye(n).to(device) * 0e-7

exp_term = - lr * compute_exp_term(- lr * (train_kernel + reg), device)
y_pred = torch.matmul(test_kernel, torch.matmul(exp_term, - y_train))
(y_pred.argmax(dim=1) == labels_test.to(device)).float().mean()

tensor(0.6672, device='cuda:0')