In [1]:
import numpy as np
import json
from pathlib import Path
from collections import defaultdict

ensemble_dir = Path("/mnt/polished-lake/spd/ensemble/e-5f228e5f")

# Load the merge array data
merge_data = np.load(ensemble_dir / "ensemble_merge_array.npz")
merge_array = merge_data['merge_array']  # shape (10, 300, 442)

# Load metadata to get component labels
with open(ensemble_dir / "ensemble_meta.json") as f:
    meta = json.load(f)

print(f"Merge array shape: {merge_array.shape}")
print(f"Number of component labels: {len(meta['component_labels'])}")
print(f"Sample labels: {meta['component_labels'][:5]}")


Merge array shape: (10, 300, 442)
Number of component labels: 442
Sample labels: ['h.0.mlp.down_proj:1', 'h.0.mlp.down_proj:10', 'h.0.mlp.down_proj:100', 'h.0.mlp.down_proj:101', 'h.0.mlp.down_proj:102']


In [2]:
# Get cluster assignments for Run 0 at final iteration (299)
run_idx = 0
iter_idx = 299
assignments = merge_array[run_idx, iter_idx, :]

print(f"Run {run_idx}, Iteration {iter_idx}")
print(f"Number of components: {len(assignments)}")
print(f"Number of unique clusters: {len(np.unique(assignments))}")

# Group components by cluster
clusters = defaultdict(list)
for comp_idx, cluster_id in enumerate(assignments):
    clusters[cluster_id].append(comp_idx)

# Sort clusters by size (descending)
sorted_clusters = sorted(clusters.items(), key=lambda x: -len(x[1]))

print(f"\nCluster sizes (top 10):")
for i, (cluster_id, members) in enumerate(sorted_clusters[:10]):
    print(f"  Cluster {i+1}: {len(members)} components")

print(f"\nTotal clusters: {len(sorted_clusters)}")


Run 0, Iteration 299
Number of components: 442
Number of unique clusters: 142

Cluster sizes (top 10):
  Cluster 1: 68 components
  Cluster 2: 16 components
  Cluster 3: 14 components
  Cluster 4: 11 components
  Cluster 5: 10 components
  Cluster 6: 8 components
  Cluster 7: 8 components
  Cluster 8: 7 components
  Cluster 9: 6 components
  Cluster 10: 6 components

Total clusters: 142


In [3]:
# Output all clusters with their raw component indices, ordered by size
# Also include neuron indices from labels for reference

component_labels = meta['component_labels']

def parse_neuron_idx(label):
    """Parse 'h.0.mlp.down_proj:N' to get neuron index N"""
    return int(label.split(':')[1])

neuron_indices = [parse_neuron_idx(l) for l in component_labels]

print("=" * 80)
print("CLUSTER ASSIGNMENTS FOR RUN 0, ITERATION 299")
print("=" * 80)
print(f"\nTotal: 442 components in 142 clusters")
print(f"Component indices are 0-indexed positions in the merge array")
print(f"Neuron indices are the original MLP neuron IDs (from 'h.0.mlp.down_proj:N')")
print("\n")

for rank, (cluster_id, members) in enumerate(sorted_clusters):
    neuron_ids = [neuron_indices[m] for m in members]
    print(f"Cluster {rank + 1} (size {len(members)}):")
    print(f"  Raw component indices: {sorted(members)}")
    print(f"  Neuron indices: {sorted(neuron_ids)}")
    print()


CLUSTER ASSIGNMENTS FOR RUN 0, ITERATION 299

Total: 442 components in 142 clusters
Component indices are 0-indexed positions in the merge array
Neuron indices are the original MLP neuron IDs (from 'h.0.mlp.down_proj:N')


Cluster 1 (size 68):
  Raw component indices: [5, 13, 21, 27, 32, 35, 44, 48, 52, 64, 85, 91, 98, 100, 103, 108, 109, 110, 111, 114, 118, 121, 136, 138, 144, 150, 163, 172, 178, 183, 185, 197, 213, 216, 220, 232, 238, 240, 241, 242, 254, 262, 263, 274, 283, 284, 287, 304, 308, 313, 318, 319, 331, 335, 338, 344, 358, 367, 379, 380, 383, 388, 399, 411, 415, 421, 422, 436]
  Neuron indices: [5, 6, 17, 24, 33, 38, 54, 60, 63, 65, 71, 73, 93, 103, 115, 128, 134, 140, 143, 155, 163, 184, 210, 219, 230, 232, 235, 240, 241, 244, 247, 251, 255, 277, 279, 287, 297, 310, 324, 336, 339, 354, 384, 388, 401, 408, 410, 411, 412, 427, 437, 439, 453, 465, 467, 472, 495, 505, 512, 513, 528, 537, 549, 570, 584, 602, 608, 643]

Cluster 2 (size 16):
  Raw component indices: [49, 53, 55, 

In [4]:
# Compact output: just raw component indices per cluster
print("CLUSTERS BY SIZE (raw component indices 0-441):")
print("=" * 60)
for rank, (cluster_id, members) in enumerate(sorted_clusters):
    print(f"Cluster {rank + 1:3d} (n={len(members):2d}): {sorted(members)}")


CLUSTERS BY SIZE (raw component indices 0-441):
Cluster   1 (n=68): [5, 13, 21, 27, 32, 35, 44, 48, 52, 64, 85, 91, 98, 100, 103, 108, 109, 110, 111, 114, 118, 121, 136, 138, 144, 150, 163, 172, 178, 183, 185, 197, 213, 216, 220, 232, 238, 240, 241, 242, 254, 262, 263, 274, 283, 284, 287, 304, 308, 313, 318, 319, 331, 335, 338, 344, 358, 367, 379, 380, 383, 388, 399, 411, 415, 421, 422, 436]
Cluster   2 (n=16): [49, 53, 55, 97, 125, 128, 143, 199, 259, 261, 295, 323, 332, 382, 393, 435]
Cluster   3 (n=14): [24, 54, 65, 83, 139, 147, 173, 251, 300, 336, 365, 402, 403, 437]
Cluster   4 (n=11): [10, 45, 94, 131, 198, 246, 280, 292, 301, 330, 414]
Cluster   5 (n=10): [57, 62, 89, 137, 200, 219, 239, 276, 291, 298]
Cluster   6 (n= 8): [17, 47, 107, 140, 228, 314, 324, 355]
Cluster   7 (n= 8): [86, 176, 207, 249, 385, 410, 426, 430]
Cluster   8 (n= 7): [0, 42, 46, 169, 191, 230, 290]
Cluster   9 (n= 6): [7, 92, 165, 188, 189, 208]
Cluster  10 (n= 6): [19, 171, 193, 269, 376, 416]
Cluster  11

In [5]:
# Summary statistics
print("\nSUMMARY:")
print(f"  Total components: 442")
print(f"  Total clusters: 142")
print(f"  Largest cluster: 68 components (15.4% of total)")
print(f"  Singletons: 47 clusters")
print(f"  Multi-component clusters: 95")
print(f"\nCluster size distribution:")
size_counts = {}
for _, members in sorted_clusters:
    size = len(members)
    size_counts[size] = size_counts.get(size, 0) + 1
for size in sorted(size_counts.keys(), reverse=True):
    print(f"  Size {size:2d}: {size_counts[size]} cluster(s)")



SUMMARY:
  Total components: 442
  Total clusters: 142
  Largest cluster: 68 components (15.4% of total)
  Singletons: 47 clusters
  Multi-component clusters: 95

Cluster size distribution:
  Size 68: 1 cluster(s)
  Size 16: 1 cluster(s)
  Size 14: 1 cluster(s)
  Size 11: 1 cluster(s)
  Size 10: 1 cluster(s)
  Size  8: 2 cluster(s)
  Size  7: 1 cluster(s)
  Size  6: 4 cluster(s)
  Size  5: 5 cluster(s)
  Size  4: 12 cluster(s)
  Size  3: 24 cluster(s)
  Size  2: 42 cluster(s)
  Size  1: 47 cluster(s)
