In [50]:
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import torch
from sklearn.datasets import load_digits
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [36]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):
        digits = load_digits()
        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        else:
            self.data = digits.data[1000:].astype(np.float32)

        self.targets = digits.target
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        target = self.targets[idx]
        
        if self.transforms:
            sample = self.transforms(sample)
        return sample, target    
    
def mnist_loader(batch_size: int) -> None:
    train_set = datasets.MNIST('data', train=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.1307,), (0.3081,))
                               ]), download=True)
    val_set = datasets.MNIST('data', train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))
                             ]), download=True)
    
    train_set = torch.utils.data.Subset(train_set, range(5000))
    val_set = torch.utils.data.Subset(val_set, range(1000))
    
    
    train_loader = DataLoader(
        train_set, batch_size=batch_size, num_workers=8,
        pin_memory=True, shuffle= True)

    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory= True)

    return train_loader, val_loader    

def digits_loader(batch_size:int):
    transform = transforms.Lambda(lambda x: 2. * (x / 17.) - 1.)
    train_data = Digits(mode='train', transforms=transform)
    test_data = Digits(mode='test', transforms=transform)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory= True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, pin_memory=True)

    return train_loader, test_loader  


class MLP(nn.Module):
    ''' A basic 3 layer MLP '''

    def __init__(self, input_dim: int, num_classes: int, hidden_dim: int = 32) -> None:
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x  
    
    def features(self, x):
        x = x.view(-1, self.input_dim)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
#         x = self.fc3(x)
        return x  

In [37]:
def get_features(model, loader):
    feats = []
    model.eval()
    
    embeddings = np.zeros((0, hidden_size))
    labels = np.zeros((0))

    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)            
            embs = model.features(data)
            embeddings = np.concatenate((embeddings, embs.cpu().numpy()))
            labels = np.concatenate((labels, target))

    return embeddings, labels        

def train_epoch(epoch, train_loader):    
    model.train()      
    
    for x, y in train_loader:
        x,y = x.to(device),y.to(device)        
        optimizer.zero_grad(set_to_none = True)

        output = model(x)
        loss = loss_function(output,y)
        
        loss.backward()
        optimizer.step()   
    
        
def test_epoch(epoch, test_loader, output_epochs = 10):        
    model.eval()
    loss = 0
    acc = 0
    N = 0
    for data, labels in test_loader:
        data, labels = data.to(device), labels.to(device)
        logits = model(data)
        
        loss_t = loss_function(logits, labels)
        acc_t = 100 * torch.sum(torch.argmax(logits,dim = 1) == labels) / len(labels)


        loss += loss_t.item()
        acc += acc_t
        N += data.shape[0]

    loss = loss / N
    acc = acc / len(test_loader)
    
    if epoch % output_epochs == 0:
        print(f'Epoch: {epoch}, CE = {loss}, ACC = {acc}')
  
    return loss

In [38]:
dataset = 'mnist'
batch_size = 32

if dataset == 'digits':
    train_loader, test_loader = digits_loader(batch_size = batch_size)
    data_width = 8
    data_dim = data_width**2
    num_classes = 10
    output_epochs = 1
elif dataset == 'mnist':
    train_loader, test_loader = mnist_loader(batch_size = batch_size)
    data_width = 28
    data_dim = data_width**2
    num_classes = 10
    output_epochs = 1    

    
## Define hyperparameters
hidden_size = 32 # size of layers in model
lr = 0.01 
num_epochs = 5
loss_function = nn.CrossEntropyLoss()

### Define model and optimizer
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
model = MLP(input_dim = data_dim, num_classes = num_classes, hidden_dim = hidden_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


best_loss = 1000
epochs_since_improvement = 0
for epoch in range(num_epochs):    
    train_epoch(epoch, train_loader)
    test_loss = test_epoch(epoch, test_loader, output_epochs = output_epochs)
    if test_loss < best_loss:
        best_model = deepcopy(model.state_dict())


Epoch: 0, CE = 0.0168315339833498, ACC = 82.6171875
Epoch: 1, CE = 0.013591308519244194, ACC = 85.64453125
Epoch: 2, CE = 0.012158217549324035, ACC = 88.8671875
Epoch: 3, CE = 0.011766072049736977, ACC = 90.0390625
Epoch: 4, CE = 0.014020623974502086, ACC = 86.42578125


In [39]:
print('Train Loader')
for x,y in train_loader:
    output = model(x)
    print(torch.argmax(output,dim = 1))
    print(y)
    print(loss_function(output, y))
    break

print()
print()
print()
print('Test Loader')    
for x,y in test_loader:
    output = model(x)
    print(torch.argmax(output,dim = 1))
    print(y)
    print(loss_function(output, y))
    break    

Train Loader
tensor([0, 3, 6, 3, 6, 0, 4, 8, 1, 0, 3, 7, 1, 0, 8, 8, 9, 8, 0, 4, 0, 7, 8, 6,
        0, 3, 9, 4, 4, 3, 2, 7])
tensor([0, 3, 6, 3, 6, 0, 4, 8, 1, 0, 3, 7, 1, 0, 8, 8, 9, 8, 0, 4, 0, 7, 8, 6,
        0, 3, 9, 4, 4, 3, 2, 3])
tensor(0.1948, grad_fn=<NllLossBackward0>)



Test Loader
tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 8, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1])
tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1])
tensor(0.3761, grad_fn=<NllLossBackward0>)


In [75]:
# Get the feature kernels

# Train set
feats_train, labels_train = get_features(model,train_loader)    
FK = feats_train @ feats_train.transpose()

# Test set
feats_test, labels_test = get_features(model,test_loader)    
FK_test = feats_test @ feats_test.transpose()

In [79]:
%%time
# Let's decompose the feature kernel. This is kernel PCA on the data with a kernel learned via features
U,S,V = np.linalg.svd(FK)

CPU times: user 4min 29s, sys: 2min 22s, total: 6min 52s
Wall time: 3min 16s


In [76]:
%%time
# Decompose the test kernel just cause we can
U_test,S_test,V_test = np.linalg.svd(FK_test)

CPU times: user 3.31 s, sys: 3.57 s, total: 6.89 s
Wall time: 3.27 s


In [80]:
S[0:28]

array([1.16530919e+06, 5.11760562e+05, 4.65160720e+05, 4.28413869e+05,
       3.65938303e+05, 2.73957573e+05, 2.27790166e+05, 1.60870160e+05,
       1.46898906e+05, 6.55960637e+04, 4.09559756e+04, 3.99963090e+04,
       3.00615172e+04, 2.69785148e+04, 2.40713948e+04, 2.11203608e+04,
       1.53656209e+04, 1.39918386e+04, 1.05963567e+04, 8.74369220e+03,
       7.01546876e+03, 6.35805428e+03, 4.54111243e+03, 4.11854409e+03,
       3.30611619e+03, 2.51226504e+02, 3.53999143e-01, 1.09344492e-03])

In [78]:
S_test[0:28]

array([2.14603530e+05, 8.54345974e+04, 8.10318090e+04, 6.86575938e+04,
       5.16320201e+04, 4.57206065e+04, 3.97376332e+04, 2.36400533e+04,
       2.02252872e+04, 9.89092232e+03, 5.55541989e+03, 5.25442143e+03,
       4.97344186e+03, 4.34181484e+03, 3.87960943e+03, 3.68633002e+03,
       2.84022135e+03, 2.38076956e+03, 1.63306511e+03, 1.35243086e+03,
       1.21016942e+03, 1.08468327e+03, 7.74765737e+02, 5.19531785e+02,
       1.08372090e+02, 2.32772694e+00, 7.25967818e-02, 1.84893981e-10])