In [31]:
#!g1.1
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt
import random
from sklearn.utils.extmath import randomized_svd


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch import utils
from torch.utils.data import TensorDataset, DataLoader

from linearized_nns.estimator import Estimator
from linearized_nns.pytorch_impl.estimators import SgdEstimator
from linearized_nns.pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from linearized_nns.pytorch_impl import ClassifierTraining
from linearized_nns.pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from linearized_nns.pytorch_impl.nns.utils import to_one_hot, print_sizes
from linearized_nns.from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca
from sklearn.utils.extmath import randomized_svd

from sklearn.linear_model import LogisticRegression, SGDClassifier, RidgeClassifier

import tqdm

In [2]:
#!g1.1
DEVICE = 'cuda'
NUM_CLASSES = 200
TRAIN_DIR = 'tiny-imagenet-200/train'
TEST_DIR  = 'tiny-imagenet-200/val'

In [3]:
#!g1.1
def to_zca_faster(train, test, zca_bias=0.01, n_components=1000):
    print('faster zca')
    orig_train_shape = train.shape
    orig_test_shape = test.shape

    train = np.ascontiguousarray(train, dtype=np.float32).reshape(train.shape[0], -1).astype(np.float64)
    test  = np.ascontiguousarray(test, dtype=np.float32).reshape(test.shape[0], -1).astype(np.float64)

    n_train = train.shape[0]
    d       = train.shape[1]

    # Zero mean every feature
    train = train - np.mean(train, axis=1)[:,np.newaxis]
    test  =  test - np.mean(test, axis=1)[:,np.newaxis]

    # Normalize
    train_norms = np.linalg.norm(train, axis=1)
    test_norms  = np.linalg.norm(test, axis=1)

    # Make features unit norm
    train = train/train_norms[:,np.newaxis]
    test  = test/test_norms[:,np.newaxis]
    
    print(train[0])
    
    print('calculating svd...')
    
    _, S, V = randomized_svd(train, n_components=n_components)
    print('svd calculated. applying svd matricies.')
    S = (S ** 2) / n_train + zca_bias
    inv_sqrt_zca_eigs = np.diag(np.power(S, -1))

    train = (train).dot(V.T).dot(inv_sqrt_zca_eigs).dot(V)
    test  =  (test).dot(V.T).dot(inv_sqrt_zca_eigs).dot(V)
    
    print('faster zca done.')

    return (train.reshape(orig_train_shape).astype(np.float64), test.reshape(orig_test_shape).astype(np.float64))


In [4]:
#!g1.1

def get_tiny_imagenet_zca(train_dir, test_dir):
    print('getting tiny-imagenet zca:')
    np.random.seed(0)
    random.seed(0)
    torch.manual_seed(0)
    
    print('loading datasets...')
    
    trainset = datasets.ImageFolder(train_dir, transform=transforms.ToTensor()) 
    testset  = datasets.ImageFolder(test_dir, transform=transforms.ToTensor())
    
    trainloader = utils.data.DataLoader(trainset, pin_memory=True, shuffle=True, batch_size=100000)
    testloader  = utils.data.DataLoader(testset,  pin_memory=True, shuffle=False, batch_size=10000)
    
    _, (X_train, y_train)  = next(enumerate(trainloader))
    _, (X_test,  y_test)   = next(enumerate(testloader))
    
    print('datasets loaded.')
    
    X_train = X_train.numpy().astype(np.float64)
    X_test  = X_test.numpy().astype(np.float64)
    
    print('calculating zca...')
    X_train, X_test = to_zca_faster(X_train, X_test, n_components=3000)
    print('zca_calulated.')
    return torch.tensor(X_train), y_train, torch.tensor(X_test), y_test

In [5]:
#!g1.1

X_train_full, labels_train_full, X_test_full, labels_test_full = get_tiny_imagenet_zca(TRAIN_DIR, TEST_DIR)

getting tiny-imagenet zca:
loading datasets...
datasets loaded.
calculating zca...
faster zca
[-0.00336678 -0.00447899 -0.00429362 ... -0.01170834 -0.01263518
 -0.01393276]
calculating svd...
getting tiny-imagenet zca:
loading datasets...
datasets loaded.
calculating zca...
faster zca
[-0.00336678 -0.00447899 -0.00429362 ... -0.01170834 -0.01263518
 -0.01393276]
calculating svd...
svd calculated. applying svd matricies.
faster zca done.
zca_calulated.


In [6]:
#!g1.1

from linearized_nns.pytorch_impl.nns.primitives import *

class Myrtle9(nn.Module):
    def __init__(self, num_classes=1, input_filters=3, num_filters=1, groups=1):
        super(Myrtle9, self).__init__()
        filters = num_filters

        def Activation():
            return ReLU2()

        self.layers = nn.Sequential(
            Conv(input_filters, filters * groups), Activation(),
            Conv(filters, filters * 2, groups),    Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),

            Conv(filters * 2, filters * 4, groups), Activation(),
            Conv(filters * 4, filters * 8, groups), Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),

            Conv(filters *  8, filters * 16, groups), Activation(),
            Conv(filters * 16, filters * 32, groups), Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            
            Conv(filters *  32, filters * 32, groups), Activation(),
            Conv(filters *  32, filters * 32, groups), Activation(),
            nn.AvgPool2d(kernel_size=8, stride=8),

            Flatten(),
            Normalize(filters * 32)
        )
        self.classifier = nn.Linear(filters * 32 * groups, num_classes, bias=True)

    def readout(self, x):
        return self.layers(x)

    def forward(self, x):
        x = self.readout(x)
        return self.classifier(x)

In [7]:
#!g1.1

N_train = 1280 * 5
N_test  = 1000

X_train = X_train_full[:N_train].float()
X_test  = X_test_full[:N_test].float()

labels_train = labels_train_full[:N_train] 
labels_test  = labels_test_full[:N_test]

y_train = to_one_hot(labels_train, NUM_CLASSES).to(DEVICE)
y_test  = to_one_hot(labels_test,  NUM_CLASSES).to(DEVICE)

In [19]:
#!g1.1
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

n_models = 50
models = [Myrtle9(num_filters=1, groups=10) for _ in range(n_models)]

def to_rf(X, models):
    X_rf = []
    X = X.to(DEVICE)

    for model_i, model in enumerate(models):
        with torch.no_grad():
            torch.cuda.empty_cache() 
            if model_i & (model_i - 1) == 0:
                print(f"{model_i} models done.")
            model = model.to(DEVICE)

            features = model.readout(X).cpu()
            X_rf.append(features)
            
    return torch.cat(X_rf, dim=1)

X_train_rf = to_rf(X_train, models)
X_test_rf  = to_rf(X_test, models)

0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


In [20]:
#!g1.1
train_kernel = torch.mm(X_train_rf, X_train_rf.T)

In [21]:
#!g1.1
n = len(X_train_rf)

U, D, Ut = randomized_svd(train_kernel.cpu().numpy(), n_components=n)
print(U.shape, D.shape, Ut.shape)

(6400, 6400) (6400,) (6400, 6400)


In [22]:
#!g1.1
U = torch.tensor(U)

In [23]:
#!g1.1

VT = torch.mm(U.T, X_train_rf)
VT.shape

torch.Size([6400, 160000])

In [62]:
#!g1.1
eps = 0.01

VT /= VT.norm(dim=1, keepdim=True)
V = VT.T

In [63]:
#!g1.1

V.shape

torch.Size([160000, 6400])

In [119]:
#!g1.1
N = len(X_train_full)

def to_pca(X, V_pca, models, batch_size = 1000):
    N = len(X)
    batch_size = 1000
    X_pca = []
    for l in tqdm.tqdm(range(0, N, batch_size)):
        r = min(N, l + batch_size)
        X_batch = X[l:r]
        X_batch_rf = to_rf(X_batch.float(), models)
        X_batch_pca = torch.mm(X_batch_rf, V_pca)
        print(X_batch_pca.shape)
        X_pca.append(X_batch_pca)
    return torch.cat(X_pca, dim=0)

X_train_full_pca = to_pca(X_train_full, V, models)
X_test_full_pca  = to_pca(X_test_full, V, models)
X_train_full_pca.shape, X_test_full_pca.shape

 88%|████████▊ | 88/100 [34:10<04:38, 23.20s/it]

0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.
torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.
torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.
torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.
torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.
torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.

 89%|████████▉ | 89/100 [34:33<04:15, 23.23s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 90%|█████████ | 90/100 [34:56<03:52, 23.26s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 91%|█████████ | 91/100 [35:20<03:29, 23.32s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 92%|█████████▏| 92/100 [35:43<03:06, 23.27s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 93%|█████████▎| 93/100 [36:06<02:42, 23.26s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 94%|█████████▍| 94/100 [36:29<02:19, 23.22s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 95%|█████████▌| 95/100 [36:53<01:57, 23.44s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 96%|█████████▌| 96/100 [37:16<01:33, 23.36s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 97%|█████████▋| 97/100 [37:40<01:10, 23.36s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 98%|█████████▊| 98/100 [38:03<00:46, 23.37s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 99%|█████████▉| 99/100 [38:26<00:23, 23.35s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


100%|██████████| 100/100 [38:50<00:00, 23.30s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 10%|█         | 1/10 [00:23<03:28, 23.20s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 20%|██        | 2/10 [00:46<03:05, 23.25s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 30%|███       | 3/10 [01:09<02:42, 23.22s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 40%|████      | 4/10 [01:33<02:19, 23.25s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 50%|█████     | 5/10 [01:56<01:56, 23.30s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 60%|██████    | 6/10 [02:20<01:34, 23.54s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 70%|███████   | 7/10 [02:43<01:10, 23.47s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 80%|████████  | 8/10 [03:07<00:46, 23.44s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.


 90%|█████████ | 9/10 [03:30<00:23, 23.39s/it]

torch.Size([1000, 6400])
0 models done.
1 models done.
2 models done.
4 models done.
8 models done.
16 models done.
32 models done.
64 models done.
128 models done.
256 models done.
torch.Size([1000, 6400])


100%|██████████| 10/10 [03:53<00:00, 23.39s/it]


(torch.Size([100000, 6400]), torch.Size([10000, 6400]))

In [122]:
#!g1.1
from sklearn.linear_model import LogisticRegression, SGDClassifier, RidgeClassifier

for alpha in [1.,   1e-1, 1e-2, 
              1e-3, 1e-4, 1e-5, 
              1e-6, 1e-7, 1e-8]:
    classifier = RidgeClassifier(alpha=alpha)
    classifier.fit(X_train_full_pca.numpy(), labels_train_full.numpy())
    y_pred = classifier.predict(X_test_full_pca.numpy())
    acc = np.average(abs(y_pred - labels_test_full.numpy()) == 0)
    print(f'alpha={alpha} test acc: {acc}')

alpha=1.0 test acc: 0.3707
alpha=0.1 test acc: 0.3716
alpha=0.01 test acc: 0.3717
alpha=0.001 test acc: 0.3717
alpha=0.0001 test acc: 0.3717
alpha=1e-05 test acc: 0.3718
alpha=1e-06 test acc: 0.3718
alpha=1e-07 test acc: 0.3718
alpha=1e-08 test acc: 0.3718
