In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import numpy as np

import copy
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from collections import defaultdict

folder = "track_structure"
os.makedirs(folder, exist_ok=True)

In [2]:
def get_linear_mask_per_column(module:nn.Module) -> torch.Tensor:
    x = module.weight.data
    output_size, input_size = x.shape
    x_norm = torch.abs(x) / torch.sum(torch.abs(x), dim=0, keepdim=True)
    neff = torch.floor(1/torch.sum((x_norm ** 2), dim=0, keepdim=True).squeeze(0))
    
    _, indices = torch.sort(x_norm, dim=0, descending=True)
    range_tensor = torch.arange(output_size, device=x.device).unsqueeze(0).expand(input_size, -1).T
    sorted_mask = range_tensor < neff
    
    mask = torch.zeros_like(x, dtype=torch.bool)
    mask.scatter_(0, indices, sorted_mask)
    return mask


def get_linear_mask_per_row(module:nn.Module) -> torch.Tensor:
    x = module.weight.data
    output_size, input_size = x.shape
    x_norm = torch.abs(x) / torch.sum(torch.abs(x), dim=1, keepdim=True)
    neff = torch.floor(1/torch.sum((x_norm ** 2), dim=1, keepdim=True).squeeze(0))
    
    _, indices = torch.sort(x_norm, dim=1, descending=True)
    range_tensor = torch.arange(input_size, device=x.device).unsqueeze(0).expand(output_size, -1)
    sorted_mask = range_tensor < neff
    
    mask = torch.zeros_like(x, dtype=torch.bool)
    mask.scatter_(1, indices, sorted_mask)
    return mask

In [25]:
def prune_model_neff_per_column_structure(model):
    model = copy.deepcopy(model)
    layer_masks = {}
    
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_linear_mask_per_column(module).to(module.weight.device)
            with torch.no_grad():
                module.weight *= mask
                layer_masks[name] = module.weight.clone()
    return model, layer_masks

def prune_model_neff_per_row_structure(model):
    model = copy.deepcopy(model)
    layer_masks = {}
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_linear_mask_per_row(module).to(module.weight.device)
            with torch.no_grad():
                module.weight *= mask
                layer_masks[name] = module.weight.clone()
    return model, layer_masks


# load model

In [26]:
# Model class with optional dropout
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=[512, 512, 512], dropout_rate=0.0):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout_rate)
        
        prev_size = input_size
        for size in hidden_size:
            self.layers.append(nn.Linear(prev_size, size))
            prev_size = size
            
        self.output = nn.Linear(prev_size, output_size)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        
        for layer in self.layers:
            x = F.relu(layer(x))
            x = self.dropout(x)  # Apply dropout after activation
        x = self.output(x)
        return F.log_softmax(x, dim=1)

In [27]:
# Model configurations
model_configs = {
    'Model_1_Underfit': {
        'hidden_size': [64, 32, 16],  # Very shallow, only 1 small hidden layer
        'lr': 1e-4,  # Lower learning rate
        'epochs': 5,  # Fewer epochs
        'dropout': 0.0,
        'description': 'Underfitted: Too simple (1 layer, 32 units)'
    },
    'Model_2_Slight_Underfit': {
        'hidden_size': [256, 128, 64],  # 2 small layers
        'lr': 5e-4,
        'epochs': 8,
        'dropout': 0.0,
        'description': 'Slightly underfitted: Simple architecture'
    },
    'Model_3_Well_Trained': {
        'hidden_size': [512, 256, 128],  # Moderate depth and width
        'lr': 3e-4,
        'epochs': 15,
        'dropout': 0.2,  # Some regularization
        'description': 'Well-trained: Balanced architecture with dropout'
    },
    'Model_4_Well_Trained_Deep': {
        'hidden_size': [1024, 512, 256],  # Deeper but with dropout
        'lr': 3e-4,
        'epochs': 20,
        'dropout': 0.3,  # More dropout for regularization
        'description': 'Well-trained: Deeper with good regularization'
    },
    'Model_5_Overfit': {
        'hidden_size': [2048, 1024, 1024],  # Very deep and wide
        'lr': 1e-3,  # Higher learning rate
        'epochs': 30,  # Many epochs
        'dropout': 0.0,  # No regularization
        'description': 'Overfitted: Very complex without regularization'
    },
    'Model_6_Extra_Overfit': {
        'hidden_size': [4096, 2048, 1024],  # Extremely deep and wide
        'lr': 1e-3,
        'epochs': 50,
        'dropout': 0.0,
        'description': 'Extra Overfitted: Very complex without regularization'
    },
    'Model_7_Extra_Overfit': {
        'hidden_size': [8192, 4096, 2048],  # Extremely deep and wide
        'lr': 1e-3,
        'epochs': 100,
        'dropout': 0.0,
        'description': 'Extra Overfitted: Very complex without regularization'
    }
}

In [28]:
row_masks = {}

for model_name, config in model_configs.items():
    model = LinearModel(input_size=784, output_size=10, hidden_size=config['hidden_size'], dropout_rate=config['dropout'])
    # load from models folder
    model.load_state_dict(torch.load(f"models/{model_name}.pth"))
    row_model, row_mask = prune_model_neff_per_row_structure(model)
    row_masks[model_name] = row_mask

In [29]:
test1 = row_masks['Model_1_Underfit']['layers.1']
test2 = row_masks['Model_2_Slight_Underfit']['layers.1']

In [37]:
# row_prune_and_compare.py
# ------------------------------------------------------------
# N_eff-based ROW pruning with recording, visualization, and
# model similarity utilities tailored for ROW-pruned analysis.
#
# Dependencies: torch, numpy, matplotlib
# Place your checkpoints (.pt/.pth with state_dict) under ./models
# Outputs under ./prune_outputs/
# ------------------------------------------------------------

import os
import glob
import copy
import json
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages


# ----------------------------
# Your MLP definition (as given)
# ----------------------------
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=[512, 512, 512], dropout_rate=0.0):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout_rate)

        prev_size = input_size
        for size in hidden_size:
            self.layers.append(nn.Linear(prev_size, size))
            prev_size = size

        self.output = nn.Linear(prev_size, output_size)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        for layer in self.layers:
            x = F.relu(layer(x))
            x = self.dropout(x)  # Apply dropout after activation
        x = self.output(x)
        return F.log_softmax(x, dim=1)


# ----------------------------
# Utils: State dict introspection
# ----------------------------
def strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Remove a leading 'module.' (from DataParallel) if present."""
    if not state_dict:
        return state_dict
    if all(k.startswith('module.') for k in state_dict.keys()):
        return {k[len('module.'):]: v for k, v in state_dict.items()}
    return state_dict


def infer_mlp_arch_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> Tuple[int, int, List[int]]:
    """
    Infer input_size, output_size, hidden_size list from a LinearModel state_dict.
    Assumes keys like 'layers.0.weight', ..., 'output.weight'.
    """
    sd = strip_module_prefix(state_dict)
    # Collect layer weight keys
    layer_keys = []
    i = 0
    while f"layers.{i}.weight" in sd:
        layer_keys.append(f"layers.{i}.weight")
        i += 1
    if not layer_keys:
        raise ValueError("Could not find any 'layers.i.weight' in state_dict. Is this a LinearModel?")

    first = sd[layer_keys[0]]
    input_size = first.shape[1]   # in_features of first hidden layer
    hidden_sizes = [sd[k].shape[0] for k in layer_keys]

    if "output.weight" not in sd:
        raise ValueError("Could not find 'output.weight' in state_dict.")
    output_size = sd["output.weight"].shape[0]
    return input_size, output_size, hidden_sizes


def build_model_from_state_dict(state_dict_path: str, dropout_rate: float = 0.0, device: str = "cpu") -> Tuple[str, LinearModel]:
    """
    Load a state_dict file, infer architecture, build a LinearModel, and load the state.
    Returns (model_name_without_ext, model)
    """
    model_name = os.path.splitext(os.path.basename(state_dict_path))[0]
    raw = torch.load(state_dict_path, map_location="cpu")
    if isinstance(raw, dict) and "state_dict" in raw:
        sd = raw["state_dict"]
    else:
        sd = raw
    sd = strip_module_prefix(sd)
    in_sz, out_sz, hsz = infer_mlp_arch_from_state_dict(sd)
    model = LinearModel(in_sz, out_sz, hidden_size=hsz, dropout_rate=dropout_rate)
    missing, unexpected = model.load_state_dict(sd, strict=False)
    if missing or unexpected:
        print(f"[WARN] While loading {model_name}: missing={missing}, unexpected={unexpected}")
    model.to(device)
    model.eval()
    return model_name, model


def load_all_models(models_dir: str, device: str = "cpu") -> Dict[str, LinearModel]:
    models = {}
    for path in glob.glob(os.path.join(models_dir, "*.pt")) + glob.glob(os.path.join(models_dir, "*.pth")):
        try:
            name, model = build_model_from_state_dict(path, device=device)
            models[name] = model
            print(f"Loaded: {name}")
        except Exception as e:
            print(f"[Skip] {path}: {e}")
    return models


# ----------------------------
# N_eff ROW pruning core
# ----------------------------
def neff_per_row(W: torch.Tensor) -> torch.Tensor:
    """
    N_eff for each ROW (out_features) of W.
    N_eff = floor( 1 / sum(p^2) ), where p is the L1-normalized row vector.
    Returns a 1D int tensor of length out_features with keep-counts per row.
    """
    eps = 1e-12
    A = W.abs()
    P = A / (A.sum(dim=1, keepdim=True) + eps)            # L1-normalize each row
    neff = (1.0 / (P.pow(2).sum(dim=1) + eps)).floor()    # (out_features,)
    neff = neff.to(torch.int64)
    neff.clamp_(1, W.shape[1])                            # at least 1, at most in_features
    return neff


def mask_top_by_neff_rows(W: torch.Tensor, neff_rows: torch.Tensor) -> torch.Tensor:
    """
    Build boolean mask of same shape as W, keeping top-N per ROW by |W|.
    neff_rows: shape (out_features,), each entry is how many columns to keep in that row.
    """
    A = W.abs()
    out, in_ = A.shape
    sorted_vals, sorted_idx = torch.sort(A, dim=1, descending=True)  # (out, in)
    ranks = torch.arange(in_, device=W.device).unsqueeze(0).expand_as(sorted_idx)
    keep_counts = neff_rows.view(-1, 1).expand_as(sorted_idx)
    keep_sorted = ranks < keep_counts
    mask = torch.zeros_like(W, dtype=torch.bool)
    mask.scatter_(1, sorted_idx, keep_sorted)
    return mask


@dataclass
class LayerPruneRecord:
    layer_name: str
    shape: Tuple[int, int]
    neff_1d: torch.Tensor     # 1D tensor of keep-counts (per row)
    mask: torch.Tensor        # bool mask same shape as weight
    weight_before: torch.Tensor
    weight_after: torch.Tensor
    kept_fraction: float      # fraction of weights kept (nonzero in mask)


def prune_linear_by_neff_row(module: nn.Linear, renormalize: bool = False
                            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Row-prune a Linear layer using N_eff per row.
    Returns: mask, neff, weight_before, weight_after
    """
    W = module.weight.data.clone()
    neff = neff_per_row(W)
    mask = mask_top_by_neff_rows(W, neff)
    W_pruned = W * mask
    if renormalize:
        # Preserve L1 per-row after pruning (avoid sign cancellation by using |.|)
        row_sum = W_pruned.abs().sum(dim=1, keepdim=True).clamp_min(1e-12)
        W_pruned = W_pruned / row_sum
    return mask, neff, W, W_pruned


def prune_model_neff_rows(model: nn.Module, renormalize: bool = False
                         ) -> Tuple[nn.Module, List[LayerPruneRecord]]:
    """
    Copy the model, row-prune each nn.Linear by N_eff, and return (pruned_model, list_of_records).
    """
    pruned = copy.deepcopy(model)
    records: List[LayerPruneRecord] = []
    for name, module in pruned.named_modules():
        if isinstance(module, nn.Linear):
            mask, neff, W_before, W_after = prune_linear_by_neff_row(module, renormalize=renormalize)
            with torch.no_grad():
                module.weight.copy_(W_after)
            kept_fraction = float(mask.float().mean().item())
            rec = LayerPruneRecord(
                layer_name=name,
                shape=tuple(W_before.shape),
                neff_1d=neff.detach().cpu(),
                mask=mask.detach().cpu(),
                weight_before=W_before.detach().cpu(),
                weight_after=W_after.detach().cpu(),
                kept_fraction=kept_fraction,
            )
            records.append(rec)
    return pruned, records


# ----------------------------
# Visualization helpers
# ----------------------------
def to_display_array(W: torch.Tensor, max_side: int = 512) -> np.ndarray:
    """
    Convert a 2D tensor to a 2D numpy array of magnitudes, possibly downscaled if very large.
    """
    X = W.detach().cpu().abs().float()  # magnitude for visualization
    H, Wd = X.shape
    if max(H, Wd) <= max_side:
        return X.numpy()
    X_img = X.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    scale_h = min(max_side, H)
    scale_w = min(max_side, Wd)
    X_small = F.interpolate(X_img, size=(scale_h, scale_w), mode="area").squeeze().numpy()
    return X_small


def plot_heatmap(arr2d: np.ndarray, title: str, out_path: Optional[str] = None):
    plt.figure()
    plt.imshow(arr2d, aspect='auto')
    plt.title(title)
    plt.colorbar()
    plt.tight_layout()
    if out_path:
        plt.savefig(out_path, dpi=200)
        plt.close()


def visualize_prune_records(records: List[LayerPruneRecord], out_dir: str, pdf_name: str = "prune_report_row.pdf"):
    os.makedirs(out_dir, exist_ok=True)
    pdf_path = os.path.join(out_dir, pdf_name)
    with PdfPages(pdf_path) as pdf:
        for rec in records:
            W0 = rec.weight_before
            M = rec.mask
            Wp = rec.weight_after

            A0 = to_display_array(W0)
            Am = to_display_array(M.float())
            Ap = to_display_array(Wp)

            base = f"{rec.layer_name.replace('.', '_')}_per_row"
            plot_heatmap(A0, f"{rec.layer_name} (per_row) - Original |W|", os.path.join(out_dir, base + "_orig.png"))
            plot_heatmap(Am, f"{rec.layer_name} (per_row) - Mask (1=keep)", os.path.join(out_dir, base + "_mask.png"))
            plot_heatmap(Ap, f"{rec.layer_name} (per_row) - Pruned |W|", os.path.join(out_dir, base + "_pruned.png"))

            # Also add to the PDF
            fig = plt.figure(figsize=(9, 7))
            ax1 = fig.add_subplot(2, 2, 1)
            im1 = ax1.imshow(A0, aspect='auto'); ax1.set_title("Original |W|"); fig.colorbar(im1, ax=ax1)
            ax2 = fig.add_subplot(2, 2, 2)
            im2 = ax2.imshow(Am, aspect='auto'); ax2.set_title("Mask"); fig.colorbar(im2, ax=ax2)
            ax3 = fig.add_subplot(2, 1, 2)
            im3 = ax3.imshow(Ap, aspect='auto'); ax3.set_title(f"Pruned |W| (kept {rec.kept_fraction:.2%})"); fig.colorbar(im3, ax=ax3)
            fig.suptitle(f"{rec.layer_name} - per_row - shape={rec.shape}")
            fig.tight_layout(rect=[0, 0.03, 1, 0.95])
            pdf.savefig(fig)
            plt.close(fig)
    print(f"[Saved] PDF report: {pdf_path}")


def save_prune_records(records: List[LayerPruneRecord], out_dir: str, model_name: str):
    os.makedirs(out_dir, exist_ok=True)
    # Save a compact torch file
    pack = []
    for r in records:
        pack.append({
            "layer_name": r.layer_name,
            "shape": r.shape,
            "kept_fraction": r.kept_fraction,
            "neff_1d": r.neff_1d,
            "mask": r.mask,
            "weight_before": r.weight_before,
            "weight_after": r.weight_after,
        })
    torch.save({"records": pack}, os.path.join(out_dir, f"{model_name}_row_prune_records.pt"))
    # Save a small JSON metadata (without tensors)
    meta = [{
        "layer_name": r.layer_name,
        "shape": r.shape,
        "kept_fraction": r.kept_fraction,
        "neff_1d_len": int(r.neff_1d.numel()),
    } for r in records]
    with open(os.path.join(out_dir, f"{model_name}_row_prune_meta.json"), "w") as f:
        json.dump(meta, f, indent=2)


# ----------------------------
# Pattern / similarity utilities (row-pruning centric)
# ----------------------------
def canonical_reorder(W: torch.Tensor, axis: str = "columns", method: str = "pc1") -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Canonicalize order of columns/rows to mitigate permutations.
    - axis: 'columns' or 'rows'
    - method: 'pc1' (project along first singular vector), or 'l2'
    Returns (W_reordered, order_idx)
    """
    A = W.detach().cpu()
    if axis == "columns":
        if method == "pc1":
            U, S, Vt = torch.linalg.svd(A.abs(), full_matrices=False)
            v1 = Vt[0]
            if v1.sum() < 0:
                v1 = -v1
            order = torch.argsort(v1)
        elif method == "l2":
            norms = torch.linalg.norm(A, dim=0)
            order = torch.argsort(norms, descending=True)
        else:
            raise ValueError("method must be 'pc1' or 'l2'")
        return A[:, order], order
    elif axis == "rows":
        if method == "pc1":
            U, S, Vt = torch.linalg.svd(A.abs(), full_matrices=False)
            u1 = U[:, 0]
            if u1.sum() < 0:
                u1 = -u1
            order = torch.argsort(u1)
        elif method == "l2":
            norms = torch.linalg.norm(A, dim=1)
            order = torch.argsort(norms, descending=True)
        else:
            raise ValueError("method must be 'pc1' or 'l2'")
        return A[order, :], order
    else:
        raise ValueError("axis must be 'columns' or 'rows'")


def canonicalize_both_axes(W: torch.Tensor, method: str = "pc1") -> torch.Tensor:
    Wc, _ = canonical_reorder(W, axis="columns", method=method)
    Wc, _ = canonical_reorder(Wc, axis="rows", method=method)
    return Wc


def resize_map_abs(W: torch.Tensor, size: Tuple[int, int] = (256, 256)) -> torch.Tensor:
    """Resize |W| to a fixed 2D map with area interpolation (scale-invariant comparison)."""
    A = W.detach().cpu().abs().float().unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    R = F.interpolate(A, size=size, mode="area").squeeze(0).squeeze(0)
    return R


def normalize_01(X: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    mn, mx = X.min(), X.max()
    return (X - mn) / (mx - mn + eps)


def cosine_similarity_flat(A: torch.Tensor, B: torch.Tensor, eps: float = 1e-12) -> float:
    a = A.flatten().float()
    b = B.flatten().float()
    a = a / (a.norm() + eps)
    b = b / (b.norm() + eps)
    return float((a * b).sum().item())


def mse_flat(A: torch.Tensor, B: torch.Tensor) -> float:
    a = A.flatten().float()
    b = B.flatten().float()
    return float(F.mse_loss(a, b).item())


def hist_emd_1d(a: torch.Tensor, b: torch.Tensor, bins: int = 64) -> float:
    """
    Earth Mover's Distance between two 1D histograms (same bin edges).
    Closed-form via L1 distance between CDFs.
    """
    a = a.flatten().float()
    b = b.flatten().float()
    mx = max(a.max().item(), b.max().item())
    edges_min = 0.0
    edges_max = mx if mx > 0 else 1.0
    ha = torch.histc(a, bins=bins, min=edges_min, max=edges_max)
    hb = torch.histc(b, bins=bins, min=edges_min, max=edges_max)
    ha = ha / (ha.sum() + 1e-12)
    hb = hb / (hb.sum() + 1e-12)
    cdfa = torch.cumsum(ha, dim=0)
    cdfb = torch.cumsum(hb, dim=0)
    emd = torch.sum(torch.abs(cdfa - cdfb)).item() / bins
    return float(emd)


def spearman_rank_corr(A: torch.Tensor, B: torch.Tensor, eps: float = 1e-12) -> float:
    """
    Spearman rank correlation between flattened arrays (no SciPy).
    """
    a = A.flatten()
    b = B.flatten()
    ra = torch.argsort(torch.argsort(a))
    rb = torch.argsort(torch.argsort(b))
    ra = ra.float(); rb = rb.float()
    ra = (ra - ra.mean()) / (ra.std() + eps)
    rb = (rb - rb.mean()) / (rb.std() + eps)
    return float((ra * rb).mean().item())


def greedy_row_alignment_cosine(W1: torch.Tensor, W2: torch.Tensor) -> Tuple[torch.Tensor, List[int]]:
    """
    Permutation-aware greedy matching of ROWS by absolute cosine similarity.
    Returns (W2_aligned, perm_idx) such that W2_aligned's rows match W1's row order.
    Shapes must match on both dims.
    """
    A = W1.detach().cpu()
    B = W2.detach().cpu()
    assert A.shape == B.shape
    out, in_ = A.shape

    def norm_rows(X):
        return X / (X.norm(dim=1, keepdim=True) + 1e-12)

    An = norm_rows(A)
    Bn = norm_rows(B)

    S = torch.abs(An @ Bn.T)  # (out, out)

    perm = [-1] * out
    used_rows_A = set()
    used_rows_B = set()

    for _ in range(out):
        S_masked = S.clone()
        if used_rows_A:
            S_masked[list(used_rows_A), :] = -1e9
        if used_rows_B:
            S_masked[:, list(used_rows_B)] = -1e9
        i, j = torch.nonzero(S_masked == S_masked.max(), as_tuple=True)
        if len(i) == 0:
            break
        i0 = int(i[0].item()); j0 = int(j[0].item())
        perm[i0] = j0
        used_rows_A.add(i0)
        used_rows_B.add(j0)

    perm_idx = torch.tensor(perm, dtype=torch.long)
    Baligned = B[perm_idx, :]
    return Baligned, perm


def compare_weight_patterns_row_focus(
    W1: torch.Tensor,
    W2: torch.Tensor,
    canonical_axis: str = "columns",     # columns canonicalization is most important for row-pruned analysis
    canonical_method: str = "pc1",
    resize_to: Tuple[int, int] = (256, 256),
    try_row_alignment_if_same_shape: bool = True
) -> Dict[str, float]:
    """
    Compare two layers' weight matrices for ROW-pruned analysis.
    Steps:
      1) Canonical reorder (usually 'columns') to mitigate feature permutations.
      2) If shapes equal, align ROWS greedily by cosine.
      3) Resize |W| maps to common size.
      4) Compute metrics.
    """
    A = W1.detach().cpu()
    B = W2.detach().cpu()

    # (1) canonical reorder
    if canonical_axis == "both":
        A = canonicalize_both_axes(A, method=canonical_method)
        B = canonicalize_both_axes(B, method=canonical_method)
    elif canonical_axis in ("columns", "rows"):
        A, _ = canonical_reorder(A, axis=canonical_axis, method=canonical_method)
        B, _ = canonical_reorder(B, axis=canonical_axis, method=canonical_method)
    else:
        raise ValueError("canonical_axis must be 'columns', 'rows', or 'both'.")

    # (2) optional permutation alignment on ROWS if same shape
    if try_row_alignment_if_same_shape and A.shape == B.shape:
        Baligned, permr = greedy_row_alignment_cosine(A, B)
        B = Baligned

    # (3) resize to common map (scale-agnostic)
    Ra = resize_map_abs(A, size=resize_to)
    Rb = resize_map_abs(B, size=resize_to)

    # normalize to [0,1] for stable metrics
    Ra_n = normalize_01(Ra)
    Rb_n = normalize_01(Rb)

    # (4) metrics
    metrics = {
        "cosine": cosine_similarity_flat(Ra_n, Rb_n),
        "ncc": float(torch.mean((Ra_n - Ra_n.mean()) * (Rb_n - Rb_n.mean())).item() /
                     ((Ra_n.std() + 1e-12) * (Rb_n.std() + 1e-12))),
        "mse": mse_flat(Ra_n, Rb_n),
        "emd_hist": hist_emd_1d(Ra_n, Rb_n, bins=64),
        "spearman": spearman_rank_corr(Ra_n, Rb_n),
    }
    return metrics


def visualize_two_layers_side_by_side(
    W1: torch.Tensor, W2: torch.Tensor, title1: str, title2: str, out_path: str, resize_to=(256, 256)
):
    A = resize_map_abs(W1, size=resize_to).numpy()
    B = resize_map_abs(W2, size=resize_to).numpy()
    plt.figure(figsize=(9, 4))
    ax1 = plt.subplot(1, 2, 1); im1 = ax1.imshow(A, aspect='auto'); ax1.set_title(title1); plt.colorbar(im1, ax=ax1)
    ax2 = plt.subplot(1, 2, 2); im2 = ax2.imshow(B, aspect='auto'); ax2.set_title(title2); plt.colorbar(im2, ax=ax2)
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


def verdict_from_metrics(m, w_cos=0.45, w_ncc=0.2, w_spear=0.15, w_emd=0.1, w_mse=0.1):
    """
    Combine metrics into a single score in [0,1] (heuristic).
    """
    cos = max(0.0, min(1.0, m["cosine"]))          # [0,1]
    ncc = (m["ncc"] + 1.0) / 2.0                   # [-1,1] -> [0,1]
    spear = (m["spearman"] + 1.0) / 2.0            # [-1,1] -> [0,1]
    emd = 1.0 - max(0.0, min(1.0, m["emd_hist"]))  # lower EMD is better
    mse = 1.0 - min(m["mse"] / 0.1, 1.0)           # cap at 0.1 as “bad”
    score = w_cos*cos + w_ncc*ncc + w_spear*spear + w_emd*emd + w_mse*mse
    if score >= 0.75:
        label = "SIMILAR"
    elif score >= 0.6:
        label = "BORDERLINE"
    else:
        label = "DIFFERENT"
    return float(score), label


# ----------------------------
# Mask similarity for ROW-pruned models
# ----------------------------
def jaccard_on_resized_masks(M1: torch.Tensor, M2: torch.Tensor, size=(256, 256)) -> float:
    """
    Compare binary masks of different shapes by resizing with area interpolation,
    then threshold at 0.5 and compute Jaccard = |A∩B|/|A∪B|.
    """
    A = resize_map_abs(M1.float(), size=size)  # values in [0,1] due to area interp
    B = resize_map_abs(M2.float(), size=size)
    Ab = (A > 0.5)
    Bb = (B > 0.5)
    inter = (Ab & Bb).sum().item()
    union = (Ab | Bb).sum().item()
    return float(inter / (union + 1e-12))


def compare_prune_masks_row_focus(recordsA, recordsB, out_dir: str, size=(256, 256)):
    """
    Compare masks layer-by-layer (up to min depth) and write a JSON summary.
    - Canonicalize COLUMNS (since row pruning picks columns per row).
    - If shapes match, greedily align ROWS (using masks) before Jaccard.
    """
    os.makedirs(out_dir, exist_ok=True)
    n = min(len(recordsA), len(recordsB))
    out = {}
    for i in range(n):
        rA, rB = recordsA[i], recordsB[i]
        MA = rA.mask.float()
        MB = rB.mask.float()

        # Canonicalize columns to mitigate feature permutations
        MA_c, _ = canonical_reorder(MA, axis="columns", method="pc1")
        MB_c, _ = canonical_reorder(MB, axis="columns", method="pc1")

        # If shapes equal, align rows greedily (on mask rows)
        if MA_c.shape == MB_c.shape:
            # reuse greedy alignment by treating masks as weights
            MB_aligned, _ = greedy_row_alignment_cosine(MA_c, MB_c)
            jac = jaccard_on_resized_masks(MA_c.bool(), MB_aligned.bool(), size=size)
        else:
            jac = jaccard_on_resized_masks(MA_c.bool(), MB_c.bool(), size=size)

        out[f"layer_{i}:{rA.layer_name} vs {rB.layer_name}"] = {
            "shapeA": list(rA.shape), "shapeB": list(rB.shape),
            "kept_frac_A": rA.kept_fraction, "kept_frac_B": rB.kept_fraction,
            "jaccard_mask": jac
        }
        # optional side-by-side mask heatmaps
        visualize_two_layers_side_by_side(
            MA_c, MB_c,
            f"Mask {rA.layer_name}", f"Mask {rB.layer_name}",
            os.path.join(out_dir, f"mask_compare_{i}.png"),
            resize_to=size
        )
    with open(os.path.join(out_dir, "mask_compare_summary.json"), "w") as f:
        json.dump(out, f, indent=2)
    print(f"[Saved] Mask comparison → {os.path.join(out_dir, 'mask_compare_summary.json')}")


# ----------------------------
# End-to-end pipeline helpers (ROW focus)
# ----------------------------
def run_row_prune_and_save(model_name: str, model: nn.Module, out_root: str, renormalize: bool = False):
    print(f"Pruning {model_name} [per_row] ...")
    pruned_model, records = prune_model_neff_rows(model, renormalize=renormalize)
    out_dir = os.path.join(out_root, model_name)
    save_prune_records(records, out_dir, model_name)
    visualize_prune_records(records, out_dir, pdf_name=f"{model_name}_row_prune_report.pdf")
    return pruned_model, records


def pick_linear_layers(model: nn.Module) -> List[Tuple[str, nn.Linear]]:
    layers = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            layers.append((name, module))
    return layers


def compare_two_models_row_focus(
    name_a: str, model_a: nn.Module,
    name_b: str, model_b: nn.Module,
    out_dir: str,
    canonical_axis: str = "columns",
    canonical_method: str = "pc1"
) -> Dict[str, Dict[str, float]]:
    """
    Compare the corresponding Linear layers of two models by index order (ROW-focused).
    - Canonicalize COLUMNS (important for row-pruned analysis).
    - If shapes equal, align ROWS greedily.
    - Resize for cross-width comparison.
    Saves side-by-side heatmaps and returns a metrics dict (plus verdict).
    """
    os.makedirs(out_dir, exist_ok=True)
    layers_a = pick_linear_layers(model_a)
    layers_b = pick_linear_layers(model_b)
    n = min(len(layers_a), len(layers_b))
    results = {}
    for i in range(n):
        (name1, la), (name2, lb) = layers_a[i], layers_b[i]
        W1 = la.weight.data
        W2 = lb.weight.data
        metrics = compare_weight_patterns_row_focus(
            W1, W2,
            canonical_axis=canonical_axis,
            canonical_method=canonical_method,
            resize_to=(256, 256),
            try_row_alignment_if_same_shape=True
        )
        score, label = verdict_from_metrics(metrics)
        metrics["combo_score"] = round(score, 4)
        metrics["verdict"] = label
        results[f"layer_{i}:{name1} vs {name2}"] = metrics

        out_path = os.path.join(out_dir, f"compare_{i}_{name1.replace('.','_')}_VS_{name2.replace('.','_')}.png")
        visualize_two_layers_side_by_side(W1, W2, f"{name_a}:{name1}", f"{name_b}:{name2}", out_path)
        print(f"Compared layer {i}: {name1} vs {name2} -> {metrics}")

    with open(os.path.join(out_dir, f"{name_a}_VS_{name_b}_row_metrics.json"), "w") as f:
        json.dump(results, f, indent=2)
    return results


def model_level_similarity_row_focus(models: Dict[str, nn.Module], out_dir: str, canonical_axis="columns"):
    """
    Compute an NxN matrix of average cosine similarity (after canonicalization & resizing)
    across corresponding layers of each pair of models (ROW-focused). Saves a heatmap.
    """
    os.makedirs(out_dir, exist_ok=True)
    names = sorted(models.keys())
    N = len(names)
    mat = np.zeros((N, N), dtype=np.float32)

    def layerwise_avg_cos(mA, mB):
        la = pick_linear_layers(mA); lb = pick_linear_layers(mB)
        n = min(len(la), len(lb))
        if n == 0:
            return 0.0
        cs = []
        for i in range(n):
            W1 = la[i][1].weight.data
            W2 = lb[i][1].weight.data
            mets = compare_weight_patterns_row_focus(
                W1, W2, canonical_axis=canonical_axis, canonical_method="pc1",
                resize_to=(256, 256), try_row_alignment_if_same_shape=True
            )
            cs.append(mets["cosine"])
        return float(np.mean(cs))

    for i in range(N):
        for j in range(N):
            if i == j:
                mat[i, j] = 1.0
            else:
                mat[i, j] = layerwise_avg_cos(models[names[i]], models[names[j]])

    # save heatmap
    plt.figure(figsize=(1+0.4*N, 1+0.4*N))
    plt.imshow(mat, vmin=0, vmax=1, aspect='equal')
    plt.xticks(range(N), names, rotation=45, ha='right', fontsize=8)
    plt.yticks(range(N), names, fontsize=8)
    plt.colorbar(label="Avg. cosine over layers (row-focus)")
    plt.tight_layout()
    out_path = os.path.join(out_dir, "row_model_similarity_matrix.png")
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"[Saved] {out_path}")

    # also dump numeric matrix
    with open(os.path.join(out_dir, "row_model_similarity_matrix.json"), "w") as f:
        json.dump({"names": names, "matrix": mat.tolist()}, f, indent=2)


# ----------------------------
# Main: end-to-end (ROW pruning only)
# ----------------------------
if __name__ == "__main__":
    torch.set_grad_enabled(False)
    device = "cuda"  # change to "cuda" if desired and available

    # 1) Load all models from ./models
    models_dir = "./models"
    models = load_all_models(models_dir, device=device)
    if not models:
        print("[INFO] No models found in ./models. Place your .pt/.pth state_dicts there.")

    # 2) Row-prune all models and export artifacts
    out_root = "./prune_outputs"
    os.makedirs(out_root, exist_ok=True)
    pruned_models = {}
    records_map: Dict[str, List[LayerPruneRecord]] = {}
    for name, mdl in models.items():
        pruned_mdl, recs = run_row_prune_and_save(name, mdl, out_root, renormalize=False)
        pruned_models[name] = pruned_mdl
        records_map[name] = recs

    # 3) Compare two pruned models (edit names below to ones you actually have)
    names = sorted(list(pruned_models.keys()))
    if len(names) >= 2:
        A_name = names[0]
        B_name = names[1]
        print(f"\n[COMPARE ROW-PRUNED] {A_name}  VS  {B_name}")
        cmp_out_dir = os.path.join(out_root, f"compare_{A_name}_VS_{B_name}_ROWPRUNE")

        # Compare pruned weights (ROW-focused)
        _ = compare_two_models_row_focus(
            A_name, pruned_models[A_name],
            B_name, pruned_models[B_name],
            out_dir=cmp_out_dir,
            canonical_axis="columns",   # robust to input-feature permutations
            canonical_method="pc1"
        )

        # Also compare masks derived during row pruning
        compare_prune_masks_row_focus(records_map[A_name], records_map[B_name], out_dir=cmp_out_dir, size=(256, 256))
    else:
        print("[INFO] Need at least two models in ./models to run comparison.")

    # 4) All-vs-all similarity (ROW-focused) across pruned models
    if len(names) >= 2:
        model_level_similarity_row_focus(pruned_models, out_dir=os.path.join(out_root, "pairwise_row_model_similarity"), canonical_axis="columns")


Loaded: Model_1_Underfit
Loaded: Model_2_Slight_Underfit
Loaded: Model_3_Well_Trained
Loaded: Model_4_Well_Trained_Deep
Loaded: Model_5_Overfit
Loaded: Model_6_Extra_Overfit
Loaded: Model_7_Extra_Overfit
Pruning Model_1_Underfit [per_row] ...
[Saved] PDF report: ./prune_outputs\Model_1_Underfit\Model_1_Underfit_row_prune_report.pdf
Pruning Model_2_Slight_Underfit [per_row] ...
[Saved] PDF report: ./prune_outputs\Model_2_Slight_Underfit\Model_2_Slight_Underfit_row_prune_report.pdf
Pruning Model_3_Well_Trained [per_row] ...
[Saved] PDF report: ./prune_outputs\Model_3_Well_Trained\Model_3_Well_Trained_row_prune_report.pdf
Pruning Model_4_Well_Trained_Deep [per_row] ...
[Saved] PDF report: ./prune_outputs\Model_4_Well_Trained_Deep\Model_4_Well_Trained_Deep_row_prune_report.pdf
Pruning Model_5_Overfit [per_row] ...
[Saved] PDF report: ./prune_outputs\Model_5_Overfit\Model_5_Overfit_row_prune_report.pdf
Pruning Model_6_Extra_Overfit [per_row] ...
[Saved] PDF report: ./prune_outputs\Model_6_E