In [375]:
#!g1.1
import time
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt
import random
from sklearn.utils.extmath import randomized_svd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch import utils
from torch.utils.data import DataLoader, random_split


from linearized_nns.estimator import Estimator
from linearized_nns.pytorch_impl.estimators import SgdEstimator
from linearized_nns.pytorch_impl.nns import Myrtle5, Myrtle7, Myrtle10
from linearized_nns.pytorch_impl import ClassifierTraining
from linearized_nns.pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from linearized_nns.pytorch_impl.nns.utils import to_one_hot, print_sizes
from linearized_nns.from_neural_kernels import to_zca, CustomTensorDataset, get_cifar_zca

In [475]:
#!g1.1

DATA_DIR = 'imagenet-r'

NUM_CLASSES = 200
DEVICE = 'cuda'
N = 30000
SEED = 42 

In [476]:
#!g1.1

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomCrop(22),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(DATA_DIR, transform=val_transforms)
trainset, testset = random_split(dataset, lengths=[27000, 3000])

NameError: name 'np' is not defined

In [None]:
#!g1.1
class ZcaTorch:
    def __init__(self, V, inv_sqrt_zca_eigs, device):
        self.V = torch.tensor(V, dtype=torch.float) \
            .to(device)
        self.inv_sqrt_zca_eigs = torch.tensor(inv_sqrt_zca_eigs, dtype=torch.float) \
            .to(device)
        self.device = device
    
    def apply(self, batch):
        orig_shape = batch.shape
        batch = batch.reshape(batch.shape[0], -1).float()
        
        # norimize
        batch -= batch.mean(dim=1, keepdim=True)
        batch /= torch.norm(batch, dim=1, keepdim=True)

        # apply zca
        batch = torch.mm(torch.mm(torch.mm(batch, self.V.T), self.inv_sqrt_zca_eigs), self.V)
        return batch.reshape(orig_shape)

In [None]:
#!g1.1
V                 = torch.load('train_zca_V.pt')
inv_sqrt_zca_eigs = torch.load('train_zca_inv_sqrt_zca_eigs.pt')

In [None]:
#!g1.1
V.shape, inv_sqrt_zca_eigs.shape

In [None]:
#!g1.1
zca = ZcaTorch(V, inv_sqrt_zca_eigs, DEVICE)

In [None]:
#!g1.1
resnet18 = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASSES)
resnet18 = resnet18.to(DEVICE)
resnet18

In [None]:
#!g1.1
train_loader = DataLoader(trainset, pin_memory=False, shuffle=True,  batch_size=128)
test_loader  = DataLoader(testset,  pin_memory=False, shuffle=False, batch_size=100)

In [None]:
#!g1.1
def train(model, train_loader, test_loader, num_epochs=30, learning_rate=0.1):
    optimizer = SGD(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(1, num_epochs + 1):
        print(f'Epoch {epoch}:')
        
        model.train()
        train_loss = 0
        train_acc = 0
        for i, (X, y) in enumerate(train_loader):
            optimizer.zero_grad()
            
            X = X.to(DEVICE)
            y = y.to(DEVICE)
            
            y_pred = model.forward(X)
            loss   = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            
            acc = (y == y_pred.argmax(dim=-1)).float().mean().item()
            
            train_acc  += (acc               - train_acc)  / (i + 1)
            train_loss += (loss.cpu().item() - train_loss) / (i + 1)
            if (i + 1) % 10 == 0:
                print(f'\rtrain_acc {train_acc:.4f} train_loss {train_loss:.4f}', end='')
        print()
        model.eval()
        test_loss = 0
        test_acc  = 0 
        for i, (X, y) in enumerate(test_loader):
            X = X.to(DEVICE)
            y = y.to(DEVICE)
            
            y_pred = model.forward(X)
            loss   = criterion(y_pred, y)
            
            acc = (y == y_pred.argmax(dim=-1)).float().mean().item()
            
            test_acc  += (acc               - test_acc)  / (i + 1)
            test_loss += (loss.cpu().item() - test_loss) / (i + 1)
            
        print(f'test_acc {test_acc:.4f} test_loss {test_loss:.4f}')
        print()

resnet18 = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASSES)
model    = resnet18.to(DEVICE)
train(model, train_loader, test_loader)