# CENTER KERNEL ALIGNEMENT TEST

*By: Cameron Kaminski*

*04.30.2023*

This notebook has beed used / written for the purpose of testing the computational efficiecy of my Centered Kernel Alignment (to be referred as cka) function. 

In [1]:
import torch
import resources as rs

Resources Loaded


## LOADING MODEL DATA

Ignore the following...

In [7]:
# Unpacking the params
model_path = 'ex_models.pt'
state_dict = torch.load(model_path, map_location=torch.device('cpu')) # MODEL

In [10]:
# Unpacking the features params
import torch.nn as nn

# get the weights and biases of the quantized model (for the features layer)
f_weights_quant = state_dict['features.hidden_layer._packed_params._packed_params'][0]
f_bias_quant = state_dict['features.hidden_layer._packed_params._packed_params'][1]

# dequantize the weights and biases
f_weights_float = torch.dequantize(f_weights_quant)
f_bias_float = torch.dequantize(f_bias_quant)

In [11]:
# get the weights and biases of the quantized model (for the readout layer)
r_weights_quant = state_dict['readout._packed_params._packed_params'][0]
r_bias_quant = state_dict['readout._packed_params._packed_params'][1]

# dequantize the weights and bises
r_weights_float = torch.dequantize(r_weights_quant)
r_bias_float = torch.dequantize(r_bias_quant)

In [None]:
# Manually updating the model
model = rs.NN()
params = list(model.parameters())
params[0].data = f_weights_float
params[1].data = f_bias_float
params[2].data = r_weights_float
params[3].data = r_bias_float

## Loading MNIST dataset

We now are going to get input data.

In [13]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
def mnist_dataset(batch_size, train=True, values=list(range(10))):
    # Initializing MNIST data set.
    dataset = datasets.MNIST(root='dataset/', train=train, transform=transforms.ToTensor(), download=True)

    targets_list = dataset.targets.tolist()
    values_index = [i for i in range(len(dataset)) if targets_list[i] in values]

    # Creating a subset of ### MNIST targets.
    subset = torch.utils.data.Subset(dataset, values_index)
    loader = DataLoader(dataset=subset, shuffle=True)

    return loader

In [14]:
MNIST = rs.mnist_dataset(batch_size=0, train=True, values=[0,1])
data, targets = next(iter(MNIST))

In [15]:
data = torch.squeeze(data, dim=1)
data = data.view(data.size(0), -1)

In [22]:
phi = model.features(data)
y = torch.unsqueeze(targets.T, -1)

  y = torch.unsqueeze(targets.T, -1)


## C.K.A. Calculation

Now we can start calculating the C.K.A. calc.

The first CKA calc comes directly from my old repo.

In [33]:
def old_cka(y, phi):
    y = y.T
    
    start = time.time()
    y = vector_centering(y)
    K1c = torch.matmul(torch.t(y), y)

    K2 = torch.mm(phi, torch.t(phi))
    
    K2c = kernel_centering(K2)
    end = time.time()

    return kernel_alignment(K1c, K2c)


def frobenius_product(K1, K2):
    # For a HUGE speed increase.
    # return torch.sum(K1 * K2)
    return torch.trace(torch.mm(K2, torch.t(K1)))


def kernel_alignment(K1, K2):
    inner = frobenius_product(K1, K2) 
    K1_norm = torch.norm(K1, p='fro')
    K2_norm = torch.norm(K2, p='fro')
    mag_norm = (K1_norm * K2_norm)
    print(f"K1 NORM = {K1_norm}")
    print(f"K2 NORM = {K2_norm}")
    return inner / mag_norm


def kernel_centering(K):
    row_means = K.mean(dim=1, keepdim=True)
    col_means = K.mean(dim=0, keepdim=True)
    total_mean = K.mean()
    
    return K - row_means - col_means + total_mean

def vector_centering(v):
    mean = torch.mean(v.float())
    centered_v = v - mean
    return centered_v

In [34]:
def new_cka(y, phi):
    y = vector_centering(y)
    K1c = y.T @ y
    phi = kernel_centering(phi)
    v = phi.T @ y
    inner = (v.T @ v) / (torch.norm(y @ y.T) * torch.norm(phi @ phi.T))
    return inner

In [41]:
def estimate_cka(y, phi):
    y = vector_centering(y)
    phic = kernel_centering(phi)
    v = phi.T @ y
    inner = (v.T @ v) / (y.T @ y * torch.norm(phic.T @ phic))
    return inner

## Testing CKA

In [42]:
import time

Old CKA

In [37]:
start = time.time()
test = old_cka(y, phi)
end = time.time()
print(f"OLD CKA : {test.item()} |TOTAL TIME: {end - start}s")

K1 NORM = 3031.20654296875
K2 NORM = 40470.0703125
OLD CKA : 0.9505689740180969 |TOTAL TIME: 60.00688886642456s


New CKA

In [39]:
start = time.time()
test = new_cka(y, phi)
end = time.time()
print(f"NEW CKA : {test.item()} |TOTAL TIME: {end - start}s")

OLD CKA : 0.95118647813797 |TOTAL TIME: 5.5530829429626465s


Estimated CKA

In [44]:
start = time.time()
test = estimate_cka(y, phi)
end = time.time()
print(f"EST CKA : {test.item()} |TOTAL TIME: {end - start}s")

EST CKA : 0.9036058187484741 |TOTAL TIME: 0.3930377960205078s
