In [17]:
#!/usr/bin/env python3
"""
Check whether the true output-neuron ordering minimises the Frobenius-norm
distance to a reference geometry, compared with 1 000 random permutations.
"""

import itertools
import random
from pathlib import Path

import numpy as np
import torch

# ----------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------
MODEL_TYPE   = "fully_connected_dropout"
DATASET_TYPE = "mnist"
HIDDEN_DIM   = "[50,50]"              # part of the folder name
UNTRAINED    = False                   # set to False for trained models

REFERENCE_SEEDS = range(0, 800)        # seeds whose geometry you trust
TEST_SEEDS      = range(800, 810)       # seeds to evaluate
N_RANDOM_PERMS  = 100_000               # how many random permutations to test

BASEDIR = Path("saved_models")
UNTRAINED_SUFFIX = "-untrained" if UNTRAINED else ""
MODEL_FMT = f"{MODEL_TYPE}-{DATASET_TYPE}{UNTRAINED_SUFFIX}-hidden_dim_{HIDDEN_DIM}/seed-{{seed}}"

rng = np.random.default_rng(42)       # reproducibility

# ----------------------------------------------------------------------
# Helper functions
# ----------------------------------------------------------------------
def load_last_layer(seed: int) -> np.ndarray:
    """Load last-layer weight matrix (shape: [num_classes, hidden])."""
    path = BASEDIR / MODEL_FMT.format(seed=seed)
    ckpt = torch.load(path, map_location="cpu")
    w = ckpt["layers.2.weight"].cpu().detach().numpy()  # shape [C, H]
    # L2-normalise rows so cosine similarity == dot product
    w /= np.linalg.norm(w, axis=1, keepdims=True) + 1e-12
    return w

def gram(w: np.ndarray) -> np.ndarray:
    """Return the C×C Gram matrix of row-normalised weights."""
    return w @ w.T                                      # cosine similarities

def random_permutation(c: int) -> np.ndarray:
    """Return a random permutation array of 0…c-1."""
    return rng.permutation(c)

def frob(A: np.ndarray, B: np.ndarray) -> float:
    """Frobenius-norm distance ‖A−B‖_F."""
    return np.linalg.norm(A - B, ord="fro")

# ----------------------------------------------------------------------
# 1) Build reference Gram matrix (mean over reference seeds)
# ----------------------------------------------------------------------
print("Building reference geometry …")
ref_grams = [gram(load_last_layer(s)) for s in REFERENCE_SEEDS]
G_ref = np.mean(ref_grams, axis=0)     # still PSD, see discussion

C = G_ref.shape[0]                     # number of classes / output neurons
print(f"  → Averaged over {len(REFERENCE_SEEDS)} seeds, C = {C}\n")

# ----------------------------------------------------------------------
# 2) Evaluate each test seed
# ----------------------------------------------------------------------
hit_counter = 0

for seed in TEST_SEEDS:
    W_test = load_last_layer(seed)
    G_test = gram(W_test)

    # Distance with the *true* ordering
    d_true = frob(G_ref, G_test)

    # Distances for N_RANDOM_PERMS random permutations
    worse_count = 0
    for _ in range(N_RANDOM_PERMS):
        p = random_permutation(C)
        G_perm = G_test[p][:, p]        # rows & cols permuted simultaneously
        d_perm = frob(G_ref, G_perm)
        if d_perm < d_true - 1e-12:     # strict improvement (tolerance)
            worse_count += 1

    is_best = (worse_count == 0)
    hit_counter += is_best

    print(f"Seed {seed:3d}: "
          f"d_true = {d_true:8.4f}  |  "
          f"random perms better = {worse_count:4d} / {N_RANDOM_PERMS}  "
          f"→ {'✓ best' if is_best else '✗ not best'}")

# ----------------------------------------------------------------------
# 3) Summary
# ----------------------------------------------------------------------
print(f"\nOriginal ordering was the unique best in "
      f"{hit_counter} / {len(TEST_SEEDS)} test seeds.")


Building reference geometry …
  → Averaged over 800 seeds, C = 10

Seed 800: d_true =   0.6638  |  random perms better =    0 / 100000  → ✓ best
Seed 801: d_true =   0.7098  |  random perms better =    0 / 100000  → ✓ best
Seed 802: d_true =   0.6543  |  random perms better =    0 / 100000  → ✓ best
Seed 803: d_true =   0.7917  |  random perms better =    0 / 100000  → ✓ best
Seed 804: d_true =   0.7398  |  random perms better =    0 / 100000  → ✓ best
Seed 805: d_true =   0.6107  |  random perms better =    0 / 100000  → ✓ best
Seed 806: d_true =   0.6630  |  random perms better =    0 / 100000  → ✓ best
Seed 807: d_true =   0.6766  |  random perms better =    0 / 100000  → ✓ best
Seed 808: d_true =   0.7665  |  random perms better =    0 / 100000  → ✓ best
Seed 809: d_true =   0.8484  |  random perms better =    0 / 100000  → ✓ best

Original ordering was the unique best in 10 / 10 test seeds.
