In [1]:
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import torch
from sklearn.datasets import load_digits
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [2]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):
        digits = load_digits()
        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        else:
            self.data = digits.data[1000:].astype(np.float32)

        self.targets = digits.target
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        target = self.targets[idx]
        
        if self.transforms:
            sample = self.transforms(sample)
        return sample, target    
    
def mnist_loader(batch_size: int) -> None:
    train_set = datasets.MNIST('data', train=True,
                               transform=transforms.Compose([
#                                    transforms.RandomCrop(28, padding=4),
#                                    transforms.RandomHorizontalFlip(),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.1307,), (0.3081,))
                               ]), download=True)
    val_set = datasets.MNIST('data', train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))
                             ]), download=True)
    
    train_set = torch.utils.data.Subset(train_set, range(5000))
    val_set = torch.utils.data.Subset(val_set, range(1000))
    
    
    train_loader = DataLoader(
        train_set, batch_size=batch_size, num_workers=4,
        pin_memory=True, shuffle= True)

    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory= True)

    return train_loader, val_loader    

def digits_loader(batch_size:int):
    transform = transforms.Lambda(lambda x: 2. * (x / 17.) - 1.)
    train_data = Digits(mode='train', transforms=transform)
    test_data = Digits(mode='test', transforms=transform)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory= True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, pin_memory=True)

    return train_loader, test_loader  


class MLP(nn.Module):
    ''' A basic 3 layer MLP '''

    def __init__(self, input_dim: int, num_classes: int, hidden_dim: int = 32) -> None:
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes,bias = False)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x  
    
    def features(self, x):
        x = x.view(-1, self.input_dim)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
#         x = self.fc3(x)
        return x  

In [3]:
def get_features(model, loader):
    feats = []
    model.eval()
    
    embeddings = np.zeros((0, hidden_size))
    labels = np.zeros((0))

    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)            
            embs = model.features(data)
            embeddings = np.concatenate((embeddings, embs.cpu().numpy()))
            labels = np.concatenate((labels, target))

    return embeddings, labels        

def train_epoch(epoch, train_loader):    
    model.train()      
    
    for x, y in train_loader:
        x,y = x.to(device),y.to(device)        
        optimizer.zero_grad(set_to_none = True)

        output = model(x)
        loss = loss_function(output,y)
        
        loss.backward()
        optimizer.step()   
    
        
def test_epoch(epoch, test_loader, output_epochs = 10):        
    model.eval()
    loss = 0
    acc = 0
    N = 0
    for data, labels in test_loader:
        data, labels = data.to(device), labels.to(device)
        logits = model(data)
        
        loss_t = loss_function(logits, labels)
        acc_t = 100 * torch.sum(torch.argmax(logits,dim = 1) == labels) / len(labels)


        loss += loss_t.item()
        acc += acc_t
        N += data.shape[0]

    loss = loss / N
    acc = acc / len(test_loader)
    
    if epoch % output_epochs == 0:
        print(f'Epoch: {epoch}, CE = {loss}, ACC = {acc}')
  
    return loss

In [4]:
dataset = 'mnist'
batch_size = 32

if dataset == 'digits':
    train_loader, test_loader = digits_loader(batch_size = batch_size)
    data_width = 8
    data_dim = data_width**2
    num_classes = 10
    output_epochs = 1
elif dataset == 'mnist':
    train_loader, test_loader = mnist_loader(batch_size = batch_size)
    data_width = 28
    data_dim = data_width**2
    num_classes = 10
    output_epochs = 1    

    
## Define hyperparameters
hidden_size = 32 # size of layers in model
lr = 0.01 
num_epochs = 5
loss_function = nn.CrossEntropyLoss()

### Define model and optimizer
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
model = MLP(input_dim = data_dim, num_classes = num_classes, hidden_dim = hidden_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


best_loss = 1000
epochs_since_improvement = 0
for epoch in range(num_epochs):    
    train_epoch(epoch, train_loader)
    test_loss = test_epoch(epoch, test_loader, output_epochs = output_epochs)
    if test_loss < best_loss:
        best_model = deepcopy(model.state_dict())


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Epoch: 0, CE = 0.013870636656880379, ACC = 85.83984375
Epoch: 1, CE = 0.013568791814148426, ACC = 87.79296875
Epoch: 2, CE = 0.014156046748161316, ACC = 87.40234375
Epoch: 3, CE = 0.012121602348983287, ACC = 90.625
Epoch: 4, CE = 0.013459737211465836, ACC = 87.59765625


In [5]:
print('Train Loader')
for x,y in train_loader:
    output = model(x)
    print(torch.argmax(output,dim = 1))
    print(y)
    print(loss_function(output, y))
    break

print()
print()
print()
print('Test Loader')    
for x,y in test_loader:
    output = model(x)
    print(torch.argmax(output,dim = 1))
    print(y)
    print(loss_function(output, y))
    break    

Train Loader
tensor([7, 2, 2, 0, 7, 3, 5, 8, 2, 4, 3, 3, 4, 7, 0, 7, 0, 3, 8, 2, 5, 8, 9, 3,
        2, 8, 5, 7, 0, 9, 0, 9])
tensor([7, 2, 2, 0, 7, 3, 5, 8, 2, 4, 3, 3, 4, 7, 0, 7, 0, 3, 8, 2, 5, 8, 9, 3,
        2, 8, 5, 7, 0, 9, 0, 9])
tensor(0.0696, grad_fn=<NllLossBackward>)



Test Loader
tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1])
tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1])
tensor(0.2855, grad_fn=<NllLossBackward>)


In [6]:
# Get the feature kernels

# Train set
feats_train, labels_train = get_features(model,train_loader)    
FK = feats_train @ feats_train.transpose()

# Test set
feats_test, labels_test = get_features(model,test_loader)    
FK_test = feats_test @ feats_test.transpose()

# Looking at SVD of feature kernels

In [7]:
# %%time
# # Let's decompose the feature kernel. This is kernel PCA on the data with a kernel learned via features
# U,S,V = np.linalg.svd(FK)


# # Decompose the test kernel just cause we can
# U_test,S_test,V_test = np.linalg.svd(FK_test)

# print(np.sum(S > 0.01))
# print(S[0:33])
# print(S_test[0:33])

# Compute the norm of projected functions

In [8]:
w = best_model['fc3.weight'][:,:].detach().numpy().transpose()

In [9]:
## Now we solve F^T F a = w for a
# N = # samples = 5000
# D = # features = 32
# Need feats_train in R^DxN

feats_train = feats_train.transpose()
feats_train.shape

(32, 5000)

In [10]:
# Solving Fa = w, gives a = (F^T F)^-1 F^T w
FTF = (feats_train.transpose() @ feats_train) + 0.00001 * np.eye(5000)
FTFinv = np.linalg.inv(FTF)

In [11]:
FTw = (feats_train.transpose() @ w)

In [12]:
a = FTFinv @ FTw

In [13]:
a

array([[-4.95756685e-05, -2.45833682e-04, -1.78518228e-04, ...,
        -1.73320903e-04,  2.52998187e-04,  1.27209191e-05],
       [-3.86567847e-05, -6.13405805e-05, -1.57014127e-04, ...,
        -2.00525567e-04,  7.26352636e-05, -6.62693674e-05],
       [ 9.79563447e-05, -1.46635444e-04, -4.16743787e-05, ...,
        -5.39855282e-05, -2.20317143e-05, -3.14728995e-05],
       ...,
       [-5.50958066e-05, -1.62754586e-05,  1.41448181e-05, ...,
         8.03349030e-06, -9.04596527e-06,  1.74404586e-05],
       [-1.50813255e-04,  2.34261155e-04, -7.68830359e-05, ...,
         5.61119959e-05, -2.31735816e-04,  2.60820088e-04],
       [-8.51559453e-05, -4.71062958e-06,  2.90867756e-05, ...,
        -8.92109374e-05,  1.58204930e-05,  7.63103017e-05]])

## Computing the norms of the projected network

In [15]:
np.diag(a.transpose() @ FK @ a)

array([1.56248519, 2.24821384, 0.81616043, 0.7745437 , 1.27478325,
       1.0603234 , 1.55849892, 0.54283823, 0.72376634, 0.65062228])