# Kernel Calculation Test
This notebook is used to test the functionality of the kernel calc methods that (if working) will be rewritten in Python script.

In [None]:
import torch
import resources as rs

### Loading Model Data

In [None]:
model_path = 'ex_models.pt'
state_dict = torch.load(model_path, map_location=torch.device('cpu')) # MODEL

### Unpacking the params

**NOTE:** 
The model was trained on the MNIST digits of 0 1

Unpacking the features params

In [None]:
import torch.nn as nn

# get the weights and biases of the quantized model (for the features layer)
f_weights_quant = state_dict['features.hidden_layer._packed_params._packed_params'][0]
f_bias_quant = state_dict['features.hidden_layer._packed_params._packed_params'][1]

# dequantize the weights and biases
f_weights_float = torch.dequantize(f_weights_quant)
f_bias_float = torch.dequantize(f_bias_quant)
# print the float values of weights and biases
print(f_weights_float)
print(f_bias_float)

Unpacking the readout params.

In [None]:
# get the weights and biases of the quantized model (for the readout layer)
r_weights_quant = state_dict['readout._packed_params._packed_params'][0]
r_bias_quant = state_dict['readout._packed_params._packed_params'][1]

# dequantize the weights and bises
r_weights_float = torch.dequantize(r_weights_quant)
r_bias_float = torch.dequantize(r_bias_quant)

print(r_weights_float)
print(r_bias_float)

### Manually updating the model

In [None]:
model = rs.NN()
model

In [None]:
params = list(model.parameters())
params

In [None]:
params[0].data = f_weights_float
params[1].data = f_bias_float
params[2].data = r_weights_float
params[3].data = r_bias_float

### Loading MNIST dataset

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
def mnist_dataset(batch_size, train=True, values=list(range(10))):
    # Initializing MNIST data set.
    dataset = datasets.MNIST(root='dataset/', train=train, transform=transforms.ToTensor(), download=True)

    targets_list = dataset.targets.tolist()
    values_index = [i for i in range(len(dataset)) if targets_list[i] in values]

    # Creating a subset of ### MNIST targets.
    subset = torch.utils.data.Subset(dataset, values_index)
    loader = DataLoader(dataset=subset, shuffle=True)

    return loader

In [None]:
MNIST = rs.mnist_dataset(batch_size=0, train=True, values=[0,1])
data, targets = next(iter(MNIST))

Next, in order to perform the CKA calc. we will need to reshape the data into a batch_size X features tenso (12665, 784).

In [598]:
data = torch.squeeze(data, dim=1)
data = data.view(data.size(0), -1)
data.shape

torch.Size([12665, 784])

### CKA Calc.
At this point we can now calculate the CKA for the model state.

In [None]:
model.features(data).shape

In [None]:
targets.shape

#TODO using this notebook create a python script that will do this for each of the model states (135 * 512)...

### Accuracy

In [None]:
import torch.optim as optim

In [None]:
device = torch.device('cpu')
loss = nn.MSELoss()
model.eval()
losses = rs.train(MNIST, device, model, loss, values=[0, 1], backwards=False, record_loss=True)

In [None]:
losses

In [None]:
import time

### CKA Test.
Defining Center Kernel Alignment functions.

In [None]:
def kernel_calc(y, phi):
    y = torch.t(torch.unsqueeze(y, -1))
    
    start = time.time()
    y = vector_centering(y)
    K1c = torch.matmul(torch.t(y), y)

    K2 = torch.mm(phi, torch.t(phi))
    
    K2c = kernel_centering(K2)
    end = time.time()

    return kernel_alignment(K1c, K2c)


def frobenius_product(K1, K2):
    return torch.sum(K1 * K2)
    #return torch.trace(torch.mm(K2, torch.t(K1)))


def kernel_alignment(K1, K2):
    inner = frobenius_product(K1, K2) 
    K1_norm = torch.norm(K1, p='fro')
    K2_norm = torch.norm(K2, p='fro')
    mag_norm = (K1_norm * K2_norm)
    print(f"K1 NORM = {K1_norm}")
    print(f"K2 NORM = {K2_norm}")
    return inner / mag_norm


def kernel_centering(K):
    row_means = K.mean(dim=1, keepdim=True)
    col_means = K.mean(dim=0, keepdim=True)
    total_mean = K.mean()
    
    return K - row_means - col_means + total_mean

def vector_centering(v):
    mean = torch.mean(v.float())
    centered_v = v - mean
    return centered_v

In [590]:
def cka(y, phi):
    y = vector_centering(y)
    K1c = y.T @ y
    phi = kernel_centering(phi)
    v = phi.T @ y
    inner = (v.T @ v) / (torch.norm(y @ y.T) * torch.norm(phi @ phi.T))
    return inner
    

In [596]:
def cka_estimate(y, phi):
    
    y = vector_centering(y)
    phi = kernel_centering(phi)
    
    v = phi.T @ y
    inner = (v.T @ v) / (y.T @ y * torch.norm(phic.T @ phic))
    return inner

In [597]:
start = time.time()
test = cka_estimate(y, phi)
end = time.time()
print(f"CKA : {test.item()} |TOTAL TIME: {end - start}s")

CKA : 0.9035245776176453 |TOTAL TIME: 0.38608384132385254s


In [593]:
start = time.time()
test = cka(y, phi)
end = time.time()
print(f"CKA : {test} |TOTAL TIME: {end - start}s")

CKA : tensor([[0.9511]], grad_fn=<DivBackward0>) |TOTAL TIME: 5.037832736968994s


In [563]:
start = time.time()
cka = kernel_calc(targets, model.features(data))
end = time.time()
print(f"CKA : {cka} |TOTAL TIME: {end - start}s")

K1 NORM = 3030.962158203125
K2 NORM = 40475.64453125
CKA : 0.9505146741867065 |TOTAL TIME: 7.848598003387451s


In [486]:
torch.sum(torch.trace(yc @ yc.T))

tensor(3153.0095)

In [494]:
torch.norm(yc) ** 2

tensor(3153.0725)

In [497]:
torch.norm(yc @ yc.T)

tensor(3030.9622)

In [502]:
t = torch.randint(low=0, high=1000, size=(10000, 1), dtype=torch.float)

In [503]:
torch.norm(t @ t.T)

tensor(3.2637e+09)

In [505]:
import torch

# Assume X is a tensor of size (100, 1000)
batch_size, height_width = X.size()

# Reshape X to be a matrix of size (batch_size, height_width)
X_reshaped = X.view(batch_size, height_width)

# Compute the norm of the Gram matrix in terms of the trace
norm = torch.trace(torch.mm(torch.mm(X_reshaped, X_reshaped.t()), torch.mm(X_reshaped, X_reshaped.t())))

print(norm)


tensor(1.2238e+22)


### New CKA Function

In [496]:
phi = model.features(data)
y = torch.unsqueeze(targets.T, -1)

The new methods potentially allows us to ignore the centering calculation.

In [399]:
# New Method
start = time.time()
v = phi.T.matmul(y.float())
inner = v.T @ v
end = time.time()
print("NEW METHOD")
print(f"INNER = {inner.item()} | TIME = {end - start}s")

NEW METHOD
INNER = 406006560.0 | TIME = 0.003192901611328125s


In [402]:
# Old Method
start = time.time()
K1 = y @ y.T
K1 = K1.float()
K2 = phi @ phi.T
inner = torch.trace(torch.mm(K2, torch.t(K1)))
end = time.time()
print("OLD METHOD")
print(f"INNER = {inner} | TIME = {end - start}s")

OLD METHOD
INNER = 406006848.0 | TIME = 57.78049302101135s


In [403]:
inner

tensor(4.0601e+08, grad_fn=<TraceBackward0>)

In [404]:
y = y.float()
torch.trace(torch.mm(phi @ phi.T, (y @ y.T).T))

tensor(4.0601e+08, grad_fn=<TraceBackward0>)

According to this new finding we can decrease computation time for the frobenius product by 3307.0 multiplier!!

### Translating this into HSIC

In [405]:
def HSIC(K1, K2):
    num = torch.trace(torch.mm(K2, K1.T))
    den = (len(K2) - 1) * (len(K1) - 1)
    return num / den

In [406]:
start = time.time()
test = HSIC(K1, K2)
end = time.time()
print("CLASSIC HSIC")
print(f"HSCI = {test} | TIME = {end - start}s")

CLASSIC HSIC
HSCI = 2.5315794944763184 | TIME = 51.769274950027466s


In [None]:
start = time.time()
v = phi.T.matmul(y.float())
test = v.T @ v / (((len(y) - 1) * (len(phi) - 1)))
end = time.time()
print("NEW HSIC")
print(f"NEW HSIC = {test.item()} | TIME = {end - start}s")

It's of no suprize here that the old methods is much much slower to the same scale as previous.

### Implementing This Findining Into CKA

In [None]:
def cka(y, phi):
    
    # Centering y
    n_y = len(y)
    ones_y = torch.ones(n_y, 1)
    yc = torch.eye(n_y) - (ones_y @ ones_y.T @ y) / n_y
    
    # Centering phi
    n_phi = len(phi)
    ones_phi = torch.ones(n_phi, 1)
    phic = torch.eye(n_phi) - (ones_phi @ ones_phi @ y) / n_phi
    
    # CKA 
    # Numerator HSIC(K_y, K_phi)
    upper = (phic.T @ yc).T @ (phic.T @ yc) / torch.sqrt((n_y - 1)*(n_phi -1))
    # Demonerator sqrt(HSIC(K_y, K_y) * HSIC(K_phi, K_phi))
    lower = torch.sqrt()

### Creating CKA function with this new method

In [None]:
import math

In [519]:
start = time.time()
phic = kernel_centering(phi)
yc = vector_centering(y)
v = phic.T.matmul(yc)
u = (v.T @ v)
d = torch.norm(yc @ yc.T, ) * torch.norm(phic.T @ phic)
test = u / d
end = time.time()
print("NEW CKA")
print(f"CKA = {test.item()} | TIME = {end - start}s")

NEW CKA
CKA = 0.9399687647819519 | TIME = 3.0141968727111816s


In [461]:
torch.norm(yc @ yc.T)

tensor(3030.9622)

In [470]:
torch.norm(phic @ phic.T)

tensor(40446.1602, grad_fn=<LinalgVectorNormBackward0>)

In [None]:
40926.1641

In [449]:
torch.norm(phic.T @ phic)

tensor(40926.1641, grad_fn=<LinalgVectorNormBackward0>)

### Norm Trick?

For the case of the 'y' vector, finding a faseter method for finding the kernel norn is trivial, for the sake of time I simply state that it is the dot product.

However it gets more complicated for the case of phi...

In [411]:
start = time.time()
meany =  torch.sqrt(torch.sum(torch.norm(yc.T @ yc, dim=1)**2))
meanphi = torch.sqrt(torch.sum(torch.norm(phic @ phic.T, dim=1)**2))
mnorm = meany * meanphi
end = time.time()
print(f"NORM: {mnorm} |TOTAL TIME: {end - start}s")

NORM: 129059912.0 |TOTAL TIME: 4.514939069747925s


In [412]:
# Mean of PHI
meanphi.item()

40929.58984375

In [413]:
start = time.time()
meany = yc.T @ yc
meanphi = torch.sqrt(torch.sum((phic.T @ phic)**2))
# Compute the product of the norms
mnorm = meany * meanphi
end = time.time()
print(f"NORM: {mnorm.item()} |TOTAL TIME: {end - start}s")

NORM: 129059912.0 |TOTAL TIME: 0.37685585021972656s


In [414]:
K1_NORM = 3029.86279296875
K2_NORM = 40491.578125

In [415]:
K1_NORM * K2_NORM

122683925.98952484

In [416]:
torch.norm(vector_centering(y) @ vector_centering(y).T)

tensor(3030.9622)

In [417]:
torch.norm(yc.T) ** 2

tensor(3153.0725)

### Testing Norm Calc For Vector

In [418]:
t = torch.tensor([1, 2, 3, 4])

In [419]:
t = t.unsqueeze(-1).float()

In [420]:
t.T @ t

tensor([[30.]])

In [421]:
torch.norm(t @ t.T)

tensor(30.)

In [422]:
torch.norm(t) ** 2

tensor(30.0000)

So this means that all the above a are equal??

### Testing Norm Calc For Matrix

In [443]:
T = torch.randint(low=-1000, high=1000, size=(100000, 10000), dtype=torch.float)

In [None]:
T.T @ T

In [None]:
start = time.time()
torch.norm(phic.T @ phic)
end = time.time()
print(end - start)

In [436]:

torch.norm(torch.mm(T.T, T))

tensor(1.1062e+11)

In [437]:
import torch

X = T
# Assume X is a tensor of size (100, 1000)
batch_size, height_width = X.size()

# Reshape X to be a matrix of size (batch_size, height_width)
X_reshaped = X.view(batch_size, height_width)

# Compute the norms of the rows of X
row_norms = torch.norm(X_reshaped, dim=1)

# Compute the norm of the Gram matrix
norm = torch.norm(torch.mm(X_reshaped, X_reshaped.t())) ** 2 / (batch_size ** 2) / 2

print(norm)


tensor(5.9769e+13)
