# Similarity Analysis

This notebook loads pre-computed similarity matrices (or computes them if missing) and provides visualization and analysis tools for comparing models across different bottleneck dimensions.

Three similarity metrics are computed:
1. **TN Similarity** - Weight-space similarity using symmetric inner product
2. **Logit Cosine Similarity** - Output-space similarity based on model logits
3. **JS Divergence** - Frequency distribution similarity from eigenvector FFT analysis

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from pathlib import Path

from models import load_sweep_results, get_device
from tn_sim import load_or_compute_tn_similarity
from act_sim import load_or_compute_act_similarity
from js_div import load_or_compute_js_divergence

# Configuration
SWEEP_PATH = Path('../comp_diagrams/sweep_results_0401.pkl')
DEVICE = get_device()
print(f'Device: {DEVICE}')

## Load Data and Similarity Matrices

The `load_or_compute_*` functions will load from cache if available, or compute and save to cache.

In [None]:
# Load sweep results
models_state, val_acc, P = load_sweep_results(SWEEP_PATH)
print(f'Loaded {len(models_state)} models with P={P}')

# Load/compute similarity matrices
tn_sim_mat = load_or_compute_tn_similarity(SWEEP_PATH, device=DEVICE)
act_sim_mat = load_or_compute_act_similarity(SWEEP_PATH, device=DEVICE)
js_div_mat = load_or_compute_js_divergence(SWEEP_PATH, device=DEVICE)

print(f'\nMatrix shapes:')
print(f'  TN similarity: {tn_sim_mat.shape}')
print(f'  Activation similarity: {act_sim_mat.shape}')
print(f'  JS divergence: {js_div_mat.shape}')

## Validation Accuracy Curve

Reference plot showing model performance across bottleneck dimensions.

In [None]:
if len(val_acc) > 0:
    dims_sorted = sorted(list(val_acc.keys()), reverse=True)
    accs_sorted = [val_acc[d] for d in dims_sorted]
    
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.plot(dims_sorted, accs_sorted, 'o-', linewidth=2, markersize=6)
    ax.set_xlabel('Hidden Dimension (Bottleneck)', fontsize=12)
    ax.set_ylabel('Validation Accuracy', fontsize=12)
    ax.set_title('Validation Accuracy vs. Bottleneck Dimension', fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0, 1.05])
    plt.tight_layout()
    plt.show()
else:
    print('No validation accuracies available.')

## Side-by-Side Similarity Heatmaps

Compare all three similarity metrics at once.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# TN Similarity
ax1 = axes[0]
im1 = ax1.imshow(tn_sim_mat, cmap='magma', vmin=0, vmax=0.7, origin='lower')
ax1.set_xlabel('Bottleneck Dimension', fontsize=11)
ax1.set_ylabel('Bottleneck Dimension', fontsize=11)
ax1.set_title('TN Similarity\n(weight-space)', fontsize=12)
fig.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)

# Logit Cosine Similarity
ax2 = axes[1]
im2 = ax2.imshow(act_sim_mat, cmap='magma', vmin=0, vmax=1.0, origin='lower')
ax2.set_xlabel('Bottleneck Dimension', fontsize=11)
ax2.set_ylabel('Bottleneck Dimension', fontsize=11)
ax2.set_title('Logit Cosine Similarity\n(output-space)', fontsize=12)
fig.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)

# JS Divergence (negated for visual consistency - lower divergence = more similar)
ax3 = axes[2]
im3 = ax3.imshow(-js_div_mat, cmap='magma', origin='lower', vmin=-17)
ax3.set_xlabel('Bottleneck Dimension', fontsize=11)
ax3.set_ylabel('Bottleneck Dimension', fontsize=11)
ax3.set_title('JS Divergence\n(frequency distribution)', fontsize=12)
fig.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04, label='-JS')

plt.suptitle('Model Similarity: Three Perspectives', fontsize=14)
plt.tight_layout()
plt.show()

## Individual Heatmap Exploration

Interactive exploration with adjustable color scales.

In [None]:
def plot_similarity_heatmap(metric='TN', vmin=0.0, vmax=1.0):
    """Plot a single similarity heatmap with adjustable parameters."""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    if metric == 'TN':
        data = tn_sim_mat
        title = 'TN Similarity (weight-space)'
    elif metric == 'Logit':
        data = act_sim_mat
        title = 'Logit Cosine Similarity (output-space)'
    else:  # JS
        data = -js_div_mat  # Negate for visual consistency
        title = 'Negative JS Divergence (frequency)'
    
    im = ax.imshow(data, cmap='magma', vmin=vmin, vmax=vmax, origin='lower')
    ax.set_xlabel('Bottleneck Dimension', fontsize=12)
    ax.set_ylabel('Bottleneck Dimension', fontsize=12)
    ax.set_title(title, fontsize=14)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

# Create interactive widgets
metric_dropdown = widgets.Dropdown(
    options=['TN', 'Logit', 'JS'],
    value='TN',
    description='Metric:'
)
vmin_slider = widgets.FloatSlider(min=-20, max=1, step=0.1, value=0.0, description='vmin')
vmax_slider = widgets.FloatSlider(min=0, max=1, step=0.1, value=1.0, description='vmax')

out = widgets.interactive_output(
    plot_similarity_heatmap,
    {'metric': metric_dropdown, 'vmin': vmin_slider, 'vmax': vmax_slider}
)
display(widgets.VBox([widgets.HBox([metric_dropdown, vmin_slider, vmax_slider]), out]))

## Correlation Analysis Between Metrics

How do the different similarity metrics relate to each other?

In [None]:
# Flatten upper triangular parts for correlation (exclude diagonal)
upper_idx = np.triu_indices(P, k=1)

tn_flat = tn_sim_mat[upper_idx]
act_flat = act_sim_mat[upper_idx]
js_flat = js_div_mat[upper_idx]

# Compute correlations
corr_tn_act = np.corrcoef(tn_flat, act_flat)[0, 1]
corr_tn_js = np.corrcoef(tn_flat, -js_flat)[0, 1]  # Negate JS for positive correlation
corr_act_js = np.corrcoef(act_flat, -js_flat)[0, 1]

print('Pairwise Correlations (upper triangular, excluding diagonal):')
print(f'  TN vs Logit: {corr_tn_act:.4f}')
print(f'  TN vs -JS:   {corr_tn_js:.4f}')
print(f'  Logit vs -JS: {corr_act_js:.4f}')

In [None]:
# Scatter plots of metric relationships
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

ax1 = axes[0]
ax1.scatter(tn_flat, act_flat, alpha=0.3, s=10)
ax1.set_xlabel('TN Similarity')
ax1.set_ylabel('Logit Similarity')
ax1.set_title(f'TN vs Logit (r={corr_tn_act:.3f})')
ax1.grid(True, alpha=0.3)

ax2 = axes[1]
ax2.scatter(tn_flat, -js_flat, alpha=0.3, s=10)
ax2.set_xlabel('TN Similarity')
ax2.set_ylabel('-JS Divergence')
ax2.set_title(f'TN vs -JS (r={corr_tn_js:.3f})')
ax2.grid(True, alpha=0.3)

ax3 = axes[2]
ax3.scatter(act_flat, -js_flat, alpha=0.3, s=10)
ax3.set_xlabel('Logit Similarity')
ax3.set_ylabel('-JS Divergence')
ax3.set_title(f'Logit vs -JS (r={corr_act_js:.3f})')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Similarity Profile by Bottleneck Dimension

How similar is each model to all others? Compare similarity profiles.

In [None]:
def plot_similarity_profile(dim_idx=P-1):
    """Plot similarity profile for a given bottleneck dimension."""
    d = dim_idx + 1  # Convert to 1-indexed dimension
    dims = np.arange(1, P+1)
    
    fig, ax = plt.subplots(figsize=(12, 5))
    
    ax.plot(dims, tn_sim_mat[dim_idx], 'o-', label='TN', markersize=4)
    ax.plot(dims, act_sim_mat[dim_idx], 's-', label='Logit', markersize=4)
    ax.plot(dims, -js_div_mat[dim_idx] / 17, '^-', label='-JS/17', markersize=4)  # Scaled
    
    ax.axvline(d, color='red', linestyle='--', alpha=0.5, label=f'd={d}')
    ax.set_xlabel('Comparison Bottleneck Dimension')
    ax.set_ylabel('Similarity')
    ax.set_title(f'Similarity Profile for d_hidden={d}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

dim_slider = widgets.IntSlider(min=0, max=P-1, step=1, value=P-1, description='dim_idx')
out = widgets.interactive_output(plot_similarity_profile, {'dim_idx': dim_slider})
display(dim_slider, out)

## Summary Statistics

In [None]:
print('Summary Statistics (excluding diagonal):')
print('\nTN Similarity:')
print(f'  Mean: {tn_flat.mean():.4f}')
print(f'  Std:  {tn_flat.std():.4f}')
print(f'  Range: [{tn_flat.min():.4f}, {tn_flat.max():.4f}]')

print('\nLogit Cosine Similarity:')
print(f'  Mean: {act_flat.mean():.4f}')
print(f'  Std:  {act_flat.std():.4f}')
print(f'  Range: [{act_flat.min():.4f}, {act_flat.max():.4f}]')

print('\nJS Divergence:')
print(f'  Mean: {js_flat.mean():.4f}')
print(f'  Std:  {js_flat.std():.4f}')
print(f'  Range: [{js_flat.min():.4f}, {js_flat.max():.4f}]')