# Imports and setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils.parametrizations as parametrizations
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.special import digamma

# Only necessary to work with LLMs
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# Toy Model
## Data

In [None]:
# Configuration
CONFIG = {
    'z': 40,                # Total number of features
    'd': 12,                # Hidden dimension (bottleneck)
    'n_groups': 2,          # Number of feature groups
    'features_per_group': 20,
    'sparsity': 0.1,       # Probability a group is completely inactive
    'batch_size': 128,
    'toy_lr': 1e-3,
    'toy_steps': 5000,
    'ndm_lr': 1e-3,
    'ndm_steps': 5000,      # Usually converges faster than paper's 60k for simple cases
    'unit_size': 6          # Subspace size (d / n_groups)
}

In [None]:
def generate_batch(batch_size, config):
    """
    Generates a batch of sparse features divided into groups.
    Within each group:
      - S probability of being 0 vector.
      - 1-S probability of having exactly 1 feature active (value in [0,1]).
    """
    z = config['z']
    features = torch.zeros((batch_size, z))

    for i in range(batch_size):
        for g in range(config['n_groups']):
            # Determine if this group is active
            if torch.rand(1) > config['sparsity']:
                # Select one feature within this group to activate
                group_start = g * config['features_per_group']
                local_idx = torch.randint(0, config['features_per_group'], (1,))
                feat_idx = group_start + local_idx

                # Assign random value in [0, 1]
                features[i, feat_idx] = torch.rand(1)

    return features.to(device)

# Visualize a sample batch
sample_batch = generate_batch(32, CONFIG)
plt.figure(figsize=(10, 4))
sns.heatmap(sample_batch.cpu().numpy(), cmap="viridis", cbar=False)
plt.title("Sample Feature Activations (Rows=Samples, Cols=Features)")
plt.xlabel("Feature Index")
plt.ylabel("Sample Index")
plt.show()

## The actual Toy Model
$x' = ReLU(W^T(Wx)+b)$

In [None]:
class ToyModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        # W maps features (z) to hidden (d)
        # We use a Linear layer but effectively it acts as W
        self.encoder = nn.Linear(config['z'], config['d'], bias=False)

        # Decoder uses the transpose of encoder weights (tied weights)
        # x' = ReLU(W.T @ h + b)
        self.bias = nn.Parameter(torch.zeros(config['z']))
        self.relu = nn.ReLU()

    def forward(self, x):
        # h = Wx
        h = self.encoder(x)

        # Reconstruction: x' = ReLU(W^T h + b)
        # Note: F.linear(input, weight) does input @ weight.T
        # We want h @ W (which is x @ W.T @ W).
        # To strictly follow h = Wx column vector notation:
        # Encoder: x (B, z) @ W.T (z, d) -> (B, d)
        # Decoder: h (B, d) @ W (d, z) + b

        x_recon = torch.matmul(h, self.encoder.weight) + self.bias
        x_recon = self.relu(x_recon)
        return h, x_recon

# Initialize Model
toy_model = ToyModel(CONFIG).to(device)
optimizer = optim.Adam(toy_model.parameters(), lr=CONFIG['toy_lr'])

# Training Loop
losses = []
print("Training Toy Model...")
for step in range(CONFIG['toy_steps']):
    batch = generate_batch(CONFIG['batch_size'], CONFIG)

    h, x_recon = toy_model(batch)

    loss = nn.MSELoss()(x_recon, batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 1000 == 0:
        print(f"Step {step}/{CONFIG['toy_steps']}: Loss {loss.item():.6f}")
        losses.append(loss.item())

print("Toy Model Trained.")

## Visualization of superposition / interference

In [None]:
def plot_matrices(model, title="Model Weights Analysis", show_abs=False):
    W = model.encoder.weight.detach().cpu() # Shape (d, z)

    # Calculate W^T W (Feature interference)
    WTW = torch.matmul(W.t(), W)

    # Plot W^T W
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    tmp = WTW.numpy()
    if show_abs:
      tmp = np.abs(tmp)
    sns.heatmap(tmp, center=0, cmap="coolwarm", square=True)
    title = f"$W^T W$ (Feature Interference)\nTarget: Block Diagonal"
    if show_abs:
      title = f"$|W^T W|$ (Feature Interference)\nTarget: Block Diagonal"
    plt.title(title)

    plt.tight_layout()
    plt.show()

plot_matrices(toy_model, show_abs=True)

This figure shows that features in groups are interefering with each other, while there is close to no inter-group interference.

## NDM - Neighbor Distance Minimization

In [None]:
class NDM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d = config['d']
        self.unit_size = config['unit_size'] # Dimension of subspaces
        self.n_subspaces = self.d // self.unit_size

        # Learnable Orthogonal Matrix R [cite: 109]
        # We parameterize it to ensure it stays orthogonal during training
        self.rotation = nn.Linear(self.d, self.d, bias=False)
        parametrizations.orthogonal(self.rotation, 'weight')

        # Initialize as Identity
        with torch.no_grad():
            self.rotation.weight.copy_(torch.eye(self.d))

    def get_rotated_activations(self, h):
        return self.rotation(h)

def ndm_loss(activations, unit_size):
    """
    Calculates Neighbor Distance Minimization loss.
    Eq 5 in the paper.
    """
    batch_size, d = activations.shape
    n_subspaces = d // unit_size

    total_loss = 0

    # Iterate over each subspace
    for s in range(n_subspaces):
        # Extract subspace s: shape (B, unit_size)
        start_idx = s * unit_size
        end_idx = start_idx + unit_size
        subspace_acts = activations[:, start_idx:end_idx]

        # Compute pairwise distances in this subspace
        # (B, 1, unit) - (1, B, unit) -> (B, B, unit) -> norm -> (B, B)
        dists = torch.cdist(subspace_acts, subspace_acts, p=2)

        # Mask diagonal (distance to self is 0) with infinity
        mask = torch.eye(batch_size).to(dists.device) * 1e9
        dists = dists + mask

        # Find distance to nearest neighbor for each point [cite: 116]
        min_dists, _ = torch.min(dists, dim=1)

        # Average over batch
        total_loss += torch.mean(min_dists)

    return total_loss / n_subspaces

## Train NDM

In [None]:
ndms = {}
# loop over multiple learning rates for the NDM
CONFIG['ndm_steps'] = 1_000
for x in [0.001]: #[0.0001, 0.0005, 0.001, 0.005, 0.01]:
  print(f"lr={x}")
  ndm_model = NDM(CONFIG).to(device)
  CONFIG['ndm_lr'] = x
  ndm_optimizer = optim.Adam(ndm_model.parameters(), lr=CONFIG['ndm_lr'])
  losses = []

  print("Training NDM (Finding subspaces)...")

  # Collect a fixed buffer of activations to train on (simulating dataset)
  # In the paper they act on live activations, here we generate on fly
  for step in range(CONFIG['ndm_steps']):
      with torch.no_grad():
          batch = generate_batch(CONFIG['batch_size'], CONFIG)
          h_frozen, _ = toy_model(batch) # Get hidden states from frozen toy model

      # 1. Rotate
      h_rotated = ndm_model.get_rotated_activations(h_frozen)

      # 2. Calculate NDM Loss
      loss = ndm_loss(h_rotated, CONFIG['unit_size'])

      ndm_optimizer.zero_grad()
      loss.backward()
      ndm_optimizer.step()

      if step % 500 == 0:
          print(f"Step {step}/{CONFIG['ndm_steps']}: NDM Loss {loss.item():.6f}")
          losses.append(loss.item())

  ndms[x] = {"model": ndm_model, "losses": losses}
print("NDM Training Complete.")

Learning rate 0.001 with ~1000 steps seems to work well for me.


## Evaluate NDM

In [None]:
def visualize_results(toy_model, ndm_model, config):
    W = toy_model.encoder.weight.detach() # (d, z)
    R = ndm_model.rotation.weight.detach() # (d, d)

    # Compute RW [cite: 145]
    # Note on shape:
    # h = Wx. h_rot = R(Wx) = (RW)x.
    # So the effective weight matrix is R @ W
    RW = torch.matmul(R, W).cpu()

    # Create mask for visual clarity of groups
    # Group 1: features 0-19, Group 2: features 20-39

    plt.figure(figsize=(14, 6))

    # Plot 1: The Raw Weights (W)
    plt.subplot(1, 2, 1)
    sns.heatmap(W.cpu().numpy(), center=0, cmap="coolwarm")
    plt.title("Original W (Entangled)\nRows=Hidden Dims, Cols=Features")
    plt.xlabel("Feature Index")
    plt.ylabel("Hidden Dimension")

    # Plot 2: The Rotated Weights (RW)
    plt.subplot(1, 2, 2)
    sns.heatmap(RW.numpy(), center=0, cmap="coolwarm")

    # Add lines to show subspaces and groups
    # Horizontal line separating subspaces (at d=6)
    plt.axhline(y=config['unit_size'], color='black', linewidth=2, linestyle='--')
    # Vertical line separating feature groups (at z=20)
    plt.axvline(x=config['features_per_group'], color='black', linewidth=2, linestyle='--')

    plt.title("Rotated RW (Disentangled by NDM)\nTop half should capture Group 1, Bottom half Group 2")
    plt.xlabel("Feature Index")
    plt.ylabel("Rotated Hidden Dimension")

    plt.tight_layout()
    plt.savefig("w_images.png")
    plt.show()

visualize_results(toy_model, ndms[list(ndms.keys())[0]]["model"], CONFIG)

The right image shows/should show a block matrix, nicely separating the groups of features. It does not have to be in the top left and bottom right, since the method is unsupervised.

# Language models
## Load language model

In [None]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
lm_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
lm_model.eval()

LM_NDM_CONFIG = {"unit_size": 32, "d_model": 768, "n_subspaces": 768//32}

## Gathering activations

In [None]:
# We target the residual stream after the MLP in Layer 6
activations_buffer = []

def hook_fn(module, input, output):
    # output is a tensor of shape (batch, seq_len, d_model)
    # We flatten the batch and seq_len dimensions to get a pool of vectors
    #print(type(output[-1]))
    activations_buffer.append(output[-1].detach().view(-1, output[-1].shape[-1]))

handle = lm_model.transformer.h[6].register_forward_hook(hook_fn)

# 3. Load Dataset and Run Forward Passes
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
# Filter out empty strings
dataset = dataset.filter(lambda x: len(x['text']) > 50)

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=128, padding="max_length")

tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
dataloader = DataLoader(tokenized_dataset, batch_size=8)

print("Collecting activations from Layer 6...")
# Collect ~20,000 activation vectors
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        if len(activations_buffer) * 8 * 128 > 20000: break
        lm_model(batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device))

# Clean up hooks and concatenate data
handle.remove()
all_activations = torch.cat(activations_buffer, dim=0)
# Shuffle and limit to a fixed size for the NDM search buffer
indices = torch.randperm(all_activations.size(0))[:20000]
all_activations = all_activations[indices].to(device)

print(f"Collected {all_activations.shape[0]} vectors of dimension {all_activations.shape[1]}")

## Train NDM

In [None]:
class ActivationBuffer:
    """
    Simulates the Buffer B described in Algorithm 1.
    """
    def __init__(self, activations, device):
        # activations: tensor of shape (Total_N, d)
        self.device = device
        self.data = activations.to(self.device)
        self.current_idx = 0
        
    def pop(self, batch_size):
        """Returns a batch and removes it from the available pool (simulated)."""
        if self.current_idx + batch_size > len(self.data):
            self.current_idx = 0 
            
        indices = torch.arange(self.current_idx, self.current_idx + batch_size, device=self.device)
        batch = self.data[indices]
        self.current_idx += batch_size
        return batch

    def sample_block(self, block_size):
        """Returns a random block for neighbor searching (B.next())."""
        indices = torch.randperm(len(self.data), device=self.device)[:block_size]
        return self.data[indices]

class NDM(nn.Module):
    def __init__(self, input_dim, initial_subspace_dim, device):
        super().__init__()
        self.d = input_dim
        self.device = device
        
        # Initialize R as Identity
        self.transform = nn.Linear(input_dim, input_dim, bias=False, device=device)
        nn.init.eye_(self.transform.weight)
        parametrizations.orthogonal(self.transform, 'weight')
        
        # Initial configuration c: equal-sized subspaces
        # [cite_start]For GPT-2 Small, paper uses unit size 32 [cite: 755]
        #num_subspaces = input_dim // initial_subspace_dim
        self.config = [initial_subspace_dim] * LM_NDM_CONFIG["n_subspaces"]
        if sum(self.config) < input_dim:
            self.config.append(input_dim - sum(self.config))
            
    def get_subspace_views(self, h_transformed):
        """Splits the transformed activations into subspace views based on current config."""
        return torch.split(h_transformed, self.config, dim=1)

    def forward(self, x):
        return self.transform(x)

def ksg_estimator(x, y, k=3):
    """
    Estimates Mutual Information I(X;Y) using the KSG estimator.
    """
    # x, y are on 'device'
    N = x.shape[0]
    
    # Concatenate to form joint space Z = (X, Y)
    z = torch.cat([x, y], dim=1)
    
    # Compute pairwise distances (Euclidean) on device
    dists = torch.cdist(z, z, p=2)
    
    # Get distance to k-th neighbor
    radius, _ = torch.topk(dists, k + 1, largest=False)
    radius = radius[:, k] # (N,)
    
    # Count points in marginal spaces within radius
    dist_x = torch.cdist(x, x, p=2)
    # Note: comparisons and sums performed on device
    nx = (dist_x < radius.unsqueeze(1)).sum(dim=1).float() - 1 
    
    dist_y = torch.cdist(y, y, p=2)
    ny = (dist_y < radius.unsqueeze(1)).sum(dim=1).float() - 1 
    
    # KSG Calculation involves digamma, which is not standard in PyTorch.
    # We move valid counts to CPU for scipy.special.digamma
    psi_k = digamma(k)
    psi_N = digamma(N)
    
    # .cpu().numpy() is necessary for scipy
    nx_np = nx.detach().cpu().numpy()
    ny_np = ny.detach().cpu().numpy()
    
    avg_psi_n = np.mean(digamma(nx_np + 1) + digamma(ny_np + 1))
    
    mi = psi_k - avg_psi_n + psi_N
    return max(0.0, float(mi))

In [None]:
def train_ndm_algorithm(
    buffer: ActivationBuffer,
    model: NDM,
    device,
    steps_K=10000,
    batch_size_b=128,
    search_steps_n=4, 
    block_size_m=128,
    merge_start_delay_p=2000,
    merge_interval_t=1000,
    merge_threshold_tau=0.04,
    lr=0.0005
):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for k in range(1, steps_K + 1):
        # --- 1. Data Loading ---
        H_q = buffer.pop(batch_size_b) # Already on device
        
        with torch.no_grad():
            H_hat_q = model(H_q)
            subspaces_q = model.get_subspace_views(H_hat_q)
            
            # Initialize Nearest Neighbors on device
            D_nearest = [torch.full((batch_size_b,), float('inf'), device=device) for _ in model.config]
            # H_nearest stores raw vectors. We initialize with zeros on device.
            H_nearest = [torch.zeros((batch_size_b, model.d), device=device) for _ in model.config]
            
            # --- 2. Neighbor Search ---
            for _ in range(search_steps_n):
                H_k = buffer.sample_block(block_size_m) # On device
                H_hat_k = model(H_k)
                subspaces_k = model.get_subspace_views(H_hat_k)
                
                for s_idx, (sub_q, sub_k) in enumerate(zip(subspaces_q, subspaces_k)):
                    # dists: (b, m)
                    dists = torch.cdist(sub_q, sub_k, p=2) 
                    
                    min_dists, min_indices = dists.min(dim=1) # (b,)
                    
                    improved_mask = min_dists < D_nearest[s_idx]
                    D_nearest[s_idx][improved_mask] = min_dists[improved_mask]
                    
                    best_indices = min_indices[improved_mask]
                    # Gather raw vectors from H_k for backprop later
                    H_nearest[s_idx][improved_mask] = H_k[best_indices].detach()

        # --- 3. Gradient Calculation ---
        optimizer.zero_grad()
        loss = torch.tensor(0.0, device=device)
        
        H_hat_q_grad = model(H_q)
        subspaces_q_grad = model.get_subspace_views(H_hat_q_grad)
        
        for s_idx, sub_q_grad in enumerate(subspaces_q_grad):
            H_k_raw = H_nearest[s_idx] 
            
            H_hat_k_grad = model(H_k_raw)
            
            # Slice manually for the specific subspace s
            start_dim = sum(model.config[:s_idx])
            end_dim = start_dim + model.config[s_idx]
            sub_k_grad = H_hat_k_grad[:, start_dim:end_dim]
            
            d = torch.norm(sub_q_grad - sub_k_grad, p=2, dim=1)
            loss += d.mean()
            
        loss = loss / len(model.config)
        loss.backward()
        optimizer.step()
        
        # --- 4. Merging Step ---
        if k > merge_start_delay_p and k % merge_interval_t == 0:
            merge_subspaces(model, buffer, block_size_m * 4, merge_threshold_tau, device)
            # Re-init optimizer to handle parameter changes if necessary 
            # (though PyTorch handles param permutation in-place well usually)
            optimizer = optim.Adam(model.parameters(), lr=lr)
            print(f"Step {k}: Config updated to {model.config}")
        
        if k % 50 == 0:
            print(f"Step {k}/{steps_K}, Loss: {loss.item():.4f}")

def merge_subspaces(model, buffer, sample_size, threshold, device):
    """
    Measures MI between subspaces and updates model configuration.
    """
    with torch.no_grad():
        batch = buffer.sample_block(sample_size)
        projected = model(batch)
        subspaces = model.get_subspace_views(projected)
        S = len(subspaces)
        
        # Normalize variance
        norm_subspaces = []
        for s in subspaces:
            variances = torch.var(s, dim=0, unbiased=False)
            norm_factor = torch.sqrt(variances.sum()) + 1e-6
            norm_subspaces.append(s / norm_factor)
            
        pairs = []
        for i in range(S):
            for j in range(i + 1, S):
                mi_val = ksg_estimator(norm_subspaces[i], norm_subspaces[j])
                
                denom = model.config[i] + model.config[j]
                normalized_mi = mi_val / denom
                
                if normalized_mi > threshold:
                    pairs.append((normalized_mi, i, j))
        
        pairs.sort(key=lambda x: x[0], reverse=True)
        
        # Merge top non-intersecting pairs
        max_merges = max(1, S // 8)
        merged_indices = set()
        new_merges = []
        
        for _, i, j in pairs:
            if len(new_merges) >= max_merges:
                break
            if i not in merged_indices and j not in merged_indices:
                new_merges.append((i, j))
                merged_indices.add(i)
                merged_indices.add(j)
        
        if not new_merges:
            return

        # Permute R to make merged subspaces adjacent
        perm_indices = []
        processed = [False] * S
        final_config_list = []
        
        for i, j in new_merges:
            start_i = sum(model.config[:i])
            indices_i = list(range(start_i, start_i + model.config[i]))
            
            start_j = sum(model.config[:j])
            indices_j = list(range(start_j, start_j + model.config[j]))
            
            perm_indices.extend(indices_i)
            perm_indices.extend(indices_j)
            
            final_config_list.append(model.config[i] + model.config[j])
            processed[i] = True
            processed[j] = True
            
        for k in range(S):
            if not processed[k]:
                start_k = sum(model.config[:k])
                perm_indices.extend(range(start_k, start_k + model.config[k]))
                final_config_list.append(model.config[k])
                
        # Update R with permutation on device
        perm_tensor = torch.tensor(perm_indices, device=device)
        with torch.no_grad():
            model.transform.weight.data = model.transform.weight.data[perm_tensor, :]
            
        model.config = final_config_list

buffer = ActivationBuffer(all_activations, device=device)
ndm_model = NDM(input_dim=LM_NDM_CONFIG['d_model'], initial_subspace_dim=32, device=device)
train_ndm_algorithm(
    buffer=buffer, 
    model=ndm_model, 
    device=device,
    steps_K=500,           # Reduced for demo
    batch_size_b=128,
    search_steps_n=8,       # Increase for better accuracy if compute allows
    block_size_m=128,
    merge_start_delay_p=500, # Reduced for demo
    merge_interval_t=500,    # Reduced for demo
    merge_threshold_tau=0.04,
    lr=1e-9
)

## Analysis

In [None]:
def analyze_subspaces(model, tokenizer, rotation_matrix, text_list, config, layer_state):
    model.eval()
    unit_size = config['unit_size']
    n_subspaces = config['d_model'] // unit_size

    # Store top activations: dictionary mapping subspace_idx -> list of (norm, token, context)
    top_acts = defaultdict(list)

    print(f"Scanning {len(text_list)} texts for interpretable patterns...")

    with torch.no_grad():
        for text_idx, text in enumerate(text_list):
            inputs = tokenizer(text, return_tensors="pt").to(device)
            if inputs['input_ids'].shape[1] < 10: continue

            # Run model
            outputs = model(**inputs, output_hidden_states=True)

            # In GPT-2 hidden_states, index 0 is embeddings, 1 is layer 1... 6 is layer 6 output.
            # (Note: if you used 'output[-1]' in the hook, ensure this matches that tensor).
            # Usually Layer 6 is hidden_states[7] (embedding + 6 layers).
            h = outputs.hidden_states[layer_state]

            # Rotate: h_rot = h @ R^T
            h_rot = torch.matmul(h, rotation_matrix.t())

            # Analyze each token
            seq_len = h.shape[1]
            for i in range(seq_len):
                token_id = inputs['input_ids'][0, i]
                token_str = tokenizer.decode(token_id)

                # Context (previous 5 tokens)
                start_ctx = max(0, i-5)
                context_str = tokenizer.decode(inputs['input_ids'][0, start_ctx:i+1])

                # Check norms for each subspace
                for s in range(n_subspaces):
                    # Extract vector for subspace s
                    vec = h_rot[0, i, s*unit_size : (s+1)*unit_size]
                    norm = torch.norm(vec).item()

                    # Store (simplified logic: just keep everything then sort later)
                    # For efficiency on large data, you'd use a min-heap here
                    top_acts[s].append({
                        'norm': norm,
                        'token': token_str,
                        'context': context_str
                    })

    # Sort and prune to top 5
    final_results = {}
    for s in top_acts:
        sorted_acts = sorted(top_acts[s], key=lambda x: x['norm'], reverse=True)
        final_results[s] = sorted_acts[:5]

    return final_results


In [None]:
# Get rotation matrix R from trained NDM
R_matrix = ndm_model.transform.weight.detach()

# Use the first 50 texts from the dataset for quick analysis
sample_texts = dataset['text'][:50]

results = analyze_subspaces(lm_model, tokenizer, R_matrix, sample_texts, LM_NDM_CONFIG, 6)

# Display a few subspaces
for s in range(5): # Print first 5 subspaces
    print(f"\n=== Subspace {s} ===")
    for item in results[s]:
        print(f"[{item['norm']:.2f}] Token: {repr(item['token']):<15} | Context: ...{repr(item['context'])}")

## Validation / Gini coefficient

In [None]:
def gini(x):
    # x shape: (n_subspaces,)
    x = x.abs() + 1e-8 # stability
    x = torch.sort(x)[0]
    n = x.shape[0]
    index = torch.arange(1, n + 1, device=x.device)
    return (torch.sum((2 * index - n - 1) * x)) / (n * torch.sum(x))

print("Calculating Gini Coefficient (Sparsity)...")

gini_scores = []
batch_size = 1000

with torch.no_grad():
    # Take a test batch
    test_batch = all_activations[:batch_size]
    # Rotate
    h_rot = ndm_model.forward(test_batch)

    # Calculate norm of each subspace for each sample
    n_subspaces = LM_NDM_CONFIG['d_model'] // LM_NDM_CONFIG['unit_size']
    subspace_norms = []

    for s in range(n_subspaces):
        start = s * LM_NDM_CONFIG['unit_size']
        end = start + LM_NDM_CONFIG['unit_size']
        norms = torch.norm(h_rot[:, start:end], dim=1) # (B,)
        subspace_norms.append(norms)

    # Stack to (B, n_subspaces)
    subspace_norms = torch.stack(subspace_norms, dim=1)

    # Calculate Gini for each vector
    for i in range(batch_size):
        g = gini(subspace_norms[i])
        gini_scores.append(g.item())

avg_gini = sum(gini_scores) / len(gini_scores)
print(f"Average Gini Coefficient: {avg_gini:.4f}")

if avg_gini > 0.5:
    print("SUCCESS: High sparsity detected. Subspaces likely encode distinct features.")
else:
    print("NOTE: Gini is low. Try training longer or checking if the correct layer is hooked.")