In [1]:
#!L
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch import utils

from linearized_nns.estimator import Estimator
from linearized_nns.pytorch_impl.estimators import SgdEstimator
from linearized_nns.pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from linearized_nns.pytorch_impl import ClassifierTraining
from linearized_nns.pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from linearized_nns.pytorch_impl.nns.utils import to_one_hot, print_sizes
from linearized_nns.from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [2]:
#!L
device = torch.device('cuda:0') if (torch.cuda.is_available()) else torch.device('cpu')
device

device(type='cuda', index=0)

https://gist.github.com/bveliqi/5efe7d20c99025d02df87e4c595711c1

In [3]:
#!L

train_dir = 'tiny-imagenet-200/train'
test_dir  = 'tiny-imagenet-200/val'

normalize = transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262])

train_transforms = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

trainset = datasets.ImageFolder(train_dir, transform=train_transforms) 
testset  = datasets.ImageFolder(test_dir,  transform=test_transforms)

trainloader = utils.data.DataLoader(trainset, pin_memory=True, shuffle=True, batch_size=1280)
testloader  = utils.data.DataLoader(testset,  pin_memory=True, shuffle=True, batch_size=1000)

In [4]:
#!L
def get_tiny_image_net_zca():
    np.random.seed(0)
    random.seed(0)
    torch.manual_seed(0)
    
    trainset = datasets.ImageFolder(train_dir, transform=transforms.ToTensor()) 
    testset  = datasets.ImageFolder(test_dir, transform=transforms.ToTensor())
    
    trainloader = utils.data.DataLoader(trainset, pin_memory=True, shuffle=True, batch_size=100000)
    testloader  = utils.data.DataLoader(testset,  pin_memory=True, shuffle=True, batch_size=10000)
    
    _, (X_train, y_train)  = next(enumerate(trainloader))
    _, (X_test,  y_test)   = next(enumerate(testloader))
    
    X_train = X_train.numpy().astype(np.float64)
    X_test  = X_test.numpy().astype(np.float64)
    
    (X_train, X_test), global_ZCA = to_zca(X_train, X_test)

    X_train = np.transpose(X_train, (0,3,1,2))
    X_test  = np.transpose(X_test,  (0,3,1,2))

    return torch.tensor(X_train), y_train, torch.tensor(X_test), y_test

In [5]:
#!L
%%time
X_train, y_train, X_test, y_test = get_tiny_image_net_zca()

CPU times: user 2h 23min 37s, sys: 25min 49s, total: 2h 49min 26s
Wall time: 30min 44s


In [6]:
#!L
torch.save(X_train, train_dir + '/X_train.pt')
torch.save(y_train, train_dir + '/y_train.pt')
torch.save(X_test,  test_dir  + '/X_val.pt')
torch.save(y_test,  test_dir  + '/y_val.pt')

In [216]:
#!L
X_train.shape, X_test.shape

(torch.Size([100000, 64, 3, 64]), torch.Size([10000, 64, 3, 64]))

In [217]:
#!L
labels_train_full = y_train
labels_test_full  = y_test

X_train_full = X_train
X_test_full  = X_test

In [218]:
#!L 
X_train_full = X_train_full.permute((0, 2, 1, 3))
X_test_full  =  X_test_full.permute((0, 2, 1, 3))

In [219]:
#!L
def compute_kernels(models, X_train, X_test, device):
    with torch.no_grad():
        X_train = X_train.to(device)
        X_test  = X_test.to(device)

        n_train = len(X_train)
        n_test  = len(X_test)

        train_kernel = torch.zeros([n_train, n_train]).double().to(device)
        test_kernel  = torch.zeros([n_test,  n_train]).double().to(device)

        m = 0
        start_time = time.time()

        for model_i, model in enumerate(models):
            model = model.to(device)
            if model_i & (model_i - 1) == 0:
                print(f"{model_i} models done. time {time.time() - start_time:.0f}s")

            train_features = model.readout(X_train) 
            test_features  = model.readout(X_test)

            m += 1

            train_kernel += torch.matmul(train_features, train_features.T).double()
            test_kernel  += torch.matmul(test_features,  train_features.T).double()

        train_kernel /= m
        test_kernel  /= m

        return train_kernel.float(), test_kernel.float()

In [223]:
#!L
from linearized_nns.pytorch_impl.nns.primitives import *

class Myrtle9(nn.Module):
    def __init__(self, num_classes=1, input_filters=3, num_filters=1, groups=1):
        super(Myrtle9, self).__init__()
        filters = num_filters

        def Activation():
            return ReLU2()

        self.layers = nn.Sequential(
            Conv(input_filters, filters * groups), Activation(),
            Conv(filters, filters * 2, groups),    Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),

            Conv(filters * 2, filters * 4, groups), Activation(),
            Conv(filters * 4, filters * 8, groups), Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),

            Conv(filters *  8, filters * 16, groups), Activation(),
            Conv(filters * 16, filters * 32, groups), Activation(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            
            Conv(filters *  32, filters * 32, groups), Activation(),
            Conv(filters *  32, filters * 32, groups), Activation(),
            nn.AvgPool2d(kernel_size=8, stride=8),

            Flatten(),
            Normalize(filters * 32)
        )
        self.classifier = nn.Linear(filters * 32 * groups, num_classes, bias=True)

    def readout(self, x):
        return self.layers(x)

    def forward(self, x):
        x = self.readout(x)
        return self.classifier(x)

In [250]:
#!L
N_train = 1280 * 10
N_test  = 1000

X_train = X_train_full[:N_train].float()
X_test  = X_test_full[:N_test].float()

In [251]:
#!L
labels_train = labels_train_full[:N_train] 
labels_test  = labels_test_full[:N_test]

num_classes = 200

y_train = to_one_hot(labels_train, num_classes).to(device)
y_test  = to_one_hot(labels_test,  num_classes).to(device)

In [253]:
#!L
X_train.shape, X_test.shape

(torch.Size([12800, 3, 64, 64]), torch.Size([1000, 3, 64, 64]))

In [254]:
#!L
train_kernel, test_kernel = compute_kernels(models, X_train, X_test, device)

train_kernel = train_kernel.float().to(device)
test_kernel  = test_kernel.float().to(device)

0 models done. time 0s
1 models done. time 0s
2 models done. time 1s
4 models done. time 1s
8 models done. time 3s
16 models done. time 6s
32 models done. time 12s
64 models done. time 23s
128 models done. time 46s
256 models done. time 92s
512 models done. time 185s
1024 models done. time 370s
2048 models done. time 741s
4096 models done. time 1481s


In [255]:
#!L
train_kernel[:5,:5]

tensor([[1.0000, 0.9841, 0.9956, 0.9929, 0.9942],
        [0.9841, 1.0000, 0.9854, 0.9878, 0.9904],
        [0.9956, 0.9854, 1.0000, 0.9955, 0.9926],
        [0.9929, 0.9878, 0.9955, 1.0000, 0.9921],
        [0.9942, 0.9904, 0.9926, 0.9921, 1.0000]], device='cuda:0')

In [256]:
#!L
test_kernel[:5,:5]

tensor([[0.9930, 0.9905, 0.9945, 0.9935, 0.9939],
        [0.9956, 0.9858, 0.9936, 0.9930, 0.9943],
        [0.9971, 0.9801, 0.9922, 0.9895, 0.9924],
        [0.9936, 0.9853, 0.9940, 0.9947, 0.9915],
        [0.9903, 0.9878, 0.9909, 0.9931, 0.9928]], device='cuda:0')

In [257]:
#!L
y_train.shape

torch.Size([12800, 200])

In [258]:
#!L
lr = 1e5

n = len(train_kernel)

exp_term = - lr * compute_exp_term(- lr * train_kernel, device)


In [259]:
#!L
y_train = y_train.to(device)

In [260]:
#!L
train_kernel.size()

torch.Size([12800, 12800])

In [261]:
#!L
exp_term.size()

torch.Size([12800, 12800])

In [262]:
#!L
labels_test.size()

torch.Size([1000])

In [263]:
#!L
y_pred = torch.matmul(test_kernel, torch.matmul(exp_term, - y_train))
(y_pred.argmax(dim=1) == labels_test.to(device)).float().mean()

tensor(0.2270, device='cuda:0')

In [264]:
#!L
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

from linearized_nns.estimator import Estimator
from linearized_nns.pytorch_impl.estimators import SgdEstimator
from linearized_nns.pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from linearized_nns.pytorch_impl import ClassifierTraining
from linearized_nns.pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from linearized_nns.pytorch_impl.nns.utils import to_one_hot, print_sizes
from linearized_nns.from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [265]:
#!L
class GpEstimator(Estimator):
    def __init__(self, models, n_classes, learning_rate, x_example, device, groups=1):
        super(GpEstimator, self).__init__()
        self.models    = [model.to(device) for model in models]
        self.lr        = learning_rate
        self.n_classes = n_classes
        self.device    = device
        
        n = len(models)
        X = torch.stack([x_example]).to(device)
        
        model = models[0].to(device)
        readout_size = model.readout(X).size()[1]
    
        # TODO: Assert that models have the same readout size
        
        self.w      = torch.zeros([n, readout_size, n_classes]).to(device)
        self.w_size = n * groups
        
    def get_w_update(self, X, right_vector):
        with torch.no_grad():
            assert len(X) == len(right_vector)

            X            = X.to(self.device)
            right_vector = right_vector.to(self.device)

            n = len(X)
            w_updates = []
            
            for model in self.models:
                features = self.to_model_features(X, model)
                update = torch.matmul(features.T, right_vector)
                w_updates.append(update)
            return torch.stack(w_updates)
                                  
    def to_model_features(self, X, model):
        with torch.no_grad():
            model = model.to(device)
            return model.readout(X) * (1. / np.sqrt(self.w_size))
        
    def calc_kernel(self, X):
        with torch.no_grad():
            X = X.to(device)
            
            res = torch.zeros([len(X), len(X)]).to(device).double()
            for model in self.models:
                features_x = self.to_model_features(X, model)
                
                res += torch.matmul(features_x, features_x.T).double()
            return res.float()
        
    def calc_kernel_pred(self, X):
        with torch.no_grad():
            X = X.to(device)
            
            n = len(X)
            y_pred = torch.zeros([n, self.n_classes]).to(device).double()
            kernel = torch.zeros([len(X), len(X)]).to(device).double()
            
            for model, w in zip(self.models, self.w):
                features = self.to_model_features(X, model)
                
                kernel += torch.matmul(features, features.T).double()
                y_pred += torch.matmul(features, w).double()
            return kernel, y_pred.float()
        
    def calc_kernels(self, X_train, X_test):
        with torch.no_grad():
            X_train = X_train.to(device)
            X_test  = X_test.to(device)
            
            res_train = torch.zeros([len(X_train), len(X_train)]).to(device)
            res_test  = torch.zeros([len(X_test),  len(X_train)]).to(device)
            for model in self.models:
                features_train = self.to_model_features(X_train, model)
                features_test  = self.to_model_features(X_test,  model)
                
                res_train += torch.matmul(features_train, features_train.T)
                res_test  += torch.matmul(features_test,  features_train.T)
            return res_train, res_test
            
    def predict(self, X, cur_w=None):
        X = X.to(self.device)
        if cur_w is None:
            cur_w = self.w
         
        with torch.no_grad():
            n = len(X)
            res = torch.zeros([n, self.n_classes]).double().to(device)

            for model, w in zip(self.models, cur_w):
                features = self.to_model_features(X, model)
                res += torch.matmul(features, w).double()
            return res.float()

In [266]:
#!g1.1
def calc_right_vector(kernel, y, learning_rate=1e5, reg_param=0):
    with torch.no_grad():
        y      = y.to(device)
        kernel = kernel.to(device)
        
        n      = len(kernel)
        reg = torch.eye(n).to(device) * reg_param
        
        exp_term = - learning_rate * compute_exp_term(- learning_rate * (kernel + reg), device)
        right_vector = torch.matmul(exp_term, - y)
        return right_vector

In [289]:
#!g1.1
def boosting(estimator, train_loader, test_loader, learning_rate=1e5, beta=1., noise_rate=0., n_iter=10):
    output_kernel = False
    
    with torch.no_grad():
        batches_num = 0
        for _ in enumerate(train_loader):
            batches_num += 1
        n_batches = (batches_num * 2) // 3 + 1
            
        for iter_num  in range(n_iter):
            print(f"iter {iter_num} ==========================")
            
            w_update = 0
            iter_start = time.time()
            
            for batch_i, (X, y) in enumerate(train_loader):
                if batch_i >= n_batches:
                    break
                
                X = X.to(device)
                y = y.to(device)
                
                kernel, y_pred = estimator.calc_kernel_pred(X)
                if output_kernel:
                    print(f"kernel\n{kernel[:5,:5]}")
                
                y_residual = y_pred - y
                
                train_acc = (y_pred.argmax(dim=1) == y.argmax(dim=1)).float().mean().item()
                train_mse = (y_residual ** 2).mean().item()
                print(f"batch {batch_i}: train_acc {train_acc:.4f}, train_mse {train_mse:.6f}")
                
                right_vector = calc_right_vector(kernel, y_residual, learning_rate=learning_rate)
                
                w_update += estimator.get_w_update(X, right_vector).double()
                
                pred_change = torch.matmul(kernel, right_vector.double())
                

            w_update = (w_update / n_batches).float()
            
            _, (X, y) = next(enumerate(train_loader))
            y = y.to(device)
            
            y_pred      = estimator.predict(X)
            pred_change = estimator.predict(X, w_update)
            
            y_residual = (y_pred - y)
            
            w_std = w_update.std()
            noise = w_std * torch.randn(w_update.size()).to(device)
            estimator.w -= (w_update + noise_rate * noise) * beta 
            
            test_acc = 0
            for i, (X_test, labels) in enumerate(test_loader):
                y_pred = estimator.predict(X_test) 
                cur_acc   = (y_pred.argmax(dim=1) == labels.to(device)).float().mean().item()
                test_acc += (cur_acc - test_acc) / (i + 1)
                
            print(f"iter {iter_num} done. took {time.time() - iter_start:.0f}s. beta {beta:.3f}, test_acc {test_acc:.4f}")
            print()

In [271]:
#!g1.1

labels_train_full.shape, labels_test_full.shape

(torch.Size([100000]), torch.Size([10000]))

In [275]:
#!g1.1

train_set = CustomTensorDataset(X_train_full, to_one_hot(labels_train_full, num_classes))
test_set  = CustomTensorDataset(X_test_full,  labels_test_full)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1280 * 2)
test_loader  = torch.utils.data.DataLoader(test_set,  batch_size=1000)

In [290]:
#!g1.1

%state_exclude estimator
%state_exclude models

torch.manual_seed(0)
np.random.seed(0)

n_models = 1024
models = [Myrtle9(num_filters=1, groups=10) for _ in range(n_models)]

estimator = GpEstimator(models, num_classes, 0.2, X_train[0], device, groups=50)
boosting(estimator, train_loader, test_loader, learning_rate=1e5, n_iter=30)

batch 0: train_acc 0.0055, train_mse 1.000000
batch 1: train_acc 0.0066, train_mse 1.000000
batch 2: train_acc 0.0031, train_mse 1.000000
batch 3: train_acc 0.0039, train_mse 1.000000
batch 4: train_acc 0.0039, train_mse 1.000000
batch 5: train_acc 0.0055, train_mse 1.000000
batch 6: train_acc 0.0031, train_mse 1.000000
batch 7: train_acc 0.0051, train_mse 1.000000
batch 8: train_acc 0.0070, train_mse 1.000000
batch 9: train_acc 0.0043, train_mse 1.000000
batch 10: train_acc 0.0047, train_mse 1.000000
batch 11: train_acc 0.0051, train_mse 1.000000
batch 12: train_acc 0.0035, train_mse 1.000000
batch 13: train_acc 0.0035, train_mse 1.000000
batch 14: train_acc 0.0051, train_mse 1.000000
batch 15: train_acc 0.0051, train_mse 1.000000
batch 16: train_acc 0.0059, train_mse 1.000000
batch 17: train_acc 0.0074, train_mse 1.000000
batch 18: train_acc 0.0039, train_mse 1.000000
batch 19: train_acc 0.0047, train_mse 1.000000
batch 20: train_acc 0.0059, train_mse 1.000000
batch 21: train_acc 0.0