<a href="https://colab.research.google.com/github/davidgonmar/model-compression-exps/blob/main/weight_act_correlation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- RESNET ACTIVATION/WEIGHT INDEPENDENCE TEST ---
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# --- Load pretrained ResNet ---
model = models.resnet18(pretrained=True).eval()

# --- Hook for intermediate activations ---
activations = {}
def hook_fn(name):
    def hook(module, input, output):
        activations[name] = output.detach().cpu()
    return hook

# --- Pick a few layers to test ---
layer_names = ['layer1.0.conv2', 'layer2.0.conv2']
for name in layer_names:
    module = dict(model.named_modules())[name]
    module.register_forward_hook(hook_fn(name))

# --- Load a small dataset batch ---
transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.ToTensor()
])
dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=64, shuffle=False)

# --- Run one batch through model ---
images, _ = next(iter(loader))
with torch.no_grad():
    model(images)

# --- Function to compute correlation matrix ---
def corr_matrix(x):
    x = x.reshape(x.shape[0], -1)  # flatten spatial dims
    x = x - x.mean(dim=0, keepdim=True)
    cov = (x.T @ x) / (x.shape[0] - 1)
    std = torch.sqrt(torch.diag(cov)).unsqueeze(1)
    corr = cov / (std @ std.T + 1e-6)
    return corr.numpy()

# --- Plot heatmap of correlation matrix ---
def plot_corr(corr, title):
    plt.figure(figsize=(6,5))
    sns.heatmap(corr, cmap='coolwarm', center=0, square=True, xticklabels=False, yticklabels=False)
    plt.title(title)
    plt.tight_layout()
    plt.show()

# --- Check activation independence ---
for name in layer_names:
    act = activations[name]
    # Aggregate over spatial dims (mean pooling)
    act_flat = act.mean(dim=[2,3]) if act.ndim == 4 else act
    corr = np.abs(corr_matrix(act_flat))
    print(f"Activation correlation matrix: {name}")
    plot_corr(corr, f"Activation Correlation: {name}")

# --- Check weight independence ---
for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        w = module.weight.data.clone().cpu()
        w_flat = w.view(w.shape[0], -1)
        corr = np.abs(corr_matrix(w_flat))
        print(f"Weight correlation matrix: {name}")
        plot_corr(corr, f"Weight Correlation: {name}")
        break  # Just show one example for brevity

