# Kernel Calculation Test
This notebook is used to test the functionality of the kernel calc methods that (if working) will be rewritten in Python script.

In [1]:
import torch
import resources as rs

ModuleNotFoundError: No module named 'torchvision'

### Loading Model Data

In [None]:
model_path = 'ex_models.pt'
state_dict = torch.load(model_path, map_location=torch.device('cpu')) # MODEL

### Unpacking the params

**NOTE:** 
The model was trained on the MNIST digits of 0 1

Unpacking the features params

In [None]:
import torch.nn as nn

# get the weights and biases of the quantized model (for the features layer)
f_weights_quant = state_dict['features.hidden_layer._packed_params._packed_params'][0]
f_bias_quant = state_dict['features.hidden_layer._packed_params._packed_params'][1]

# dequantize the weights and biases
f_weights_float = torch.dequantize(f_weights_quant)
f_bias_float = torch.dequantize(f_bias_quant)
# print the float values of weights and biases
print(f_weights_float)
print(f_bias_float)

Unpacking the readout params.

In [None]:
# get the weights and biases of the quantized model (for the readout layer)
r_weights_quant = state_dict['readout._packed_params._packed_params'][0]
r_bias_quant = state_dict['readout._packed_params._packed_params'][1]

# dequantize the weights and bises
r_weights_float = torch.dequantize(r_weights_quant)
r_bias_float = torch.dequantize(r_bias_quant)

print(r_weights_float)
print(r_bias_float)

### Manually updating the model

In [None]:
model = rs.NN()
model

In [None]:
params = list(model.parameters())
params

In [None]:
params[0].data = f_weights_float
params[1].data = f_bias_float
params[2].data = r_weights_float
params[3].data = r_bias_float

### Loading MNIST dataset

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
def mnist_dataset(batch_size, train=True, values=list(range(10))):
    # Initializing MNIST data set.
    dataset = datasets.MNIST(root='dataset/', train=train, transform=transforms.ToTensor(), download=True)

    targets_list = dataset.targets.tolist()
    values_index = [i for i in range(len(dataset)) if targets_list[i] in values]

    # Creating a subset of ### MNIST targets.
    subset = torch.utils.data.Subset(dataset, values_index)
    loader = DataLoader(dataset=subset, shuffle=True)

    return loader

In [None]:
MNIST = rs.mnist_dataset(batch_size=0, train=True, values=[0,1])
data, targets = next(iter(MNIST))

Next, in order to perform the CKA calc. we will need to reshape the data into a batch_size X features tenso (12665, 784).

In [None]:
data = torch.squeeze(data, dim=1)
data = data.view(data.size(0), -1)
data.shape

### CKA Calc.
At this point we can now calculate the CKA for the model state.

In [None]:
model.features(data).shape

In [None]:
targets.shape

In [None]:
rs.kernel_calc(targets, model.features(data))

#TODO using this notebook create a python script that will do this for each of the model states (135 * 512)...

### Accuracy

In [None]:
import torch.optim as optim

In [None]:
device = torch.device('cpu')
loss = nn.MSELoss()
model.eval()
losses = rs.train(MNIST, device, model, loss, values=[0, 1], backwards=False, record_loss=True)

In [None]:
losses

In [None]:
import time

### CKA Test.
Defining Center Kernel Alignment functions.

In [None]:
def kernel_calc(y, phi):
    y = torch.t(torch.unsqueeze(y, -1))
    
    start = time.time()
    K1 = torch.matmul(torch.t(y), y)
    
    K1c = kernel_centering(K1.float())

    K2 = torch.mm(phi, torch.t(phi))
    
    K2c = kernel_centering(K2)
    end = time.time()

    return kernel_alignment(K1c, K2c)


def frobenius_product(K1, K2):
    return torch.sum(K1 * K2)


def kernel_alignment(K1, K2):
    inner = frobenius_product(K1, K2) 
    mag_norm = ((torch.norm(K1, p='fro') * torch.norm(K2, p='fro')))
    return inner / mag_norm


def kernel_centering(K):
    row_means = K.mean(dim=1, keepdim=True)
    col_means = K.mean(dim=0, keepdim=True)
    total_mean = K.mean()
    
    return K - row_means - col_means + total_mean

In [None]:
start = time.time()
cka = kernel_calc(targets, model.features(data))
end = time.time()
print(f"CKA : {cka} |TOTAL TIME: {end - start}s")

### New CKA Function

In [None]:
phi = model.features(data)
y = torch.unsqueeze(targets.T, -1)

The new methods potentially allows us to ignore the centering calculation.

In [None]:
# New Method
start = time.time()
v = phi.T.matmul(y.float())
inner = v.T @ v
end = time.time()
print("NEW METHOD")
print(f"INNER = {inner} | TIME = {end - start}s")

In [None]:
# Old Method
start = time.time()
K1 = y @ y.T
K1 = K1.float()
K2 = phi @ phi.T
inner = torch.trace(torch.mm(K2, torch.t(K1)))
end = time.time()
print("OLD METHOD")
print(f"INNER = {inner} | TIME = {end - start}s")

According to this new finding we can decrease computation time for the frobenius product by 3307.0 multiplier!!

### Using this result to produce our a new CKA method.

In [None]:
def cka(phi, y):
    
    # Targets vectory [y]
    y_m = len(y)
    y_ones = torch.ones(y_m, 1)
    y_c = y - (y_ones * y_ones.T * y) / y_m
    
    # Features matrix [PHI]
    phi_m = len(phi)
    phi_ones = torch.ones(phi_m, 1)
    phi_c = phi 

In [None]:
cka(phi, y.flo)