# Kernel Calculation Test
This notebook is used to test the functionality of the kernel calc methods that (if working) will be rewritten in Python script.

In [1]:
import torch
import resources as rs

Resources Loaded


### Loading Model Data

In [2]:
model_path = 'ex_models.pt'
state_dict = torch.load(model_path, map_location=torch.device('cpu')) # MODEL

  device=storage.device,


### Unpacking the params

**NOTE:** 
The model was trained on the MNIST digits of 0 1

Unpacking the features params

In [3]:
import torch.nn as nn

# get the weights and biases of the quantized model (for the features layer)
f_weights_quant = state_dict['features.hidden_layer._packed_params._packed_params'][0]
f_bias_quant = state_dict['features.hidden_layer._packed_params._packed_params'][1]

# dequantize the weights and biases
f_weights_float = torch.dequantize(f_weights_quant)
f_bias_float = torch.dequantize(f_bias_quant)
# print the float values of weights and biases
print(f_weights_float)
print(f_bias_float)

tensor([[ 0.0212, -0.0311, -0.0050,  ..., -0.0343, -0.0193,  0.0000],
        [ 0.0112, -0.0050,  0.0149,  ..., -0.0093, -0.0106,  0.0336],
        [-0.0324, -0.0006, -0.0062,  ...,  0.0311,  0.0137, -0.0324],
        ...,
        [-0.0305,  0.0168, -0.0293,  ...,  0.0237, -0.0280, -0.0112],
        [-0.0349,  0.0056, -0.0174,  ..., -0.0093, -0.0031,  0.0212],
        [ 0.0349,  0.0168, -0.0206,  ..., -0.0149,  0.0044, -0.0187]])
Parameter containing:
tensor([ 0.0235, -0.0103, -0.0210,  ..., -0.0244,  0.0028, -0.0313],
       grad_fn=<NotImplemented>)


Unpacking the readout params.

In [4]:
# get the weights and biases of the quantized model (for the readout layer)
r_weights_quant = state_dict['readout._packed_params._packed_params'][0]
r_bias_quant = state_dict['readout._packed_params._packed_params'][1]

# dequantize the weights and bises
r_weights_float = torch.dequantize(r_weights_quant)
r_bias_float = torch.dequantize(r_bias_quant)

print(r_weights_float)
print(r_bias_float)

tensor([[ 0.0219,  0.0389,  0.0024,  ..., -0.0073,  0.0182, -0.0158],
        [-0.0195, -0.0219, -0.0024,  ...,  0.0073, -0.0170,  0.0134]])
Parameter containing:
tensor([0.4623, 0.4824], grad_fn=<NotImplemented>)


### Manually updating the model

In [5]:
model = rs.NN()
model

NN(
  (features): Sequential(
    (hidden_layer): Linear(in_features=784, out_features=2048, bias=True)
    (hidden_activation): ReLU()
  )
  (readout): Linear(in_features=2048, out_features=2, bias=True)
)

In [6]:
params = list(model.parameters())
params

[Parameter containing:
 tensor([[-0.0135, -0.0237,  0.0054,  ...,  0.0342, -0.0247, -0.0317],
         [ 0.0325, -0.0190,  0.0209,  ...,  0.0345,  0.0137, -0.0097],
         [ 0.0105,  0.0267, -0.0238,  ..., -0.0092,  0.0330,  0.0343],
         ...,
         [-0.0160,  0.0044, -0.0138,  ..., -0.0175, -0.0025, -0.0057],
         [-0.0206,  0.0357,  0.0309,  ..., -0.0236,  0.0310, -0.0297],
         [ 0.0049, -0.0330,  0.0098,  ..., -0.0004, -0.0245, -0.0281]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0127, -0.0324, -0.0059,  ...,  0.0179,  0.0112, -0.0176],
        requires_grad=True),
 Parameter containing:
 tensor([[ 0.0174,  0.0033,  0.0001,  ...,  0.0201, -0.0122, -0.0172],
         [ 0.0101,  0.0214,  0.0112,  ..., -0.0177,  0.0167, -0.0195]],
        requires_grad=True),
 Parameter containing:
 tensor([0.0045, 0.0077], requires_grad=True)]

In [7]:
params[0].data = f_weights_float
params[1].data = f_bias_float
params[2].data = r_weights_float
params[3].data = r_bias_float

### Loading MNIST dataset

In [8]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
def mnist_dataset(batch_size, train=True, values=list(range(10))):
    # Initializing MNIST data set.
    dataset = datasets.MNIST(root='dataset/', train=train, transform=transforms.ToTensor(), download=True)

    targets_list = dataset.targets.tolist()
    values_index = [i for i in range(len(dataset)) if targets_list[i] in values]

    # Creating a subset of ### MNIST targets.
    subset = torch.utils.data.Subset(dataset, values_index)
    loader = DataLoader(dataset=subset, shuffle=True)

    return loader

In [9]:
MNIST = rs.mnist_dataset(batch_size=0, train=True, values=[0,1])
data, targets = next(iter(MNIST))

Next, in order to perform the CKA calc. we will need to reshape the data into a batch_size X features tenso (12665, 784).

In [10]:
data = torch.squeeze(data, dim=1)
data = data.view(data.size(0), -1)
data.shape

torch.Size([12665, 784])

### CKA Calc.
At this point we can now calculate the CKA for the model state.

In [11]:
model.features(data).shape

torch.Size([12665, 1024])

In [12]:
targets.shape

torch.Size([12665])

#TODO using this notebook create a python script that will do this for each of the model states (135 * 512)...

### Accuracy

In [13]:
import torch.optim as optim

In [14]:
device = torch.device('cpu')
loss = nn.MSELoss()
model.eval()
losses = rs.train(MNIST, device, model, loss, values=[0, 1], backwards=False, record_loss=True)

In [15]:
losses

8.250704013335053e-06

In [16]:
import time

### CKA Test.
Defining Center Kernel Alignment functions.

In [17]:
def kernel_calc(y, phi):
    y = torch.t(torch.unsqueeze(y, -1))
    
    start = time.time()
    K1 = torch.matmul(torch.t(y), y)
    
    K1c = kernel_centering(K1.float())

    K2 = torch.mm(phi, torch.t(phi))
    
    K2c = kernel_centering(K2)
    end = time.time()

    return kernel_alignment(K1c, K2c)


def frobenius_product(K1, K2):
    return torch.sum(K1 * K2)


def kernel_alignment(K1, K2):
    inner = frobenius_product(K1, K2) 
    mag_norm = ((torch.norm(K1, p='fro') * torch.norm(K2, p='fro')))
    return inner / mag_norm


def kernel_centering(K):
    row_means = K.mean(dim=1, keepdim=True)
    col_means = K.mean(dim=0, keepdim=True)
    total_mean = K.mean()
    
    return K - row_means - col_means + total_mean

In [18]:
start = time.time()
cka = kernel_calc(targets, model.features(data))
end = time.time()
print(f"CKA : {cka} |TOTAL TIME: {end - start}s")

CKA : 0.9493693709373474 |TOTAL TIME: 7.698458194732666s


### New CKA Function

In [19]:
phi = model.features(data)
y = torch.unsqueeze(targets.T, -1)

  y = torch.unsqueeze(targets.T, -1)


The new methods potentially allows us to ignore the centering calculation.

In [20]:
# New Method
start = time.time()
v = phi.T.matmul(y.float())
inner = v.T @ v
end = time.time()
print("NEW METHOD")
print(f"INNER = {inner.item()} | TIME = {end - start}s")

NEW METHOD
INNER = 406006720.0 | TIME = 0.0041730403900146484s


In [21]:
# Old Method
start = time.time()
K1 = y @ y.T
K1 = K1.float()
K2 = phi @ phi.T
inner = torch.trace(torch.mm(K2, torch.t(K1)))
end = time.time()
print("OLD METHOD")
print(f"INNER = {inner} | TIME = {end - start}s")

OLD METHOD
INNER = 406006848.0 | TIME = 58.009095191955566s


According to this new finding we can decrease computation time for the frobenius product by 3307.0 multiplier!!

### Translating this into HSIC

In [22]:
def HSIC(K1, K2):
    num = torch.trace(torch.mm(K2, K1.T))
    den = (len(K2) - 1) * (len(K1) - 1)
    return num / den

In [23]:
start = time.time()
test = HSIC(K1, K2)
end = time.time()
print("CLASSIC HSIC")
print(f"HSCI = {test} | TIME = {end - start}s")

CLASSIC HSIC
HSCI = 2.5315794944763184 | TIME = 51.93221092224121s


In [24]:
start = time.time()
v = phi.T.matmul(y.float())
test = v.T @ v / (((len(y) - 1) * (len(phi) - 1)))
end = time.time()
print("NEW HSIC")
print(f"NEW HSIC = {test.item()} | TIME = {end - start}s")

NEW HSIC
NEW HSIC = 2.531578540802002 | TIME = 0.004106044769287109s


It's of no suprize here that the old methods is much much slower to the same scale as previous.

### Implementing This Findining Into CKA

In [25]:
def cka(y, phi):
    
    # Centering y
    n_y = len(y)
    ones_y = torch.ones(n_y, 1)
    yc = torch.eye(n_y) - (ones_y @ ones_y.T @ y) / n_y
    
    # Centering phi
    n_phi = len(phi)
    ones_phi = torch.ones(n_phi, 1)
    phic = torch.eye(n_phi) - (ones_phi @ ones_phi @ y) / n_phi
    
    # CKA 
    # Numerator HSIC(K_y, K_phi)
    upper = (phic.T @ yc).T @ (phic.T @ yc) / torch.sqrt((n_y - 1)*(n_phi -1))
    # Demonerator sqrt(HSIC(K_y, K_y) * HSIC(K_phi, K_phi))
    lower = torch.sqrt()

In [27]:
# Centering y
y = y.float()
n_y = len(y)
ones_y = torch.ones(n_y, 1)
yc = torch.eye(n_y) - (ones_y @ ones_y.T @ y) / n_y

In [29]:
# Centering phi
n_phi = len(phi)
ones_phi = torch.ones(n_phi, 1)
phic = torch.eye(n_phi) - (ones_phi @ ones_phi.T @ y) / n_phi

In [None]:
upper = (phic.T @ yc).T @ (phic.T @ yc) / torch.sqrt((n_y - 1)*(n_phi -1))

In [None]:
lower = torch.sqrt()