# CENTER KERNEL ALIGNEMENT TEST

*By: Cameron Kaminski*

*04.30.2023*

This notebook has beed used / written for the purpose of testing the computational efficiecy of my Centered Kernel Alignment (to be referred as cka) function. 

In [1]:
import torch
import resources as rs

Resources Loaded


## LOADING MODEL DATA

Ignore the following...

In [2]:
# Unpacking the params
model_path = 'ex_models.pt'
state_dict = torch.load(model_path, map_location=torch.device('cpu')) # MODEL

  device=storage.device,


In [3]:
# Unpacking the features params
import torch.nn as nn

# get the weights and biases of the quantized model (for the features layer)
f_weights_quant = state_dict['features.hidden_layer._packed_params._packed_params'][0]
f_bias_quant = state_dict['features.hidden_layer._packed_params._packed_params'][1]

# dequantize the weights and biases
f_weights_float = torch.dequantize(f_weights_quant)
f_bias_float = torch.dequantize(f_bias_quant)

In [4]:
# get the weights and biases of the quantized model (for the readout layer)
r_weights_quant = state_dict['readout._packed_params._packed_params'][0]
r_bias_quant = state_dict['readout._packed_params._packed_params'][1]

# dequantize the weights and bises
r_weights_float = torch.dequantize(r_weights_quant)
r_bias_float = torch.dequantize(r_bias_quant)

In [5]:
# Manually updating the model
model = rs.NN()
params = list(model.parameters())
params[0].data = f_weights_float
params[1].data = f_bias_float
params[2].data = r_weights_float
params[3].data = r_bias_float

## Loading MNIST dataset

We now are going to get input data.

In [6]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
def mnist_dataset(batch_size, train=True, values=list(range(10))):
    # Initializing MNIST data set.
    dataset = datasets.MNIST(root='dataset/', train=train, transform=transforms.ToTensor(), download=True)

    targets_list = dataset.targets.tolist()
    values_index = [i for i in range(len(dataset)) if targets_list[i] in values]

    # Creating a subset of ### MNIST targets.
    subset = torch.utils.data.Subset(dataset, values_index)
    loader = DataLoader(dataset=subset, shuffle=True)

    return loader

In [58]:
MNIST = rs.mnist_dataset(batch_size=0, train=True, values=[0,1])
data, targets = next(iter(MNIST))

In [59]:
data = torch.squeeze(data, dim=1)
data = data.view(data.size(0), -1)

In [100]:
phi = model.features(data).double()
y = torch.unsqueeze(targets.T, -1).double()

## C.K.A. Calculation

Now we can start calculating the C.K.A. calc.

The first CKA calc comes directly from my old repo.

In [101]:
def old_cka(y, phi):
    y = y.T
    
    start = time.time()
    y = vector_centering(y)
    K1c = y.T @ y

    K2 = phi @ phi.T
    
    K2c = kernel_centering(K2)
    end = time.time()

    return kernel_alignment(K1c, K2c)


def frobenius_product(K1, K2):
    # For a HUGE speed increase.
    return torch.sum(K1 * K2)
    #return torch.trace(torch.mm(K2, torch.t(K1)))


def kernel_alignment(K1, K2):
    inner = frobenius_product(K1, K2) 
    K1_norm = torch.norm(K1, p='fro')
    K2_norm = torch.norm(K2, p='fro')
    mag_norm = (K1_norm * K2_norm)
    print(f"K1 NORM = {K1_norm}")
    print(f"K2 NORM = {K2_norm}")
    return inner / mag_norm


def kernel_centering(K):
    row_means = K.mean(dim=1, keepdim=True)
    col_means = K.mean(dim=0, keepdim=True)
    total_mean = K.mean()
    
    return K - row_means - col_means + total_mean

def vector_centering(v):
    mean = torch.mean(v.double())
    centered_v = v - mean
    return centered_v

In [102]:
def new_cka(y, phi):
    yc = vector_centering(y)
    #K1c = y.T @ y
    phic = kernel_centering(phi)
    v = phic.T @ yc
    inner = (v.T @ v) / (torch.norm(yc @ yc.T) * torch.norm(phic.T @ phic))
    return inner

In [103]:
def estimate_cka(y, phi):
    yc = vector_centering(y)
    phic = kernel_centering(phi)
    v = phic.T @ yc
    inner = (v.T @ v) / (torch.norm(yc) ** 2 * torch.norm(phic.T @ phic))
    return inner

## Testing CKA

In [23]:
import time

Old CKA

In [104]:
start = time.time()
test = old_cka(y, phi)
end = time.time()
print(f"OLD CKA : {test.item()} |TOTAL TIME: {end - start}s")

K1 NORM = 3153.0095542666504
K2 NORM = 40956.472252104475
OLD CKA : 0.9029949111074519 |TOTAL TIME: 15.536466121673584s


New CKA

In [105]:
start = time.time()
test = new_cka(y, phi)
end = time.time()
print(f"NEW CKA : {test.item()} |TOTAL TIME: {end - start}s")

NEW CKA : 0.9035078157156612 |TOTAL TIME: 1.9259819984436035s


Estimated CKA

In [106]:
start = time.time()
test = estimate_cka(y, phi)
end = time.time()
print(f"EST CKA : {test.item()} |TOTAL TIME: {end - start}s")

EST CKA : 0.9035078158240055 |TOTAL TIME: 0.7852292060852051s


In [17]:
phi.shape

torch.Size([12665, 1024])

In [18]:
y.shape

torch.Size([12665, 1])

In [19]:
1024 / 12665

0.08085274378207659

In [32]:
yc = vector_centering(y)
torch.norm(yc) ** 2

tensor(3153.0725)

In [30]:
torch.norm(yc @ yc.T, p = 'fro')

tensor(3025.9299)

In [33]:
yc

tensor([[-0.5323],
        [ 0.4677],
        [ 0.4677],
        ...,
        [-0.5323],
        [ 0.4677],
        [-0.5323]])

In [36]:
import numpy as np
y2=np.array(yc)

In [41]:
np.linalg.norm(y2) ** 2

3153.0068819681183

In [39]:
np.linalg.norm(y2 @ y2.T, 'fro')

3174.2888

In [64]:
phic = kernel_centering(phi)

In [72]:
torch.norm(phic.T @ phic, 'fro')

tensor(40926.1641, grad_fn=<LinalgVectorNormBackward0>)

In [73]:
torch.norm(phic @ phic.T, 'fro')

tensor(40445.0273, grad_fn=<LinalgVectorNormBackward0>)

In [74]:
A = phic @ phic.T

In [75]:
torch.norm(A, 'fro')

tensor(40445.0273, grad_fn=<LinalgVectorNormBackward0>)

In [76]:
torch.norm(A.T, 'fro')

tensor(40445.0273, grad_fn=<LinalgVectorNormBackward0>)

In [77]:
phic.min()

tensor(-0.4213, grad_fn=<MinBackward1>)

In [110]:
yc[0]

tensor([-0.5323])

In [111]:
vector_centering(y)[0]

tensor([-0.5323], dtype=torch.float64)

In [99]:
data.type()

'torch.FloatTensor'