In [2]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torchvision
import os
import numpy as np
import torch
from scipy.stats import entropy, sem
from dataclasses import dataclass
import random

def exclude_bias_and_norm(name):
    """Dummy function to fix projection head initialization."""
    pass 

In [None]:
# =================================
# DINO (ViT and ResNet)
# =================================
class DINOHead_ViT(nn.Module):
    """ViT Head: Uses Weight Norm"""
    def __init__(self, in_dim, out_dim, hidden_dim, bottleneck_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, bottleneck_dim),
        )
        self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer.weight_g.requires_grad = False

    def forward(self, x):
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x


class DINOHead_ResNet(nn.Module):
    """ResNet Head: Uses BatchNorm"""
    def __init__(self, in_dim, out_dim, hidden_dim, bottleneck_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),       
            nn.BatchNorm1d(hidden_dim),          
            nn.GELU(),                           
            nn.Linear(hidden_dim, bottleneck_dim)
        )
        self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)

    def forward(self, x):
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x


class DINO_Local(nn.Module):
    def __init__(self, arch, ckpt_path):
        super().__init__()
        
        # Load Backbone
        if arch == 'resnet50':
            self.backbone = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        elif arch == 'vit_s':
            self.backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
            
        # Inspect Checkpoint
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(f"Missing: {ckpt_path}")
            
        print(f"Inspecting {ckpt_path}...")
        
        def exclude_fns(name): return False
        with torch.serialization.safe_globals({"exclude_bias_and_norm": exclude_fns}):
            full_checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
            
        state = full_checkpoint['teacher'] if 'teacher' in full_checkpoint else full_checkpoint
        
        # Auto-Detect Dimensions 
        # Find MLP Input/Hidden dims
        mlp_key = next(k for k in state.keys() if "head.mlp.0.weight" in k)
        in_dim = state[mlp_key].shape[1]      
        hidden_dim = state[mlp_key].shape[0]  
        
        # Find Output/Bottleneck dims
        if any("head.last_layer.weight_v" in k for k in state.keys()):
            last_key = next(k for k in state.keys() if "head.last_layer.weight_v" in k)
            has_wn = True
        else:
            last_key = next(k for k in state.keys() if "head.last_layer.weight" in k)
            has_wn = False
            
        # Shape is [Output, Input] -> [out_dim, bottleneck_dim]
        bottleneck_dim = state[last_key].shape[1] 
        out_dim = state[last_key].shape[0]        
        
        print(f" -> Detected: In={in_dim}, Hidden={hidden_dim}, Bottle={bottleneck_dim}, Out={out_dim}")

        # Init Head
        if has_wn:
            print(" -> Type: ViT Head (WeightNorm)")
            self.head = DINOHead_ViT(in_dim, out_dim, hidden_dim, bottleneck_dim)
        else:
            print(" -> Type: ResNet Head (BatchNorm)")
            self.head = DINOHead_ResNet(in_dim, out_dim, hidden_dim, bottleneck_dim)

        # Clean & Load
        clean_state = {}
        for k, v in state.items():
            k = k.replace("teacher.", "")
            clean_state[k] = v
            
        self.load_state_dict(clean_state, strict=True)
        print(f" -> Load Status: Success")

    def forward(self, x):
        if "ViT" in str(type(self.backbone)):
            z = self.backbone(x)
        else:
            z = self.backbone(x)
            if len(z.shape) == 4: z = torch.flatten(z, 1)
        return z, self.head(z)


# =================================
# VICReg (ResNet)
# =================================
class VICReg_Local(nn.Module):
    def __init__(self, ckpt_path):
        super().__init__()
        self.backbone = torchvision.models.resnet50()
        self.backbone.fc = nn.Identity()
        self.projector = nn.Sequential(
            nn.Linear(2048, 8192), nn.BatchNorm1d(8192), nn.ReLU(True),
            nn.Linear(8192, 8192), nn.BatchNorm1d(8192), nn.ReLU(True),
            nn.Linear(8192, 8192, bias=False) 
        )
        print(f"Loading {ckpt_path}...")
        def exclude_fns(name): return False
        with torch.serialization.safe_globals({"exclude_bias_and_norm": exclude_fns}):
            ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        state_dict = ckpt['model'] if 'model' in ckpt else ckpt
        backbone_state = {k.replace('module.backbone.', ''): v for k, v in state_dict.items() if 'module.backbone.' in k}
        projector_state = {k.replace('module.projector.', ''): v for k, v in state_dict.items() if 'module.projector.' in k}
        self.backbone.load_state_dict(backbone_state)
        self.projector.load_state_dict(projector_state)
        print(" -> Load Status: Success")

    def forward(self, x):
        z = self.backbone(x)
        z = torch.flatten(z, 1)
        h = self.projector(z)
        return z, h
    
    
# =================================
# Barlow Twins (ResNet)
# =================================
class BarlowTwins_Local(nn.Module):
    def __init__(self, ckpt_path):
        super().__init__()
        
        # Backbone (ResNet50)
        self.backbone = torchvision.models.resnet50()
        self.backbone.fc = nn.Identity()
        
        # Projector (3 Layers: 8192 -> 8192 -> 8192)
        sizes = [2048, 8192, 8192, 8192]
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)
        
        # Load Weights
        print(f"Loading {ckpt_path}...")
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(f"Missing: {ckpt_path}")

        def exclude_fns(name): return False
        with torch.serialization.safe_globals({"exclude_bias_and_norm": exclude_fns}):
            ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

        # Handle 'model' key if present (standard in Barlow checkpoints)
        state_dict = ckpt['model'] if 'model' in ckpt else ckpt
        
        # Clean keys (Remove 'module.')
        clean_state = {}
        for k, v in state_dict.items():
            k = k.replace("module.", "")
            clean_state[k] = v
            
        # Separate into backbone and projector
        # Official keys are usually 'backbone.xxx' and 'projector.xxx'
        backbone_keys = {k.replace("backbone.", ""): v for k, v in clean_state.items() if "backbone." in k}
        projector_keys = {k.replace("projector.", ""): v for k, v in clean_state.items() if "projector." in k}
        
        msg_b = self.backbone.load_state_dict(backbone_keys, strict=False)
        msg_p = self.projector.load_state_dict(projector_keys, strict=True)
        print(f" -> Load Status: Backbone={msg_b}, Projector={msg_p}")

    def forward(self, x):
        z = self.backbone(x)
        h = self.projector(z)
        return z, h

In [None]:
# We assume the checkpoint files are in the 'pretrained/' directory 
# and have been renamed to the following filenames.
checkpoints = {
    "DINO ViT-S": {
        "type": "dino",
        "arch": "vit_s",
        "path": "pretrained/dino_deitsmall16_fullckpt.pth"
    },
    "DINO ResNet50": {
        "type": "dino",
        "arch": "resnet50",
        "path": "pretrained/dino_resnet50_fullckpt.pth" 
    },
    "VICReg ResNet50": {
        "type": "vicreg",
        "arch": "resnet50",
        "path": "pretrained/vicreg_resnet50_fullckpt.pth"
    },
    "Barlow Twins ResNet50": {
        "type": "barlow",
        "arch": "resnet50",
        "path": "pretrained/barlowtwins_resnet50_fullckpt.pth" 
    }
}

print(f"{'Model':<25} | {'Status':<30}")
print("-" * 60)

loaded_models = {} 

for name, cfg in checkpoints.items():
    if not os.path.exists(cfg['path']):
        print(f"{name:<25} | File Not Found")
        continue
        
    try:
        # Instantiate
        if cfg['type'] == 'dino':
            model = DINO_Local(cfg['arch'], cfg['path'])
        elif cfg['type'] == 'vicreg':
            model = VICReg_Local(cfg['path'])
        elif cfg['type'] == 'barlow':
            model = BarlowTwins_Local(cfg['path'])
            
        model.cuda().eval()
        
        # Test
        z, h = model(torch.randn(2, 3, 224, 224).cuda())
        print(f"{name:<25} | Ready (z={z.shape[1]}, h={h.shape[1]})")
        # Save
        loaded_models[name] = model 
        
    except Exception as e:
        print(f"{name:<25} | Error: {str(e)[:100]}...")

In [None]:
# ============================================================================
# Configuration & Reproducibility
# ============================================================================
@dataclass
class AnalysisConfig:
    batch_size: int = 64
    num_samples_orbit: int = 50   
    steps: int = 12               
    seed: int = 0


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)


config = AnalysisConfig()
set_seed(config.seed)


# ============================================================================
# Data Loading (CIFAR-10)
# ============================================================================
transform = T.Compose([
    T.Resize(224), T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)


# ============================================================================
# Geometric Metrics
# ============================================================================

def calculate_effective_rank(traj):
    """Computes Effective Rank via Shannon Entropy of Singular Values."""
    traj = traj - np.mean(traj, axis=0)
    _, S, _ = np.linalg.svd(traj, full_matrices=False)
    eig_vals = S**2
    total_energy = np.sum(eig_vals) + 1e-10
    probs = eig_vals / total_energy
    ent = entropy(probs)
    return np.exp(ent)

def calculate_curvature(traj):
    """
    Computes Local Curvature using Central Finite Differences.
    Formula: k_t = || z_{t+1} - 2z_t + z_{t-1} ||
    """
    curvatures = []
    
    # Iterate from t=1 to T-1 (skipping endpoints)
    for t in range(1, len(traj) - 1):
        z_prev = traj[t-1]
        z_curr = traj[t]
        z_next = traj[t+1]
        
        # Central Second Derivative (Acceleration Vector)
        acc_vec = z_next - 2*z_curr + z_prev
        
        # Curvature = Magnitude (Norm) of this vector
        k = np.linalg.norm(acc_vec)
        curvatures.append(k)
        
    return np.mean(curvatures)


# ============================================================================
# Main Analysis Loop
# ============================================================================

def get_orbit_generator(orbit_type, steps):
    if orbit_type == 'rotation':
        values = np.linspace(0, 45, steps)
        func = lambda img, v: TF.rotate(img.unsqueeze(0), float(v))
    elif orbit_type == 'hue':
        values = np.linspace(-0.4, 0.4, steps)
        func = lambda img, v: TF.adjust_hue(img.unsqueeze(0), float(v))
    elif orbit_type == 'saturation':
        values = np.linspace(0.0, 2.0, steps)
        func = lambda img, v: TF.adjust_saturation(img.unsqueeze(0), float(v))
    elif orbit_type == 'blur':
        values = np.linspace(0.1, 3.0, steps)
        func = lambda img, v: T.GaussianBlur(kernel_size=23, sigma=float(v))(img.unsqueeze(0))
    else:
        raise ValueError(f"Unknown orbit: {orbit_type}")
    return values, func


def analyze_full_spectrum(model, dataloader, config, orbit_type):
    model.eval()
    values, transform_func = get_orbit_generator(orbit_type, config.steps)
    
    stats = {
        'var_backbone': [], 'var_head': [], 'ratio_collapse': [],
        'rank_backbone': [], 'rank_head': [],
        'curv_backbone': [], 'curv_head': [], 'ratio_curv': [],
        'cos_backbone': [], 'cos_head': [], 'gain': []
    }
    
    with torch.no_grad():
        for i, (images, _) in enumerate(dataloader):
            if i * config.batch_size >= config.num_samples_orbit: break
            images = images.cuda()
            
            for img in images:
                # Anchors
                anchor_b, anchor_h = model(img.unsqueeze(0))
                anchor_b = F.normalize(anchor_b.flatten(1), dim=1)
                anchor_h = F.normalize(anchor_h.flatten(1), dim=1)
                
                # Trajectories
                traj_b_raw, traj_h_raw = [], []
                traj_cos_b, traj_cos_h = [], []
                
                for val in values:
                    aug = transform_func(img, val)
                    backbone, head = model(aug)
                    
                    # Raw
                    traj_b_raw.append(backbone.cpu().numpy().flatten())
                    traj_h_raw.append(head.cpu().numpy().flatten())
                    
                    # Cosine
                    b_n = F.normalize(backbone.flatten(1), dim=1)
                    h_n = F.normalize(head.flatten(1), dim=1)
                    traj_cos_b.append(torch.mm(anchor_b, b_n.T).item())
                    traj_cos_h.append(torch.mm(anchor_h, h_n.T).item())

                # Compute metrics
                tb, th = np.stack(traj_b_raw), np.stack(traj_h_raw)
                
                # Variance
                vb = np.mean(np.var(tb, axis=0))
                vh = np.mean(np.var(th, axis=0))
                
                # Rank
                rb = calculate_effective_rank(tb)
                rh = calculate_effective_rank(th)
                
                # Curvature
                cb = calculate_curvature(tb)
                ch = calculate_curvature(th)
                
                # Alignment
                mb, mh = np.mean(traj_cos_b), np.mean(traj_cos_h)
                
                # Store
                stats['var_backbone'].append(vb)
                stats['var_head'].append(vh)
                stats['ratio_collapse'].append(vb / (vh + 1e-12))
                
                stats['rank_backbone'].append(rb)
                stats['rank_head'].append(rh)
                
                stats['curv_backbone'].append(cb)
                stats['curv_head'].append(ch)
                stats['ratio_curv'].append(ch / (cb + 1e-12))
                
                stats['cos_backbone'].append(mb)
                stats['cos_head'].append(mh)
                stats['gain'].append(mh - mb)

    # Aggregate
    results = {k: np.mean(v) for k, v in stats.items()}
    results['sem_ratio_collapse'] = sem(stats['ratio_collapse'])
    return results


# ============================================================================
# Execution 
# ============================================================================

print(f"{'='*160}")
print(f"Summary Table (Seed={config.seed})")
print(f"{'='*160}")

print(f"{'Model':<22} | {'Orbit':<10} || {'DIMENSIONALITY (Rank)':<19} || {'VARIANCE':<19} | {'COLLAPSE':<24} || {'CURVATURE':<28} || {'ALIGNMENT':<27}")
print(f"{'':<22} | {'':<10} || {'Backbone':<9} {'Head':<9} || {'Backbone':<9} {'Head':<9} | {'Ratio (95% CI)':<24} || {'Backbone':<9} {'Head':<9} {'Ratio':<8} || {'Cos(B)':<9} {'Cos(H)':<9} {'Gain':<8}")
print("-" * 160)

orbit_types = ['rotation', 'hue', 'saturation', 'blur']

if 'loaded_models' not in locals():
    print("Error: Run model loading first")
else:
    for model_name, model in loaded_models.items():
        print(f"--- {model_name} ---")
        for orbit in orbit_types:
            try:
                res = analyze_full_spectrum(model, loader, config, orbit)
                
                ci = res['sem_ratio_collapse'] * 1.96
                coll_str = f"{res['ratio_collapse']:.4f}±{ci:.4f}"
                
                print(f"{model_name:<22} | {orbit:<10} || "
                      f"{res['rank_backbone']:<9.6f} {res['rank_head']:<9.6f} || "
                      f"{res['var_backbone']:<9.6f} {res['var_head']:<9.6f} | "
                      f"{coll_str:<24} || "
                      f"{res['curv_backbone']:<9.6f} {res['curv_head']:<9.6f} {res['ratio_curv']:<8.4f} || "
                      f"{res['cos_backbone']:<9.6f} {res['cos_head']:<9.6f} {res['gain']:<+8.6f}")
                      
            except Exception as e:
                print(f"Error {orbit}: {e}")
        print("-" * 160)

Summary Table (Seed=0)
Model                  | Orbit      || DIMENSIONALITY (Rank) || VARIANCE            | COLLAPSE                 || CURVATURE                    || ALIGNMENT                  
                       |            || Backbone  Head      || Backbone  Head      | Ratio (95% CI)           || Backbone  Head      Ratio    || Cos(B)    Cos(H)    Gain    
----------------------------------------------------------------------------------------------------------------------------------------------------------------
--- DINO ViT-S ---
DINO ViT-S             | rotation   || 3.648814  2.722318  || 5.121819  0.004150  | 1284.3635±61.9525        || 39.095955 13.976097 0.3581   || 0.418339  0.756372  +0.338033
DINO ViT-S             | hue        || 6.496621  3.072013  || 2.247406  0.001261  | 2498.6184±321.2806       || 43.918747 12.802624 0.2903   || 0.681305  0.911384  +0.230080
DINO ViT-S             | saturation || 2.454316  2.363897  || 2.001809  0.000882  | 2898.0669±435.6995