In [160]:
#!g1.1
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt
import random
import ipyplot
from sklearn.utils.extmath import randomized_svd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch import utils

from linearized_nns.estimator import Estimator
from linearized_nns.pytorch_impl.estimators import SgdEstimator
from linearized_nns.pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from linearized_nns.pytorch_impl import ClassifierTraining
from linearized_nns.pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from linearized_nns.pytorch_impl.nns.utils import to_one_hot, print_sizes
from linearized_nns.from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [208]:
#!g1.1
DEVICE = 'cuda'
NUM_CLASSES = 200

In [182]:
#!g1.1
def to_zca(train, test, zca_bias=0.0001):
    orig_train_shape = train.shape
    orig_test_shape = test.shape

    train = np.ascontiguousarray(train, dtype=np.float32).reshape(train.shape[0], -1).astype(np.float64)
    test = np.ascontiguousarray(test, dtype=np.float32).reshape(test.shape[0], -1).astype(np.float64)

    n_train = train.shape[0]

    # Zero mean every feature
    train = train - np.mean(train, axis=1)[:,np.newaxis]
    test = test - np.mean(test, axis=1)[:,np.newaxis]

    # Normalize
    train_norms = np.linalg.norm(train, axis=1)
    test_norms = np.linalg.norm(test, axis=1)

    # Make features unit norm
    train = train/train_norms[:,np.newaxis]
    test = test/test_norms[:,np.newaxis]

    train = torch.to_tensor(train, dtype=torch.double).to(DEVICE)
    train_cov_mat = 1.0/n_train * torch.matmul(train.T, train)

    (E,V) = np.linalg.eig(train_cov_mat)

    E += zca_bias
    sqrt_zca_eigs = np.sqrt(E)
    inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1))
    global_ZCA = V.dot(inv_sqrt_zca_eigs).dot(V.T)

    train = (train).dot(global_ZCA)
    test = (test).dot(global_ZCA)

    return (train.reshape(orig_train_shape).astype(np.float64), test.reshape(orig_test_shape).astype(np.float64)), global_ZCA

In [183]:
#!g1.1
def get_cifar_zca():
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
    testset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=None)

    X_train = np.asarray(trainset.data).astype(np.float64)
    y_train = np.asarray(trainset.targets)
    X_test  = np.asarray(testset.data).astype(np.float64)
    y_test  = np.asarray(testset.targets)

    (X_train, X_test), global_ZCA = to_zca(X_train, X_test)

    X_train = np.transpose(X_train, (0,3,1,2))
    X_test  = np.transpose(X_test,  (0,3,1,2))

    return X_train, y_train, X_test, y_test

In [184]:
#!g1.1
%%time
X_train, _, X_test, _ = get_cifar_zca()

In [185]:
#!g1.1
X_train[0].shape

In [186]:
X_train[0]

In [187]:
#!g1.1

for i in range(10):
    image = np.transpose(X_train[i], (1,2,0))
    plt.imshow(image)
    plt.show()

In [None]:
#!g1.1
from sklearn.utils.extmath import randomized_svd

def to_zca_fast(train, test, zca_bias=0.01, n_components=1000):
    print('fast zca')
    orig_train_shape = train.shape
    orig_test_shape = test.shape

    train = np.ascontiguousarray(tracccccckdbvnbfhthicnccrtlnjfdhrfrjkkfnidkrlbl
                                 t, dtype=np.float32).reshape(test.shape[0], -1).astype(np.float64)

    n_train = train.shape[0]

    # Zero mean every feature
    train = train - np.mean(train, axis=1)[:,np.newaxis]
    test  =  test - np.mean(test, axis=1)[:,np.newaxis]

    # Normalize
    train_norms = np.linalg.norm(train, axis=1)
    test_norms = np.linalg.norm(test, axis=1)

    # Make features unit norm
    train = train/train_norms[:,np.newaxis]
    test = test/test_norms[:,np.newaxis]

    train_torch = torch.tensor(train, dtype=torch.float).to(DEVICE)
    train_cov_mat = 1.0/n_train * torch.matmul(train_torch.T, train_torch).cpu().numpy()
    del train_torch

    U, S, Vt = randomized_svd(train_cov_mat, n_components=n_components)
    
    V = Vt.T
    S += zca_bias
    sqrt_zca_eigs = S
    inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1))
    global_ZCA = V.dot(inv_sqrt_zca_eigs).dot(V.T)

    train = (train).dot(global_ZCA)
    test = (test).dot(global_ZCA)

    return (train.reshape(orig_train_shape).astype(np.float64), test.reshape(orig_test_shape).astype(np.float64)), global_ZCA

In [316]:
#!g1.1

def to_zca_faster(train, test, zca_bias=0.01, n_components=1000):
    print('faster zca')
    orig_train_shape = train.shape
    orig_test_shape = test.shape

    train = np.ascontiguousarray(train, dtype=np.float32).reshape(train.shape[0], -1).astype(np.float64)
    test  = np.ascontiguousarray(test, dtype=np.float32).reshape(test.shape[0], -1).astype(np.float64)

    n_train = train.shape[0]
    d       = train.shape[1]

    # Zero mean every feature
    train = train - np.mean(train, axis=1)[:,np.newaxis]
    test  =  test - np.mean(test, axis=1)[:,np.newaxis]

    # Normalize
    train_norms = np.linalg.norm(train, axis=1)
    test_norms  = np.linalg.norm(test, axis=1)

    # Make features unit norm
    train = train/train_norms[:,np.newaxis]
    test  = test/test_norms[:,np.newaxis]
    
    print(train[0])
    
    print('calculating svd...')
    
    _, S, V = randomized_svd(train, n_components=n_components)
    print('svd calculated. applying svd matricies.')
    S = (S ** 2) / n_train + zca_bias
    inv_sqrt_zca_eigs = np.diag(np.power(S, -1))

    train = (train).dot(V.T).dot(inv_sqrt_zca_eigs).dot(V)
    test  =  (test).dot(V.T).dot(inv_sqrt_zca_eigs).dot(V)
    
    print('faster zca done.')

    return (train.reshape(orig_train_shape).astype(np.float64), test.reshape(orig_test_shape).astype(np.float64))

def get_cifar_zca_fast():
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
    testset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=None)

    X_train = np.asarray(trainset.data).astype(np.float64)
    y_train = np.asarray(trainset.targets)
    X_test  = np.asarray(testset.data).astype(np.float64)
    y_test  = np.asarray(testset.targets)

    X_train, X_test = to_zca_faster(X_train, X_test, n_components=1000)

    X_train = np.transpose(X_train, (0,3,1,2))
    X_test  = np.transpose(X_test,  (0,3,1,2))

    return X_train, y_train, X_test, y_test

In [317]:
#!g1.1
%%time

X_train, _, X_test, _ = get_cifar_zca_fast()

In [318]:
#!g1.1
X_train[0]

In [319]:
#!g1.1
for i in range(10):
    image = np.transpose(X_train[i], (1,2,0))
    plt.imshow(image)
    plt.show()

In [292]:
#!g1.1
def get_tiny_imagenet_zca(train_dir, test_dir):
    print('getting tiny-imagenet zca:')
    np.random.seed(0)
    random.seed(0)
    torch.manual_seed(0)
    
    print('loading datasets...')
    
    trainset = datasets.ImageFolder(train_dir, transform=transforms.ToTensor()) 
    testset  = datasets.ImageFolder(test_dir, transform=transforms.ToTensor())
    
    trainloader = utils.data.DataLoader(trainset, pin_memory=True, shuffle=True, batch_size=100000)
    testloader  = utils.data.DataLoader(testset,  pin_memory=True, shuffle=False, batch_size=10000)
    
    _, (X_train, y_train)  = next(enumerate(trainloader))
    _, (X_test,  y_test)   = next(enumerate(testloader))
    
    print('datasets loaded.')
    
    X_train = X_train.numpy().astype(np.float64)
    X_test  = X_test.numpy().astype(np.float64)
    
    print('calculating zca...')
    X_train, X_test = to_zca_faster(X_train, X_test, n_components=3000)
    print('zca_calulated.')
    return torch.tensor(X_train), y_train, torch.tensor(X_test), y_test

TRAIN_DIR = 'tiny-imagenet-200/train'
TEST_DIR  = 'tiny-imagenet-200/val'

X_train_full, labels_train_full, X_test_full, labels_test_full = get_tiny_imagenet_zca(TRAIN_DIR, TEST_DIR)

In [293]:
#!g1.1
for i in range(10):
    image = np.transpose(X_train_full[i], (1,2,0))
    plt.imshow(image)
    plt.show()

In [294]:
#!g1.1
from linearized_nns.pytorch_impl.nns.primitives import *

class Myrtle9(nn.Module):
    def __init__(self, num_classes=1, input_filters=3, num_filters=1, groups=1):
        super(Myrtle9, self).__init__()
        filters = num_filters

        def Activation():
            return ReLU2()

        self.layers = nn.Sequential(
            Conv(input_filters, filters * groups), Activation(),
            Conv(filters, filters * 2, groups),    Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),

            Conv(filters * 2, filters * 4, groups), Activation(),
            Conv(filters * 4, filters * 8, groups), Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),

            Conv(filters *  8, filters * 16, groups), Activation(),
            Conv(filters * 16, filters * 32, groups), Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            
            Conv(filters *  32, filters * 32, groups), Activation(),
            Conv(filters *  32, filters * 32, groups), Activation(),
            nn.AvgPool2d(kernel_size=8, stride=8),

            Flatten(),
            Normalize(filters * 32)
        )
        self.classifier = nn.Linear(filters * 32 * groups, num_classes, bias=True)

    def readout(self, x):
        return self.layers(x)

    def forward(self, x):
        x = self.readout(x)
        return self.classifier(x)

In [295]:
#!g1.1
N_train = 1280 * 5
N_test  = 1000

X_train = X_train_full[:N_train].float()
X_test  = X_test_full[:N_test].float()

In [305]:
#!g1.1
labels_train = labels_train_full[:N_train] 
labels_test  = labels_test_full[:N_test]

y_train = to_one_hot(labels_train, NUM_CLASSES).to(DEVICE)
y_test  = to_one_hot(labels_test,  NUM_CLASSES).to(DEVICE)

In [306]:
#!g1.1
def compute_kernels(models, X_train, X_test):
    with torch.no_grad():
        X_train = X_train.to(DEVICE)
        X_test  = X_test.to(DEVICE)

        n_train = len(X_train)
        n_test  = len(X_test)

        train_kernel = torch.zeros([n_train, n_train]).double().to(DEVICE)
        test_kernel  = torch.zeros([n_test,  n_train]).double().to(DEVICE)

        m = 0
        start_time = time.time()

        for model_i, model in enumerate(models):
            model = model.to(DEVICE)
            if model_i & (model_i - 1) == 0:
                print(f"{model_i} models done. time {time.time() - start_time:.0f}s")

            train_features = model.readout(X_train) 
            test_features  = model.readout(X_test)

            m += 1

            train_kernel += torch.matmul(train_features, train_features.T).double()
            test_kernel  += torch.matmul(test_features,  train_features.T).double()

        train_kernel /= m
        test_kernel  /= m

        return train_kernel.float(), test_kernel.float()

In [307]:
#!g1.1
X_train.shape, X_test.shape

In [308]:
#!g1.1

n_models = 500
models = [Myrtle9(num_filters=1, groups=10) for _ in range(n_models)]

train_kernel, test_kernel = compute_kernels(models, X_train, X_test)

train_kernel = train_kernel.float().to(DEVICE)
test_kernel  = test_kernel.float().to(DEVICE)

In [309]:
#!g1.1
train_kernel[:5,:5]

In [310]:
#!g1.1
test_kernel[:5,:5]

In [311]:
#!g1.1
lr = 1e5

n = len(train_kernel)

exp_term = - lr * compute_exp_term(- lr * train_kernel, DEVICE)
y_pred = torch.matmul(test_kernel, torch.matmul(exp_term, - y_train))
(y_pred.argmax(dim=1) == labels_test.to(DEVICE)).float().mean()

In [99]:
#!g1.1
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

from linearized_nns.estimator import Estimator
from linearized_nns.pytorch_impl.estimators import SgdEstimator
from linearized_nns.pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from linearized_nns.pytorch_impl import ClassifierTraining
from linearized_nns.pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from linearized_nns.pytorch_impl.nns.utils import to_one_hot, print_sizes
from linearized_nns.from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [116]:
#!g1.1
class GpEstimator(Estimator):
    def __init__(self, models, n_classes, learning_rate, x_example, device, groups=1):
        super(GpEstimator, self).__init__()
        self.models    = [model.to(device) for model in models]
        self.lr        = learning_rate
        self.n_classes = n_classes
        self.device    = device
        
        n = len(models)
        X = torch.stack([x_example]).to(self.device)
        
        model = models[0].to(self.device)
        readout_size = model.readout(X).size()[1]
    
        # TODO: Assert that models have the same readout size
        
        self.w      = torch.zeros([n, readout_size, n_classes]).to(self.device)
        self.w_size = n * groups
        
    def get_w_update(self, X, right_vector):
        with torch.no_grad():
            assert len(X) == len(right_vector)

            X            = X.to(self.device)
            right_vector = right_vector.to(self.device)

            n = len(X)
            w_updates = []
            
            for model in self.models:
                features = self.to_model_features(X, model)
                update = torch.matmul(features.T, right_vector)
                w_updates.append(update)
            return torch.stack(w_updates)
                                  
    def to_model_features(self, X, model):
        with torch.no_grad():
            model = model.to(self.device)
            return model.readout(X) * (1. / np.sqrt(self.w_size))
        
    def calc_kernel(self, X):
        with torch.no_grad():
            X = X.to(self.device)
            
            res = torch.zeros([len(X), len(X)]).to(self.device).double()
            for model in self.models:
                features_x = self.to_model_features(X, model)
                
                res += torch.matmul(features_x, features_x.T).double()
            return res.float()
        
    def calc_kernel_pred(self, X):
        with torch.no_grad():
            X = X.to(self.device)
            
            n = len(X)
            y_pred = torch.zeros([n, self.n_classes]).to(self.device).double()
            kernel = torch.zeros([len(X), len(X)]).to(self.device).double()
            
            for model, w in zip(self.models, self.w):
                features = self.to_model_features(X, model)
                
                kernel += torch.matmul(features, features.T).double()
                y_pred += torch.matmul(features, w).double()
            return kernel, y_pred.float()
        
    def calc_kernels(self, X_train, X_test):
        with torch.no_grad():
            X_train = X_train.to(self.device)
            X_test  = X_test.to(self.device)
            
            res_train = torch.zeros([len(X_train), len(X_train)]).to(self.device)
            res_test  = torch.zeros([len(X_test),  len(X_train)]).to(self.device)
            for model in self.models:
                features_train = self.to_model_features(X_train, model)
                features_test  = self.to_model_features(X_test,  model)
                
                res_train += torch.matmul(features_train, features_train.T)
                res_test  += torch.matmul(features_test,  features_train.T)
            return res_train, res_test
            
    def predict(self, X, cur_w=None):
        X = X.to(self.device)
        if cur_w is None:
            cur_w = self.w
         
        with torch.no_grad():
            n = len(X)
            res = torch.zeros([n, self.n_classes]).double().to(self.device)

            for model, w in zip(self.models, cur_w):
                features = self.to_model_features(X, model)
                res += torch.matmul(features, w).double()
            return res.float()

In [101]:
#!g1.1

def calc_right_vector(kernel, y, learning_rate=1e5, reg_param=0):
    with torch.no_grad():
        y      = y.to(DEVICE)
        kernel = kernel.to(DEVICE)
        
        n      = len(kernel)
        reg = torch.eye(n).to(DEVICE) * reg_param
        
        exp_term = - learning_rate * compute_exp_term(- learning_rate * (kernel + reg), DEVICE)
        right_vector = torch.matmul(exp_term, - y)
        return right_vector

In [102]:
#!g1.1

labels_train_full.shape, labels_test_full.shape

In [119]:
#!g1.1
def boosting(estimator, train_loader, test_loader, learning_rate=1e5, beta=1., noise_rate=0., n_iter=10):
    output_kernel = False
    
    with torch.no_grad():
        batches_num = 0
        for _ in enumerate(train_loader):
            batches_num += 1
        n_batches = (batches_num * 2) // 3 + 1
            
        for iter_num  in range(n_iter):
            print(f"iter {iter_num} ==========================")
            
            w_update = 0
            iter_start = time.time()
            
            for batch_i, (X, y) in enumerate(train_loader):
                if batch_i >= n_batches:
                    break
                
                X = X.to(DEVICE)
                y = y.to(DEVICE)
                
                kernel, y_pred = estimator.calc_kernel_pred(X)
                if output_kernel:
                    print(f"kernel\n{kernel[:5,:5]}")
                
                y_residual = y_pred - y
                
                train_acc = (y_pred.argmax(dim=1) == y.argmax(dim=1)).float().mean().item()
                train_mse = (y_residual ** 2).mean().item()
                print(f"batch {batch_i}: train_acc {train_acc:.4f}, train_mse {train_mse:.6f}")
                
                right_vector = calc_right_vector(kernel, y_residual, learning_rate=learning_rate)
                
                w_update += estimator.get_w_update(X, right_vector).double()
                
                pred_change = torch.matmul(kernel, right_vector.double())
                

            w_update = (w_update / n_batches).float()
            
            _, (X, y) = next(enumerate(train_loader))
            y = y.to(DEVICE)
            
            y_pred      = estimator.predict(X)
            pred_change = estimator.predict(X, w_update)
            
            y_residual = (y_pred - y)
            
            w_std = w_update.std()
            noise = w_std * torch.randn(w_update.size()).to(DEVICE)
            estimator.w -= (w_update + noise_rate * noise) * beta 
            
            test_acc = 0
            for i, (X_test, labels) in enumerate(test_loader):
                y_pred = estimator.predict(X_test) 
                cur_acc   = (y_pred.argmax(dim=1) == labels.to(DEVICE)).float().mean().item()
                test_acc += (cur_acc - test_acc) / (i + 1)
                
            print(f"iter {iter_num} done. took {time.time() - iter_start:.0f}s. beta {beta:.3f}, test_acc {test_acc:.4f}")
            print()

In [122]:
#!g1.1
train_set = CustomTensorDataset(X_train_full, to_one_hot(labels_train_full, NUM_CLASSES))
test_set  = CustomTensorDataset(X_test_full,  labels_test_full)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1280 * 4)
test_loader  = torch.utils.data.DataLoader(test_set,  batch_size=2000)

In [123]:
#!g1.1

%state_exclude estimator
%state_exclude models

torch.manual_seed(0)
np.random.seed(0)

n_models = 256
models = [Myrtle9(num_filters=1, groups=10) for _ in range(n_models)]

estimator = GpEstimator(models, NUM_CLASSES, 0.2, X_train[0], DEVICE, groups=10)
boosting(estimator, train_loader, test_loader, learning_rate=1e5, n_iter=30)

In [None]:
#!g1.1

%state_exclude estimator
%state_exclude models

torch.manual_seed(0)
np.random.seed(0)

n_models = 1024
models = [Myrtle9(num_filters=1, groups=10) for _ in range(n_models)]

estimator = GpEstimator(models, NUM_CLASSES, 0.2, X_train[0], DEVICE, groups=10)
boosting(estimator, train_loader, test_loader, learning_rate=1e5, n_iter=30)

In [None]:
#!g1.1
