In [None]:
import os
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import sys
import torch
import torch.nn as nn

import seaborn as sns
import matplotlib.pyplot as plt

from scipy.spatial.distance import cosine
from sklearn.decomposition import PCA

current_dir = os.getcwd()
repo_root = os.path.abspath(os.path.join(current_dir, '..'))

if repo_root not in sys.path:
    sys.path.append(repo_root)

print(f"Repository Root: {repo_root}")
from models.mit_b2 import MIT_B2   

In [None]:
class WeightSpaceAnalyzer:
    def __init__(self, encoder_name='mit_b2', num_phases=3, in_channels=1,
                 cache_dir='./weight_cache'):
        
        self.encoder_name = encoder_name
        self.num_phases = num_phases
        self.in_channels = in_channels
        
        self.models = {}
        self.weight_vectors = {}
        
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True, parents=True)
        
    # -------------------------------
    # Cache helpers
    # -------------------------------
    def _cache_path(self, model_name):
        return self.cache_dir / f"{model_name}.pkl"
    
    def save_vec(self, name, v):
        with open(self._cache_path(name), 'wb') as f:
            pickle.dump(v, f)
    
    def load_vec(self, name):
        p = self._cache_path(name)
        if p.exists():
            with open(p, 'rb') as f:
                return pickle.load(f)
        return None
    
    # -------------------------------
    # Loading models
    # -------------------------------
    def load_imagenet_base(self, use_cache=True):
        name = "W_img"
        cached = self.load_vec(name) if use_cache else None
        if cached is not None:
            self.weight_vectors[name] = cached
            print(f"Loaded {name} from cache.")
            return
        
        print("Loading ImageNet pretrained model...")
        model = MIT_B2(
            encoder_name=self.encoder_name,
            num_phases=self.num_phases,
            in_channels=self.in_channels,
            pretrained='imagenet'
        )
        self.models[name] = model
    
    def load_checkpoint(self, ckpt_path, name, use_cache=True):
        cached = self.load_vec(name) if use_cache else None
        if cached is not None:
            self.weight_vectors[name] = cached
            print(f"Loaded {name} from cache.")
            return
        
        print(f"Loading {name} from:\n  {ckpt_path}")
        model = MIT_B2.load_from_checkpoint(ckpt_path)
        self.models[name] = model
    
    # -------------------------------
    # Extract encoder weights
    # -------------------------------
    def extract_encoder(self, model):
        vecs = []
        for p_name, param in model.named_parameters():
            if "encoder" in p_name:
                vecs.append(param.detach().cpu().flatten())
        return torch.cat(vecs).numpy()
    
    def compute_weight_vectors(self, use_cache=True):
        for name, model in list(self.models.items()):
            print(f"Extracting encoder vector for {name}...")
            w = self.extract_encoder(model)
            self.weight_vectors[name] = w
            
            if use_cache:
                self.save_vec(name, w)
            
            del self.models[name]   # free memory
        
        print(f"✓ Extracted {len(self.weight_vectors)} weight vectors.")
    
    # -------------------------------
    # L2 Matrix
    # -------------------------------
    def compute_l2_matrix(self):
        names = list(self.weight_vectors.keys())
        n = len(names)
        M = np.zeros((n, n))
        
        print("Computing L2 pairwise distances...")
        for i in range(n):
            for j in range(n):
                M[i,j] = np.linalg.norm(self.weight_vectors[names[i]] -
                                        self.weight_vectors[names[j]])
        return pd.DataFrame(M, index=names, columns=names)
    
    # -------------------------------
    # Plot Heatmap
    # -------------------------------
    def plot_l2_matrix(self, df, save_path=None):
        plt.figure(figsize=(10, 8))
        sns.heatmap(df, annot=True, fmt=".1f", cmap="RdYlBu_r",
                    square=True, cbar_kws={"label": "L2 Distance"})
        
        plt.title("Pairwise L2 Distance Matrix (Encoder Weights Only)", fontsize=16)
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()

In [None]:
checkpoint_paths = {
    "W_syn": os.path.join(repo_root, "output", "checkpoints", "Synthetic-PreTrained.ckpt"),
    "W_real_direct_1": os.path.join(repo_root, "output", "checkpoints", "Dataset1-FineTuned-ImageNet-Initialization.ckpt"),
    "W_real_via_syn_1": os.path.join(repo_root, "output", "checkpoints", "Dataset1-FineTuned-Synthetic-PreTrained.ckpt"),
    "W_real_direct_2": os.path.join(repo_root, "output", "checkpoints", "Dataset2-FineTuned-ImageNet-Initialization.ckpt"),
    "W_real_via_syn_2": os.path.join(repo_root, "output", "checkpoints", "Dataset2-FineTuned-Synthetic-PreTrained.ckpt")
}

checkpoint_paths

In [None]:
an = WeightSpaceAnalyzer(cache_dir="./l2_cache")

# Step 1 – Load ImageNet
an.load_imagenet_base()

# Step 2 – Load all checkpoints
for name, path in checkpoint_paths.items():
    an.load_checkpoint(path, name)

# Step 3 – Extract weight vectors
an.compute_weight_vectors()

# Step 4 – Compute L2 matrix
df_l2 = an.compute_l2_matrix()
df_l2

In [None]:
an.plot_l2_matrix(df_l2, save_path="l2_matrix.png")