In [1]:
import os
import shutil
import zipfile
import urllib.request


def download_repo(url, save_to):
    zip_filename = save_to + '.zip'
    urllib.request.urlretrieve(url, zip_filename)
    
    if os.path.exists(save_to):
        shutil.rmtree(save_to)
    with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
        zip_ref.extractall('.')
    del zip_ref
    assert os.path.exists(save_to)

In [2]:
REPO_PATH = 'LinearizedNNs-master'

download_repo(url='https://github.com/maxkvant/LinearizedNNs/archive/master.zip',
              save_to=REPO_PATH)

In [3]:
import sys
sys.path.append(f"{REPO_PATH}/src")

In [4]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import RidgeClassifier
from sklearn.decomposition import PCA

from pytorch_impl.nns import ResNet, FCN, CNN
from pytorch_impl.nns import warm_up_batch_norm
from pytorch_impl.estimators import LinearizedSgdEstimator, SgdEstimator, MatrixExpEstimator, GradientBoostingEstimator
from pytorch_impl import ClassifierTraining
from pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from pytorch_impl.nns.utils import to_one_hot

In [5]:
device = torch.device('cuda:0') if (torch.cuda.is_available()) else torch.device('cpu')
num_classes = 10
device

device(type='cuda', index=0)

In [6]:
# compute M^-1 * (exp(M) - E)
def compute_exp_term(M, device, n_iter=3):
    with torch.no_grad():
        M = M.double().to(device)

        n = M.size()[0]
        norm = torch.sqrt((M ** 2).sum())
        steps = 0
        while norm > 1e-9:
            M /= 2.
            norm /= 2.
            steps += 1

        series_sum = torch.zeros([n, n]).double().to(device)
        prod = torch.eye(n).double().to(device)

        # series_sum: E + M / 2 + M^2 / 6 + ...
        for i in range(1, n_iter):
            series_sum = (series_sum + prod)
            prod = torch.matmul(prod, M) / (i + 1)

        # (exp 0) (exp 0) = (exp^2           0)
        # (sum E) (sum E) = (sum * exp + sum E)
        exp = torch.matmul(M, series_sum) + torch.eye(n).to(device)
        for step in range(steps):
            series_sum = (torch.matmul(series_sum, exp) + series_sum) / 2.
            exp = torch.matmul(exp, exp)

        return series_sum

In [15]:
kernels_12k = np.load('../data/kernels_12k.npz')

train_kernel = kernels_12k['train_kernel']
test_kernel  = kernels_12k['test_kernel']
labels_train = kernels_12k['labels_train']
labels_test  = kernels_12k['labels_test']

train_kernel.shape, test_kernel.shape, labels_train.shape, labels_test.shape

((12800, 12800), (10000, 12800), (12800,), (10000,))

In [16]:
train_kernel = torch.from_numpy(train_kernel).float().to(device)
test_kernel  = torch.from_numpy(test_kernel).float().to(device)

labels_test  = torch.from_numpy(labels_test).to(device)
labels_train = torch.from_numpy(labels_train).to(device)

In [22]:
y_train = to_one_hot(labels_train, num_classes).to(device)

In [23]:
lr = 1e4

n = len(train_kernel)
reg = 0e-4 * torch.eye(n).to(device)
exp_term = - lr * compute_exp_term(- lr * (train_kernel + reg), device).float()

y_pred = torch.matmul(test_kernel, torch.matmul(exp_term, - y_train))

del exp_term

(y_pred.argmax(dim=1) == labels_test).float().mean()

tensor(0.8012, device='cuda:0')

In [24]:
y_train = np.load('../data/cifar10_targets_2.npz')['targets'][:12800]
y_train = torch.tensor(y_train).to(device)

In [25]:
lr = 1e4

n = len(train_kernel)
reg = 0e-4 * torch.eye(n).to(device)
exp_term = - lr * compute_exp_term(- lr * (train_kernel + reg), device).float()

y_pred = torch.matmul(test_kernel, torch.matmul(exp_term, - y_train))

del exp_term

(y_pred.argmax(dim=1) == labels_test).float().mean()

tensor(0.8044, device='cuda:0')

In [26]:
del train_kernel
del test_kernel

In [27]:
def matmul_via_torch(numpy_matrix, torch_matrix, step=2048):
    with torch.no_grad():
        n,  m = numpy_matrix.shape
        m2, k = torch_matrix.size()
        assert m2 == m
        
        to_torch = lambda matrix: torch.from_numpy(matrix).double().to(device)
        
        result = torch.zeros([n, k]).to(device)
        for l in range(0, n, step):
            r = min(l + step, n)
            result[l:r] = torch.matmul(to_torch(numpy_matrix[l:r]), torch_matrix.double())
        return result

In [28]:
def boosting(train_kernel, y_train, labels_train, test_kernel, labels_test, n_iter=24, lr=1e4, flips=False):
    with torch.no_grad():
        n = len(train_kernel)
        
        block_size = 1280 * 2

        right_vector = torch.zeros([n, num_classes]).double().to(device)
        # right_vector.normal_()
        # right_vector /= np.sqrt(n)
        
        n_actual = (n // 2) if (flips) else n 
        
        n_blocks = (2 * n_actual) // (3 * block_size) + 1
        print(n_blocks)

        for iter_num in range(n_iter):
            index = torch.randperm(n_actual).to(device)
            
            if flips:
                index += n_actual * rand_bool(n_actual)
                
            y_pred_train = matmul_via_torch(train_kernel, right_vector)
            y_pred_test  = matmul_via_torch(test_kernel, right_vector)
            train_acc = (y_pred_train.argmax(dim=1) == labels_train).float().mean().item()
            test_acc  = (y_pred_test.argmax(dim=1)  == labels_test).float().mean().item()

            y_residual     = y_pred_train - y_train
            
            train_mse = (y_residual ** 2).sum(dim=1).mean().item()
            
            print(f"iteration {iter_num} train_acc {train_acc} test_acc {test_acc} train_mse {train_mse}")

            d_right_vector = torch.zeros([n, num_classes]).double().to(device)
            
            for i in range(n_blocks):
                batch_index = index[i * block_size: (i + 1) * block_size]
                batch_index_np = batch_index.cpu().numpy()
                
                K = train_kernel[batch_index_np][:, batch_index_np]
                K = torch.from_numpy(K).double().to(device)
                
                K = K + 1e-4 * torch.eye(len(K)).double().to(device)
                
                exp_term = - lr * compute_exp_term(- lr * K, device)
                d_right_vector[batch_index] = torch.matmul(exp_term, y_residual[batch_index].double()) / n_blocks
                
            print(f"batches {0}-{n_blocks - 1} done")

            pred_change = matmul_via_torch(train_kernel, d_right_vector)
            beta = (- y_residual * pred_change).sum() / (pred_change ** 2).sum()
            print(f"beta = {beta}")

            right_vector += d_right_vector * beta
            
            print()
            
        y_pred_train = matmul_via_torch(train_kernel, right_vector)
        y_pred_test  = matmul_via_torch(test_kernel, right_vector)
        train_acc = (y_pred_train.argmax(dim=1) == labels_train).float().mean().item()
        test_acc  = (y_pred_test.argmax(dim=1)  == labels_test).float().mean().item()

        y_residual     = y_pred_train - to_one_hot(labels_train, num_classes).double().to(device)

        train_mse = (y_residual ** 2).sum(dim=1).mean().item()

        print(f"iteration {n_iter} train_acc {train_acc} test_acc {test_acc} train_mse {train_mse}")

In [29]:
%%time

kernels_50k = np.load('../data/kernels_50k.npz')

train_kernel = kernels_50k['train_kernel']
test_kernel  = kernels_50k['test_kernel']
labels_train = kernels_50k['labels_train']
labels_test  = kernels_50k['labels_test']

labels_train = torch.from_numpy(labels_train).to(device)
labels_test  = torch.from_numpy(labels_test).to(device)

y_train = np.load('../data/cifar10_targets.npz')['targets']
y_train = torch.tensor(y_train).to(device)

CPU times: user 35.5 s, sys: 35.5 s, total: 1min 10s
Wall time: 16min 57s


In [30]:
boosting(train_kernel, y_train, labels_train, test_kernel, labels_test, lr=1e5, n_iter=128)

14
iteration 0 train_acc 0.09999999403953552 test_acc 0.09999999403953552 train_mse 66.65372467041016
batches 0-13 done
beta = 1.1174983978271484

iteration 1 train_acc 0.8289799690246582 test_acc 0.7734000086784363 train_mse 22.691205978393555
batches 0-13 done
beta = 2.3555779457092285

iteration 2 train_acc 0.9208199977874756 test_acc 0.8316999673843384 train_mse 16.630090713500977
batches 0-13 done
beta = 1.6244648694992065

iteration 3 train_acc 0.9400999546051025 test_acc 0.8342999815940857 train_mse 13.46646499633789
batches 0-13 done
beta = 1.941909670829773

iteration 4 train_acc 0.9642799496650696 test_acc 0.8499000072479248 train_mse 11.432799339294434
batches 0-13 done
beta = 1.7602756023406982

iteration 5 train_acc 0.9706999659538269 test_acc 0.849299967288971 train_mse 9.808686256408691
batches 0-13 done
beta = 1.893420934677124

iteration 6 train_acc 0.9807199835777283 test_acc 0.8572999835014343 train_mse 8.517805099487305
batches 0-13 done
beta = 1.81419038772583

ite