In [1]:
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import torch
from sklearn.datasets import load_digits
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [2]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):
        digits = load_digits()
        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        else:
            self.data = digits.data[1000:].astype(np.float32)

        self.targets = digits.target
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        target = self.targets[idx]
        
        if self.transforms:
            sample = self.transforms(sample)
        return sample, target    
    
def mnist_loader(batch_size: int) -> None:
    train_set = datasets.MNIST('data', train=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.1307,), (0.3081,))
                               ]), download=True)
    val_set = datasets.MNIST('data', train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))
                             ]), download=True)
    
    train_set = torch.utils.data.Subset(train_set, range(5000))
    val_set = torch.utils.data.Subset(val_set, range(1000))
    
    
    train_loader = DataLoader(
        train_set, batch_size=batch_size, num_workers=4,
        pin_memory=True, shuffle= True)

    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory= True)

    return train_loader, val_loader    

def digits_loader(batch_size:int):
    transform = transforms.Lambda(lambda x: 2. * (x / 17.) - 1.)
    train_data = Digits(mode='train', transforms=transform)
    test_data = Digits(mode='test', transforms=transform)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory= True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, pin_memory=True)

    return train_loader, test_loader  


class MLP(nn.Module):
    ''' A basic 3 layer MLP '''

    def __init__(self, input_dim: int, num_classes: int, hidden_dim: int = 32) -> None:
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes,bias = False)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x  
    
    def features(self, x):
        x = x.view(-1, self.input_dim)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
#         x = self.fc3(x)
        return x  

In [3]:
def get_features(model, loader):
    feats = []
    model.eval()
    
    embeddings = np.zeros((0, hidden_size))
    labels = np.zeros((0))

    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)            
            embs = model.features(data)
            embeddings = np.concatenate((embeddings, embs.cpu().numpy()))
            labels = np.concatenate((labels, target))

    return embeddings, labels        

def train_epoch(epoch, train_loader):    
    model.train()      
    
    for x, y in train_loader:
        x,y = x.to(device),y.to(device)        
        optimizer.zero_grad(set_to_none = True)

        output = model(x)
        loss = loss_function(output,y)
        
        loss.backward()
        optimizer.step()   
    
        
def test_epoch(epoch, test_loader, output_epochs = 10):        
    model.eval()
    loss = 0
    acc = 0
    N = 0
    for data, labels in test_loader:
        data, labels = data.to(device), labels.to(device)
        logits = model(data)
        
        loss_t = loss_function(logits, labels)
        acc_t = 100 * torch.sum(torch.argmax(logits,dim = 1) == labels) / len(labels)


        loss += loss_t.item()
        acc += acc_t
        N += data.shape[0]

    loss = loss / N
    acc = acc / len(test_loader)
    
    if epoch % output_epochs == 0:
        print(f'Epoch: {epoch}, CE = {loss}, ACC = {acc}')
  
    return loss

In [4]:
dataset = 'mnist'
batch_size = 32

if dataset == 'digits':
    train_loader, test_loader = digits_loader(batch_size = batch_size)
    data_width = 8
    data_dim = data_width**2
    num_classes = 10
    output_epochs = 1
elif dataset == 'mnist':
    train_loader, test_loader = mnist_loader(batch_size = batch_size)
    data_width = 28
    data_dim = data_width**2
    num_classes = 10
    output_epochs = 1    

    
## Define hyperparameters
hidden_size = 32 # size of layers in model
lr = 0.01 
num_epochs = 5
loss_function = nn.CrossEntropyLoss()

### Define model and optimizer
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
model = MLP(input_dim = data_dim, num_classes = num_classes, hidden_dim = hidden_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


best_loss = 1000
epochs_since_improvement = 0
for epoch in range(num_epochs):    
    train_epoch(epoch, train_loader)
    test_loss = test_epoch(epoch, test_loader, output_epochs = output_epochs)
    if test_loss < best_loss:
        best_model = deepcopy(model.state_dict())


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Epoch: 0, CE = 0.013887106120586395, ACC = 85.83984375
Epoch: 1, CE = 0.013362858057022095, ACC = 87.20703125
Epoch: 2, CE = 0.015732753075659276, ACC = 85.7421875
Epoch: 3, CE = 0.013607118770480157, ACC = 87.20703125
Epoch: 4, CE = 0.013322155103087424, ACC = 88.76953125


In [11]:
print('Train Loader')
for x,y in train_loader:
    output = model(x)
    print(torch.argmax(output,dim = 1))
    print(y)
    print(loss_function(output, y))
    break

print()
print()
print()
print('Test Loader')    
for x,y in test_loader:
    output = model(x)
    print(torch.argmax(output,dim = 1))
    print(y)
    print(loss_function(output, y))
    break    

Train Loader
tensor([9, 7, 5, 4, 7, 6, 9, 6, 2, 2, 1, 5, 9, 4, 4, 1, 0, 1, 7, 4, 0, 9, 7, 3,
        4, 9, 6, 5, 0, 4, 7, 3])
tensor([9, 7, 5, 4, 7, 6, 9, 6, 2, 2, 1, 5, 9, 4, 4, 1, 0, 1, 7, 4, 0, 9, 7, 8,
        9, 9, 6, 5, 0, 4, 7, 3])
tensor(0.1903, grad_fn=<NllLossBackward>)



Test Loader
tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1])
tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1])
tensor(0.3076, grad_fn=<NllLossBackward>)


In [12]:
# Get the feature kernels

# Train set
feats_train, labels_train = get_features(model,train_loader)    
FK = feats_train @ feats_train.transpose()

# Test set
feats_test, labels_test = get_features(model,test_loader)    
FK_test = feats_test @ feats_test.transpose()

In [13]:
%%time
# Let's decompose the feature kernel. This is kernel PCA on the data with a kernel learned via features
U,S,V = np.linalg.svd(FK)

CPU times: user 1min 14s, sys: 855 ms, total: 1min 15s
Wall time: 38.8 s


In [14]:
%%time
# Decompose the test kernel just cause we can
U_test,S_test,V_test = np.linalg.svd(FK_test)

CPU times: user 565 ms, sys: 36.4 ms, total: 601 ms
Wall time: 324 ms


In [17]:
S[0:33]

array([1.06309123e+06, 4.96130412e+05, 4.57817873e+05, 3.61928555e+05,
       3.37971388e+05, 2.19157719e+05, 1.64447879e+05, 9.73574035e+04,
       8.12054531e+04, 7.22451590e+04, 4.26400838e+04, 3.75154036e+04,
       2.61840257e+04, 2.49301464e+04, 2.16674233e+04, 1.76575909e+04,
       1.32244020e+04, 1.10738526e+04, 1.02820863e+04, 9.60930014e+03,
       7.49083800e+03, 7.08403254e+03, 6.00716364e+03, 5.19599241e+03,
       4.71248047e+03, 4.03342693e+03, 3.54149657e+03, 2.93436721e+03,
       2.61389763e+03, 1.66502641e+03, 6.79993712e+02, 2.16706289e+01,
       1.11286951e-09])

In [18]:
S_test[0:33]

array([1.70562037e+05, 9.29771256e+04, 6.32432741e+04, 5.51877225e+04,
       5.20200517e+04, 3.47431033e+04, 2.64476912e+04, 1.52996854e+04,
       1.16152366e+04, 9.84608900e+03, 6.92901826e+03, 5.44786048e+03,
       4.37500968e+03, 3.53152385e+03, 3.31906758e+03, 2.83114999e+03,
       2.53105908e+03, 2.06661173e+03, 1.85521915e+03, 1.60679748e+03,
       1.48327526e+03, 1.20402944e+03, 9.60334307e+02, 7.50718381e+02,
       6.79124394e+02, 5.85725289e+02, 4.62735475e+02, 4.13054296e+02,
       3.47637571e+02, 4.62672652e+01, 4.32914034e+01, 1.24818012e-10,
       6.42016869e-11])

In [119]:
w = best_model['fc3.weight'][:,:].detach().numpy().transpose()

In [116]:
## Now we solve F^T F a = w for a
# N = # samples = 5000
# D = # features = 32
# Need feats_train in R^DxN

feats_train = feats_train.transpose()

In [117]:
# Solving Fa = w, gives a = (F^T F)^-1 F^T w
FTF = (feats_train.transpose() @ feats_train) + 0.00001 * np.eye(5000)
FTFinv = np.linalg.inv(FTF)

In [120]:
FTw = (feats_train.transpose() @ w)

In [121]:
a = FTFinv @ FTw

In [123]:
a

array([[ 4.01362504e-02,  4.44235688e-02, -3.37799092e-02, ...,
         2.62050346e-03,  3.74421704e-02,  2.03334803e-02],
       [ 7.51732377e-04,  2.20430398e-03, -1.89899087e-03, ...,
         3.24858899e-03,  2.75285600e-03,  1.30534717e-03],
       [-8.19433284e-03,  6.72474991e-04,  4.18309508e-03, ...,
         1.07090239e-02,  2.96160493e-03, -1.69968026e-04],
       ...,
       [-1.54445879e-05, -4.63149045e-05, -7.33952947e-06, ...,
        -1.35463866e-04, -1.54428417e-04, -2.07060948e-05],
       [ 1.41069293e-04,  2.26218777e-04,  4.33604437e-05, ...,
        -7.66542507e-06, -1.43289275e-04, -2.76804494e-05],
       [ 1.02874619e-05,  3.96743417e-05,  1.11117879e-05, ...,
         2.09433492e-05,  1.14605267e-04, -1.21545745e-05]])

In [134]:
np.sqrt(np.diag(a.transpose() @ FK @ a))

array([1.14795407, 1.50111468, 1.04685615, 0.92225615, 1.11397796,
       0.99123258, 1.15805069, 1.12715211, 1.08622707, 0.97514829])