# PC-Label Alignment — Analysis & Visualisation

This notebook loads the pre-computed probe results and produces all statistical analyses and figures.
It is fully self-contained: no model, no dataset, no GPU required.

**Prerequisites**: run the extraction + probe notebook first and save outputs to `probe_results/`.

## 0. Imports & Style

In [None]:
import numpy as np
import pandas as pd
import json
import os
import warnings
warnings.filterwarnings('ignore')

from scipy import stats
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
import seaborn as sns

# ── Palette & style ───────────────────────────────────────────────────────────
COLORS = {
    'layer0': '#2d77a6',
    'layer1': '#bf7b04',
    'layer2': '#6ea66d',
    'layer3': '#808080',
    'accent':  '#d62728',
    'neutral': '#7f7f7f'
}
LAYER_COLORS = [COLORS['layer0'], COLORS['layer1'], COLORS['layer2'], COLORS['layer3']]

plt.style.use('seaborn-v0_8-paper')
sns.set_context('paper', font_scale=1.2)
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'mathtext.fontset': 'stix',
    'axes.labelsize': 11,
    'axes.titlesize': 12,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'legend.fontsize': 9,
    'figure.titlesize': 13,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'axes.axisbelow': True
})

RESULTS_DIR = 'probe_results'
FIGURES_DIR = 'figures_probe'
os.makedirs(FIGURES_DIR, exist_ok=True)
print('✅ Imports OK')

## 1. Load Results

In [None]:
DATASET_NAME = 'esc50'  # ← change as needed

results_df   = pd.read_csv(f'{RESULTS_DIR}/probe_{DATASET_NAME}.csv')
block_info_df = pd.read_csv(f'{RESULTS_DIR}/block_info.csv')

with open(f'{RESULTS_DIR}/meta_{DATASET_NAME}.json') as f:
    meta = json.load(f)

N_SAMPLES    = meta['n_samples']
N_CLASSES    = meta['n_classes']
HEAD_DIM     = meta['head_dim']
K_VALUES     = meta['k_values']
HTSAT_DEPTHS = meta['htsat_depths']
HTSAT_HEADS  = meta['htsat_heads']
N_BLOCKS     = sum(HTSAT_DEPTHS)
CHANCE       = 1.0 / N_CLASSES

# Convenience: k=1 slice used in most analyses
k1_df = results_df[results_df['k'] == 1].copy()

print(f'✅ Loaded  {DATASET_NAME.upper()}')
print(f'   Rows in results_df : {len(results_df)}')
print(f'   k values           : {K_VALUES}')
print(f'   N classes / chance : {N_CLASSES} / {CHANCE:.4f}')
display(results_df.head())

## 2. Summary Statistics by Layer

In [None]:
layer_summary = k1_df.groupby('layer').agg(
    acc_pc_mean   = ('acc_pc',   'mean'),
    acc_pc_std    = ('acc_pc',   'std'),
    acc_rand_mean = ('acc_rand', 'mean'),
    acc_rand_std  = ('acc_rand', 'std'),
    delta_mean    = ('delta',    'mean'),
    delta_std     = ('delta',    'std'),
    delta_max     = ('delta',    'max'),
    n_heads       = ('head_id',  'count')
).reset_index()

print('Layer-wise summary (k=1):')
display(layer_summary.round(4))

## 3. Statistical Tests

In [None]:
print('=' * 65)
print('One-sample t-test: H₀: mean(Δ) = 0  (k=1, per layer)')
print('=' * 65)

stat_rows = []
for layer_idx in range(4):
    deltas = k1_df[k1_df['layer'] == layer_idx]['delta'].values
    t_stat, p_val = stats.ttest_1samp(deltas, popmean=0)
    sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'
    print(f'  Stage {layer_idx+1} (L{layer_idx}):  '
          f'mean Δ = {deltas.mean():+.4f}  '
          f't = {t_stat:6.3f}  p = {p_val:.4f}  {sig}')
    stat_rows.append({
        'layer': layer_idx,
        'mean_delta': deltas.mean(),
        't_stat': t_stat,
        'p_value': p_val,
        'significant': sig
    })

stats_df = pd.DataFrame(stat_rows)

print()
print('── Pairwise layer comparison (Δ, Welch t-test) ──')
layers = range(4)
for i in layers:
    for j in layers:
        if j <= i: continue
        d_i = k1_df[k1_df['layer']==i]['delta'].values
        d_j = k1_df[k1_df['layer']==j]['delta'].values
        t, p = stats.ttest_ind(d_i, d_j, equal_var=False)
        sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else 'ns'
        print(f'  L{i} vs L{j}: t={t:6.3f}  p={p:.4f}  {sig}')

print()
print('── Specialised heads (Δ > μ + 2σ across all heads) ──')
global_mean = k1_df['delta'].mean()
global_std  = k1_df['delta'].std()
threshold   = global_mean + 2 * global_std
top_heads   = k1_df[k1_df['delta'] > threshold].sort_values('delta', ascending=False)
display(top_heads[['head_id','layer','global_block','head',
                   'acc_pc','acc_rand','delta']].round(4))
print(f'Threshold (μ+2σ) = {threshold:.4f} | '
      f'Specialised: {len(top_heads)}/{len(k1_df)} '
      f'({100*len(top_heads)/len(k1_df):.1f}%)')

## 4. Figures

### Figure A — PC₁ accuracy vs. random baseline, per layer

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(14, 4), sharey=True)
fig.suptitle(
    r'Linear Probe Accuracy: PC$_1$ vs. Random Baseline ($k=1$)',
    fontsize=13, y=1.02
)
layer_labels = ['Stage 1 (L0)', 'Stage 2 (L1)', 'Stage 3 (L2)', 'Stage 4 (L3)']

for layer_idx, ax in enumerate(axes):
    sub   = k1_df[k1_df['layer'] == layer_idx].sort_values('head')
    color = LAYER_COLORS[layer_idx]
    x     = np.arange(len(sub))

    ax.bar(x, sub['acc_pc'].values,   color=color,             alpha=0.85, label=r'PC$_1$', zorder=3)
    ax.bar(x, sub['acc_rand'].values, color=COLORS['neutral'], alpha=0.45, label='Random',  zorder=2)
    ax.errorbar(x, sub['acc_rand'].values, yerr=sub['std_rand'].values,
                fmt='none', color='#333333', capsize=2, linewidth=0.8, zorder=4)
    ax.axhline(CHANCE, color=COLORS['accent'], linewidth=1.0,
               linestyle='--', label=f'Chance ({CHANCE:.2f})', zorder=5)

    ax.set_title(layer_labels[layer_idx], color=color, fontweight='bold')
    ax.set_xlabel('Head index')
    ax.set_xticks(x)
    ax.set_xticklabels(sub['head'].values, fontsize=7)
    if layer_idx == 0:
        ax.set_ylabel('5-fold CV accuracy')
    ax.set_ylim(0, min(1.05, k1_df['acc_pc'].max() * 1.25))

handles = [
    mpatches.Patch(color=LAYER_COLORS[0], alpha=0.85, label=r'PC$_1$ accuracy'),
    mpatches.Patch(color=COLORS['neutral'], alpha=0.45, label='Random baseline'),
    plt.Line2D([0],[0], color=COLORS['accent'], linestyle='--', label='Chance level'),
]
fig.legend(handles=handles, loc='lower center', ncol=3,
           bbox_to_anchor=(0.5, -0.08), frameon=True, edgecolor='#cccccc')
plt.tight_layout()
fig.savefig(f'{FIGURES_DIR}/figA_pc1_probe_by_layer.pdf', bbox_inches='tight', dpi=300)
fig.savefig(f'{FIGURES_DIR}/figA_pc1_probe_by_layer.png', bbox_inches='tight', dpi=300)
plt.show()
print('✅ Figure A saved')

### Figure B — Δ accuracy heatmap: blocks × heads

In [None]:
max_heads    = max(HTSAT_HEADS)
delta_matrix = np.full((N_BLOCKS, max_heads), np.nan)
acc_matrix   = np.full((N_BLOCKS, max_heads), np.nan)

for _, row in k1_df.iterrows():
    b = int(row['global_block'])
    h = int(row['head'])
    delta_matrix[b, h] = row['delta']
    acc_matrix[b, h]   = row['acc_pc']

stage_boundaries = np.cumsum(HTSAT_DEPTHS)[:-1] - 0.5
stage_centers    = [sum(HTSAT_DEPTHS[:i]) + HTSAT_DEPTHS[i]/2 - 0.5 for i in range(4)]
stage_names      = ['S1', 'S2', 'S3', 'S4']

fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
fig.suptitle(
    r'Head-level Semantic Alignment — ' + DATASET_NAME.upper(),
    fontsize=13
)

for ax_idx, (matrix, cmap, vmin, vmax, label, panel_title) in enumerate([
    (delta_matrix, 'RdYlGn', -0.05,
     np.nanmax(delta_matrix),
     r'$\Delta$ accuracy (PC$_1$ $-$ random)',
     r'(a) $\Delta$ accuracy'),
    (acc_matrix, 'Blues', CHANCE,
     np.nanmax(acc_matrix),
     r'PC$_1$ probe accuracy',
     r'(b) Absolute PC$_1$ accuracy'),
]):
    ax = axes[ax_idx]
    im = ax.imshow(matrix, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)

    for boundary in stage_boundaries:
        ax.axhline(boundary, color='black', linewidth=1.5, linestyle='--', alpha=0.7)

    # Annotate cells
    for b in range(N_BLOCKS):
        layer = block_info_df.loc[block_info_df['global_block']==b, 'layer'].values[0]
        n_h   = HTSAT_HEADS[layer]
        for h in range(n_h):
            val = matrix[b, h]
            if not np.isnan(val):
                thresh_light = (vmin + vmax) * 0.6
                ax.text(h, b, f'{val:.2f}', ha='center', va='center',
                        fontsize=6,
                        color='white' if val > thresh_light else 'black')

    cb = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.03)
    cb.set_label(label, fontsize=9)
    ax.set_xlabel('Head index within block')
    ax.set_ylabel('Global block index')
    ax.set_title(panel_title)
    ax.set_yticks(range(N_BLOCKS))
    ax.set_yticklabels([f'B{b}' for b in range(N_BLOCKS)], fontsize=8)

    ax2 = ax.twinx()
    ax2.set_ylim(ax.get_ylim())
    ax2.set_yticks(stage_centers)
    ax2.set_yticklabels(stage_names, fontsize=9, fontweight='bold',
                        color=LAYER_COLORS)
    ax2.tick_params(length=0)

plt.tight_layout()
fig.savefig(f'{FIGURES_DIR}/figB_alignment_heatmap.pdf', bbox_inches='tight', dpi=300)
fig.savefig(f'{FIGURES_DIR}/figB_alignment_heatmap.png', bbox_inches='tight', dpi=300)
plt.show()
print('✅ Figure B saved')

### Figure C — Accuracy vs. number of PCs retained, per stage

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(14, 4), sharey=True)
fig.suptitle('Probe Accuracy vs. Number of PCs Retained', fontsize=13, y=1.02)

for layer_idx, ax in enumerate(axes):
    sub   = results_df[results_df['layer'] == layer_idx]
    color = LAYER_COLORS[layer_idx]

    agg = sub.groupby('k').agg(
        acc_pc_mean   = ('acc_pc',   'mean'),
        acc_pc_std    = ('acc_pc',   'std'),
        acc_rand_mean = ('acc_rand', 'mean'),
        acc_rand_std  = ('acc_rand', 'std'),
    ).reset_index()

    ax.plot(agg['k'], agg['acc_pc_mean'], color=color, marker='o',
            linewidth=1.8, markersize=5, label='PC directions', zorder=4)
    ax.fill_between(agg['k'],
                    agg['acc_pc_mean'] - agg['acc_pc_std'],
                    agg['acc_pc_mean'] + agg['acc_pc_std'],
                    alpha=0.15, color=color)
    ax.plot(agg['k'], agg['acc_rand_mean'], color=COLORS['neutral'],
            marker='s', linewidth=1.4, markersize=4, linestyle='--',
            label='Random dirs', zorder=3)
    ax.fill_between(agg['k'],
                    agg['acc_rand_mean'] - agg['acc_rand_std'],
                    agg['acc_rand_mean'] + agg['acc_rand_std'],
                    alpha=0.10, color=COLORS['neutral'])
    ax.axhline(CHANCE, color=COLORS['accent'], linewidth=1.0,
               linestyle=':', label='Chance', zorder=2)

    ax.set_title(f'Stage {layer_idx+1} (L{layer_idx})', color=color, fontweight='bold')
    ax.set_xlabel('Number of PCs ($k$)')
    if layer_idx == 0:
        ax.set_ylabel('Mean 5-fold CV accuracy')
    ax.set_xticks(K_VALUES)

handles, hlabels = axes[0].get_legend_handles_labels()
fig.legend(handles, hlabels, loc='lower center', ncol=3,
           bbox_to_anchor=(0.5, -0.10), frameon=True, edgecolor='#cccccc')
plt.tight_layout()
fig.savefig(f'{FIGURES_DIR}/figC_probe_vs_k.pdf', bbox_inches='tight', dpi=300)
fig.savefig(f'{FIGURES_DIR}/figC_probe_vs_k.png', bbox_inches='tight', dpi=300)
plt.show()
print('✅ Figure C saved')

### Figure D — Distribution of Δ per stage (violin + scatter)

In [None]:
fig, ax = plt.subplots(figsize=(7, 4))

layer_names = ['Stage 1\n(L0)', 'Stage 2\n(L1)', 'Stage 3\n(L2)', 'Stage 4\n(L3)']
data_by_layer = [k1_df[k1_df['layer'] == l]['delta'].values for l in range(4)]

parts = ax.violinplot(data_by_layer, positions=range(4),
                      showmedians=True, showextrema=True)
for i, pc in enumerate(parts['bodies']):
    pc.set_facecolor(LAYER_COLORS[i])
    pc.set_alpha(0.6)
for key in ['cmedians', 'cmaxes', 'cmins', 'cbars']:
    parts[key].set_color('#333333')
    parts[key].set_linewidth(1.5 if key == 'cmedians' else 0.8)

for l, deltas in enumerate(data_by_layer):
    jitter = np.random.default_rng(42).uniform(-0.08, 0.08, len(deltas))
    ax.scatter(l + jitter, deltas, s=18, color=LAYER_COLORS[l],
               alpha=0.5, zorder=3, edgecolors='none')

# Significance annotations
for l, row in stats_df.iterrows():
    sig = row['significant']
    if sig != 'ns':
        y_pos = data_by_layer[l].max() + 0.005
        ax.text(l, y_pos, sig, ha='center', va='bottom',
                fontsize=10, color=LAYER_COLORS[l], fontweight='bold')

ax.axhline(0, color=COLORS['accent'], linewidth=1.0,
           linestyle='--', label='No gain over random ($\\Delta=0$)')
ax.set_xticks(range(4))
ax.set_xticklabels(layer_names)
ax.set_ylabel(r'$\Delta$ accuracy $=$ PC$_1$ probe $-$ random')
ax.set_title(r'Semantic Gain Distribution per Stage ($k=1$)')
ax.legend(frameon=True, edgecolor='#cccccc')

plt.tight_layout()
fig.savefig(f'{FIGURES_DIR}/figD_delta_violin.pdf', bbox_inches='tight', dpi=300)
fig.savefig(f'{FIGURES_DIR}/figD_delta_violin.png', bbox_inches='tight', dpi=300)
plt.show()
print('✅ Figure D saved')

### Figure E — Top-5 specialised heads: full PC accuracy profile

In [None]:
top5_ids = (
    k1_df.nlargest(5, 'delta')['head_id'].tolist()
)

fig, axes = plt.subplots(1, 5, figsize=(14, 3.5), sharey=True)
fig.suptitle(
    r'Top-5 Specialised Heads: Accuracy Profile Across $k$',
    fontsize=13, y=1.02
)

for ax, hid in zip(axes, top5_ids):
    sub   = results_df[results_df['head_id'] == hid].sort_values('k')
    layer = sub['layer'].values[0]
    color = LAYER_COLORS[layer]

    ax.plot(sub['k'], sub['acc_pc'],   color=color,             marker='o',
            linewidth=1.8, markersize=5, label='PCs',    zorder=4)
    ax.plot(sub['k'], sub['acc_rand'], color=COLORS['neutral'], marker='s',
            linewidth=1.2, markersize=4, linestyle='--', label='Random', zorder=3)
    ax.fill_between(sub['k'],
                    sub['acc_rand'] - sub['std_rand'],
                    sub['acc_rand'] + sub['std_rand'],
                    alpha=0.15, color=COLORS['neutral'])
    ax.axhline(CHANCE, color=COLORS['accent'], linewidth=0.9,
               linestyle=':', label='Chance')

    delta_k1 = sub[sub['k']==1]['delta'].values[0]
    ax.set_title(f'{hid}\n' + r'$\Delta_{k=1}$' + f'={delta_k1:.3f}',
                 fontsize=9, color=color)
    ax.set_xlabel('$k$')
    if ax == axes[0]:
        ax.set_ylabel('CV accuracy')
    ax.set_xticks(K_VALUES)

handles, hlabels = axes[0].get_legend_handles_labels()
fig.legend(handles, hlabels, loc='lower center', ncol=3,
           bbox_to_anchor=(0.5, -0.12), frameon=True, edgecolor='#cccccc')
plt.tight_layout()
fig.savefig(f'{FIGURES_DIR}/figE_top5_heads.pdf', bbox_inches='tight', dpi=300)
fig.savefig(f'{FIGURES_DIR}/figE_top5_heads.png', bbox_inches='tight', dpi=300)
plt.show()
print('✅ Figure E saved')

## 5. Final Summary

In [None]:
print('=' * 65)
print(f'SUMMARY — {DATASET_NAME.upper()}')
print('=' * 65)
print(f'Samples: {N_SAMPLES}  |  Classes: {N_CLASSES}  |  Chance: {CHANCE:.4f}')
print()
for l in range(4):
    sub = k1_df[k1_df['layer']==l]
    best = sub.loc[sub['delta'].idxmax()]
    sig  = stats_df[stats_df['layer']==l]['significant'].values[0]
    print(f'  Stage {l+1}:  mean PC₁={sub["acc_pc"].mean():.4f}  '
          f'rand={sub["acc_rand"].mean():.4f}  '
          f'Δ={sub["delta"].mean():+.4f} ({sig})  '
          f'best head={best["head_id"]} Δ={best["delta"]:+.4f}')
print()
print(f'Specialised heads (Δ > μ+2σ = {threshold:.4f}): '
      f'{len(top_heads)}/{len(k1_df)} ({100*len(top_heads)/len(k1_df):.1f}%)')
print()
print(f'Figures saved to: {FIGURES_DIR}/')