Imports...

In [12]:
import torch.nn as nn   # Nueral network modules.
from collections import OrderedDict
import torch  # Base torch library
from torch.utils.data import DataLoader  # Minibathces
import torchvision.datasets as datasets  # MNIST dataset
import torchvision.transforms as transforms
import numpy as np
import torch.nn as nn  # Neural network modules
import torch.optim as optim  # Optimization algorithms
import pandas as pd

Model Class:

In [13]:
class NN(nn.Module):

    def __init__(self, input_size, middle_width, num_classes):


        super(NN, self).__init__()
        self.features = nn.Sequential(OrderedDict([
            ('hidden_layer', nn.Linear(input_size, middle_width)),
            ('hidden_activation', nn.ReLU()),
        ]))
        self.readout = nn.Linear(middle_width, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.readout(x)

        return x

Network Functions:

In [14]:
def set_device():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return device


def mnist_dataset(batch_size, train=True, values=list(range(10))):
    # Initializing MNIST data set.
    dataset = datasets.MNIST(root='dataset/', train=train, transform=transforms.ToTensor(), download=True)

    targets_list = dataset.targets.tolist()
    values_index = [i for i in range(len(dataset)) if targets_list[i] in values]

    # Creating a subset of ### MNIST targets.
    subset = torch.utils.data.Subset(dataset, values_index)
    loader = DataLoader(dataset=subset, batch_size=batch_size, shuffle=True)

    return loader


def train(loader, device, model, loss_function, optimizer_function, values=list(range(10))):
    # Training on each data point.
    for batch_idx, (data, targets) in enumerate(loader):
        data = data.reshape(data.shape[0], -1).to(device=device)
        targets = targets.to(device=device)

        # Forwards.
        scores = model(data)
        loss = loss_function(scores, classify_targets(targets, values))

        # Backwards.
        optimizer_function.zero_grad()
        loss.backward()

        optimizer_function.step()
        
        
        phi = slow_model.features(data)
        
        return targets, phi


def record_accuracy(device, model, train_loader, test_loader, epoch, values=list(range(10))):
    epoch_accuracy = np.array([[
        epoch + 1,
        check_accuracy(device, model, train_loader, values).cpu(),
        check_accuracy(device, model, test_loader, values).cpu()
    ]])

    return epoch_accuracy


def check_accuracy(device, model, loader, values=list(range(10))):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = classify_targets(y, values).to(device=device)
            x = x.reshape(x.shape[0], -1)

            scores = model(x)
            # 64images x 10,

            predictions = scores.argmax(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

    return 100 - 100. * num_correct / num_samples


def classify_targets(targets, values):
    new_targets = targets.clone()

    # Changing targets to a classifiable number.
    for key, element in enumerate(values):
        new_targets[targets == element] = key
    return new_targets

Main Program:

In [15]:
# Checking & Setting Device Allocation
device = set_device()
print(f"Running on {device}")

# Hyper Parameters
hp = {
    "Input Size": 784,
    "Middle Layer Width": 2000,
    "Num Classes": 2,
    "Regular Learning Rate": 0.01,
    "Slow Learning Rate": 0.001,
    "Batch Size": 200,
    "Epochs": 1
}
print(f"Hyper Parameters: {hp}")

# Initializing Model
slow_model = NN(input_size=hp["Input Size"],
                middle_width=hp["Middle Layer Width"],
                num_classes=hp["Num Classes"]).to(device=device)

reg_model = NN(input_size=hp["Input Size"],
               middle_width=hp["Middle Layer Width"],
               num_classes=hp["Num Classes"]).to(device=device)

# Loading MNIST Dataset
mnist_values = [8, 9]
print(f"MNIST digits {mnist_values}")
train_loader = mnist_dataset(hp["Batch Size"], values=mnist_values)
validate_loader = mnist_dataset(hp["Batch Size"], train=False, values=mnist_values)

# Loss function
loss_function = nn.CrossEntropyLoss()

# Optimizers
sl_optimizer = optim.SGD([{'params': slow_model.features.hidden_layer.parameters()},
                          {'params': slow_model.readout.parameters(),
                           'lr': hp["Regular Learning Rate"]}],
                         lr=hp["Slow Learning Rate"])
r_optimizer = optim.SGD(reg_model.parameters(), lr=hp["Regular Learning Rate"])

# Creating 'empty' arrays for future storing of accuracy metrics
slow_accuracy = np.zeros((1, 3))
regular_accuracy = np.zeros((1, 3))

print("Training models...")
for epoch in range(hp["Epochs"]):

    # Slow Model
    sl_targets, sl_phi = train(train_loader, device, slow_model, loss_function, sl_optimizer, values=mnist_values)
    slow_accuracy_epoch = record_accuracy(device, slow_model, train_loader, validate_loader, epoch, mnist_values)
    slow_accuracy = np.concatenate((slow_accuracy, slow_accuracy_epoch))
    print("Slow: ")
    print(slow_accuracy_epoch)
    # Regular Model
    reg_targets, reg_phi = train(train_loader, device, reg_model, loss_function, r_optimizer, values=mnist_values)
    regular_accuracy_epoch = record_accuracy(device, reg_model, train_loader, validate_loader, epoch, mnist_values)
    regular_accuracy = np.concatenate((regular_accuracy, regular_accuracy_epoch))
    print("Reg: ")
    print(regular_accuracy_epoch)
    print(f"-Finished epoch {epoch + 1}/{hp['Epochs']}")

Running on cpu
Hyper Parameters: {'Input Size': 784, 'Middle Layer Width': 2000, 'Num Classes': 2, 'Regular Learning Rate': 0.01, 'Slow Learning Rate': 0.001, 'Batch Size': 200, 'Epochs': 1}
MNIST digits [8, 9]
Training models...
Slow: 
[[ 1.         50.26271057 50.78164291]]
Reg: 
[[ 1.         34.16949463 31.46747589]]
-Finished epoch 1/1


***
Kernel Alignment Calc:
***

Kernel Matrix: $K_{1}$

In [16]:
sl_targets = torch.t(torch.unsqueeze(sl_targets, -1))

In [17]:
sl_targets

tensor([[8, 8, 9, 8, 9, 9, 8, 9, 9, 9, 8, 9, 8, 9, 8, 8, 8, 9, 9, 8, 8, 8, 8, 8,
         9, 8, 8, 9, 9, 9, 8, 9, 9, 8, 9, 8, 9, 8, 8, 9, 8, 8, 9, 8, 9, 9, 9, 9,
         8, 8, 8, 9, 9, 8, 8, 9, 9, 9, 8, 9, 8, 8, 8, 9, 8, 8, 9, 8, 9, 8, 9, 9,
         9, 8, 9, 9, 8, 9, 8, 8, 9, 8, 9, 8, 8, 8, 8, 9, 8, 9, 9, 8, 9, 8, 9, 9,
         8, 8, 9, 8, 8, 9, 8, 9, 8, 8, 8, 8, 9, 9, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9,
         8, 8, 8, 8, 8, 8, 8, 9, 9, 8, 8, 9, 8, 9, 9, 8, 8, 9, 8, 8, 8, 8, 8, 9,
         8, 9, 8, 8, 8, 8, 9, 9, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 8, 8, 8, 8, 9, 8,
         8, 8, 8, 9, 8, 8, 8, 8, 8, 9, 9, 8, 9, 9, 9, 8, 8, 8, 8, 9, 8, 8, 9, 8,
         8, 9, 8, 8, 9, 8, 9, 9]])

In [18]:
for i in range(sl_targets.size()[0]):
    if sl_targets[0][i] == 9:
        sl_targets[i] = int(1)
    else:
        sl_targets[i] = int(1)

In [19]:
K1 = torch.matmul(torch.t(sl_targets), sl_targets)

In [20]:
K1.size()

torch.Size([200, 200])

***

Kernel Matrix: $K_{2}$

In [21]:
slow_model.features.hidden_activation.eval()

ReLU()

In [22]:
phi = sl_phi

In [23]:
K2 = torch.matmul(phi, torch.t(phi))
K2.size()

torch.Size([200, 200])

***

Kernel Centering: 

$K_{c} = \left[ I - \frac{11^{T}}{m} \right] K \left[ I - \frac{11^{T}}{m} \right]$

*Note: let 1 denote the vector with all enteries equal to one and I being the identity matrix*

In [24]:
def kernel_centering(K):
    # Lemmna 1
    
    m = K.size()[0]
    I = torch.eye(m)
    l = torch.ones(m, 1)
    
    # I - ll^T / m
    mat = I - torch.matmul(l, torch.t(l))/m
    
    
    return torch.matmul(torch.matmul(mat, K), mat)
    
    

Centering Kernel $K_{1}$

In [25]:
Kc1 = kernel_centering(K1.float())
Kc1, Kc1.size()

(tensor([[ 2.9992e-10,  2.9992e-10,  2.9986e-10,  ..., -7.1516e-09,
          -7.1516e-09, -7.1516e-09],
         [ 2.9992e-10,  2.9992e-10,  2.9986e-10,  ..., -7.1516e-09,
          -7.1516e-09, -7.1516e-09],
         [ 2.9992e-10,  2.9992e-10,  2.9986e-10,  ..., -7.1516e-09,
          -7.1516e-09, -7.1516e-09],
         ...,
         [ 3.2056e-09,  3.2056e-09,  3.2055e-09,  ..., -7.6889e-08,
          -7.6889e-08, -7.6889e-08],
         [ 3.3889e-09,  3.3889e-09,  3.3890e-09,  ..., -8.1361e-08,
          -8.1361e-08, -8.1361e-08],
         [ 3.5778e-09,  3.5778e-09,  3.5778e-09,  ..., -8.5831e-08,
          -8.5831e-08, -8.5831e-08]]),
 torch.Size([200, 200]))

Centering Kernel $K_{2}$

In [26]:
Kc2 = kernel_centering(K2)
Kc2, Kc2.size()

(tensor([[10.5177,  1.9631, -2.7707,  ...,  0.6664, -1.9902, -2.7594],
         [ 1.9631, 11.3822, -3.6263,  ..., -0.9451, -2.3427, -2.7076],
         [-2.7707, -3.6263, 21.7066,  ...,  5.0532,  8.1675, 10.0201],
         ...,
         [ 0.6664, -0.9450,  5.0532,  ..., 13.3286,  1.3345,  1.7989],
         [-1.9902, -2.3427,  8.1675,  ...,  1.3345, 20.5117,  8.7552],
         [-2.7594, -2.7076, 10.0201,  ...,  1.7989,  8.7552, 25.7633]],
        grad_fn=<MmBackward0>),
 torch.Size([200, 200]))

***

Kernel Aligment Function: $\hat{p}(K, K') = \frac{\langle K_{c}, K_{c}' \rangle}{\| K_{c} \| \| K_{c}' \|} $

In [27]:
def kernel_alignment(K1, K2):
    return torch.mm(K1, K2)/(torch.norm(K1, p='fro')*torch.norm(K2, p='fro'))

In [28]:
kernel_alignment(Kc1, Kc2)

tensor([[ 4.5761e-05,  8.4785e-05, -1.5062e-04,  ..., -1.1316e-04,
         -2.3778e-04, -2.6700e-04],
        [ 4.5761e-05,  8.4785e-05, -1.5062e-04,  ..., -1.1316e-04,
         -2.3778e-04, -2.6700e-04],
        [ 4.5761e-05,  8.4785e-05, -1.5062e-04,  ..., -1.1316e-04,
         -2.3778e-04, -2.6700e-04],
        ...,
        [ 4.9179e-04,  9.1129e-04, -1.6191e-03,  ..., -1.2164e-03,
         -2.5561e-03, -2.8701e-03],
        [ 5.2035e-04,  9.6424e-04, -1.7133e-03,  ..., -1.2871e-03,
         -2.7047e-03, -3.0369e-03],
        [ 5.4898e-04,  1.0173e-03, -1.8074e-03,  ..., -1.3579e-03,
         -2.8533e-03, -3.2038e-03]], grad_fn=<DivBackward0>)

In [34]:
torch.norm(Kc1, p='fro'), torch.norm(Kc2, p='fro')

(tensor(1.6182e-06), tensor(567.4170, grad_fn=<CopyBackwards>))

In [38]:
torch.mm(Kc1, Kc2, p ='fro')

TypeError: mm() got an unexpected keyword argument 'p'