# Grokking Trajectory Analysis

This notebook analyzes the results of the grokking trajectory experiment.
We visualize how TN similarity, activation similarity, and JS divergence
change throughout training to understand the grokking phenomenon.

In [None]:
import json
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# Set up plotting style
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

## 1. Load Results

In [None]:
# Path to results
results_dir = Path("../results")

# Load config
with open(results_dir / "config.json") as f:
    config = json.load(f)
print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

In [None]:
# Load training history
with open(results_dir / "training_history.json") as f:
    history_raw = json.load(f)

# Convert string keys back to integers and sort
history = {int(k): v for k, v in history_raw.items()}
steps = sorted(history.keys())

# Extract metrics
train_losses = [history[s]['train_loss'] for s in steps]
val_losses = [history[s]['val_loss'] for s in steps]
val_accs = [history[s]['val_acc'] for s in steps]

print(f"Number of checkpoints: {len(steps)}")
print(f"Step range: {steps[0]} to {steps[-1]}")
print(f"Final val accuracy: {val_accs[-1]:.4f}")

In [None]:
# Load similarity matrices
tn_sim = np.load(results_dir / "tn_similarity.npy")
act_sim = np.load(results_dir / "act_similarity.npy")
js_div = np.load(results_dir / "js_divergence.npy")
checkpoint_steps = np.load(results_dir / "checkpoint_steps.npy")

print(f"Matrix shapes: {tn_sim.shape}")
print(f"TN sim range: [{tn_sim.min():.4f}, {tn_sim.max():.4f}]")
print(f"Act sim range: [{act_sim.min():.4f}, {act_sim.max():.4f}]")
print(f"JS div range: [{js_div.min():.4f}, {js_div.max():.4f}]")

## 2. Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax = axes[0]
ax.semilogy(steps, train_losses, label='Train Loss', alpha=0.8)
ax.semilogy(steps, val_losses, label='Val Loss', alpha=0.8)
ax.set_xlabel('Training Steps')
ax.set_ylabel('Loss (log scale)')
ax.set_title('Loss Curves')
ax.legend()
ax.grid(True, alpha=0.3)

# Accuracy curve
ax = axes[1]
ax.plot(steps, val_accs, 'g-', linewidth=2, label='Val Accuracy')
ax.axhline(y=0.95, color='r', linestyle='--', alpha=0.5, label='95% threshold')
ax.set_xlabel('Training Steps')
ax.set_ylabel('Validation Accuracy')
ax.set_title('Validation Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])

plt.tight_layout()
plt.show()

# Identify grokking point
grok_threshold = 0.95
grok_idx = next((i for i, acc in enumerate(val_accs) if acc >= grok_threshold), None)
if grok_idx is not None:
    grok_step = steps[grok_idx]
    print(f"Grokking occurred around step {grok_step} (checkpoint index {grok_idx})")
else:
    print(f"Did not reach {grok_threshold*100}% validation accuracy")
    grok_step = None

## 3. Similarity Heatmaps

In [None]:
def plot_similarity_heatmap(matrix, title, ax, cmap='viridis', vmin=None, vmax=None):
    """Plot a similarity/divergence heatmap."""
    N = matrix.shape[0]
    
    # Create tick positions and labels (show every 20 checkpoints)
    tick_stride = 20
    tick_positions = list(range(0, N, tick_stride))
    tick_labels = [str(checkpoint_steps[i]) for i in tick_positions]
    
    im = ax.imshow(matrix, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
    ax.set_xticks(tick_positions)
    ax.set_xticklabels(tick_labels, rotation=45, ha='right')
    ax.set_yticks(tick_positions)
    ax.set_yticklabels(tick_labels)
    ax.set_xlabel('Step')
    ax.set_ylabel('Step')
    ax.set_title(title)
    
    # Add grokking line if found
    if grok_idx is not None:
        ax.axhline(y=grok_idx, color='red', linestyle='--', alpha=0.7, linewidth=1)
        ax.axvline(x=grok_idx, color='red', linestyle='--', alpha=0.7, linewidth=1)
    
    return im

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# TN Similarity
im1 = plot_similarity_heatmap(tn_sim, 'TN Similarity', axes[0], cmap='viridis', vmin=0, vmax=1)
plt.colorbar(im1, ax=axes[0], label='Similarity')

# Activation Similarity
im2 = plot_similarity_heatmap(act_sim, 'Activation Similarity', axes[1], cmap='viridis', vmin=0, vmax=1)
plt.colorbar(im2, ax=axes[1], label='Similarity')

# JS Divergence
im3 = plot_similarity_heatmap(js_div, 'JS Divergence', axes[2], cmap='magma')
plt.colorbar(im3, ax=axes[2], label='Divergence')

plt.suptitle('Similarity Matrices Throughout Training\n(Red lines indicate grokking transition)', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Diagonal Band Analysis

We analyze how similar consecutive checkpoints are by looking at diagonal bands.

In [None]:
def extract_diagonal_band(matrix, offset=1):
    """Extract diagonal band at given offset from main diagonal."""
    return np.diag(matrix, k=offset)

# Consecutive checkpoint similarities (k=1 diagonal)
tn_consec = extract_diagonal_band(tn_sim, 1)
act_consec = extract_diagonal_band(act_sim, 1)
js_consec = extract_diagonal_band(js_div, 1)

# Steps for plotting (between consecutive checkpoints)
consec_steps = checkpoint_steps[:-1]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# TN similarity between consecutive checkpoints
ax = axes[0, 0]
ax.plot(consec_steps, tn_consec, 'b-', linewidth=1.5)
if grok_step:
    ax.axvline(x=grok_step, color='red', linestyle='--', alpha=0.7, label=f'Grok @ {grok_step}')
ax.set_xlabel('Step')
ax.set_ylabel('TN Similarity')
ax.set_title('TN Similarity (Consecutive Checkpoints)')
ax.grid(True, alpha=0.3)
ax.legend()

# Activation similarity between consecutive checkpoints
ax = axes[0, 1]
ax.plot(consec_steps, act_consec, 'g-', linewidth=1.5)
if grok_step:
    ax.axvline(x=grok_step, color='red', linestyle='--', alpha=0.7, label=f'Grok @ {grok_step}')
ax.set_xlabel('Step')
ax.set_ylabel('Activation Similarity')
ax.set_title('Activation Similarity (Consecutive Checkpoints)')
ax.grid(True, alpha=0.3)
ax.legend()

# JS divergence between consecutive checkpoints
ax = axes[1, 0]
ax.plot(consec_steps, js_consec, 'm-', linewidth=1.5)
if grok_step:
    ax.axvline(x=grok_step, color='red', linestyle='--', alpha=0.7, label=f'Grok @ {grok_step}')
ax.set_xlabel('Step')
ax.set_ylabel('JS Divergence')
ax.set_title('JS Divergence (Consecutive Checkpoints)')
ax.grid(True, alpha=0.3)
ax.legend()

# Rate of change (1 - similarity represents change)
ax = axes[1, 1]
ax.plot(consec_steps, 1 - tn_consec, 'b-', linewidth=1.5, label='TN change')
ax.plot(consec_steps, 1 - act_consec, 'g-', linewidth=1.5, label='Act change')
if grok_step:
    ax.axvline(x=grok_step, color='red', linestyle='--', alpha=0.7, label=f'Grok @ {grok_step}')
ax.set_xlabel('Step')
ax.set_ylabel('1 - Similarity (Change)')
ax.set_title('Rate of Change Between Checkpoints')
ax.grid(True, alpha=0.3)
ax.legend()

plt.tight_layout()
plt.show()

## 5. Similarity to Final Model

How similar is each checkpoint to the final trained model?

In [None]:
# Similarity to final checkpoint
tn_to_final = tn_sim[:, -1]
act_to_final = act_sim[:, -1]
js_to_final = js_div[:, -1]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# TN similarity to final
ax = axes[0]
ax.plot(checkpoint_steps, tn_to_final, 'b-', linewidth=2)
ax.fill_between(checkpoint_steps, tn_to_final, alpha=0.3)
if grok_step:
    ax.axvline(x=grok_step, color='red', linestyle='--', alpha=0.7, label=f'Grok @ {grok_step}')
ax.set_xlabel('Step')
ax.set_ylabel('TN Similarity')
ax.set_title('TN Similarity to Final Model')
ax.grid(True, alpha=0.3)
ax.legend()

# Activation similarity to final
ax = axes[1]
ax.plot(checkpoint_steps, act_to_final, 'g-', linewidth=2)
ax.fill_between(checkpoint_steps, act_to_final, alpha=0.3, color='green')
if grok_step:
    ax.axvline(x=grok_step, color='red', linestyle='--', alpha=0.7, label=f'Grok @ {grok_step}')
ax.set_xlabel('Step')
ax.set_ylabel('Activation Similarity')
ax.set_title('Activation Similarity to Final Model')
ax.grid(True, alpha=0.3)
ax.legend()

# JS divergence to final
ax = axes[2]
ax.plot(checkpoint_steps, js_to_final, 'm-', linewidth=2)
ax.fill_between(checkpoint_steps, js_to_final, alpha=0.3, color='magenta')
if grok_step:
    ax.axvline(x=grok_step, color='red', linestyle='--', alpha=0.7, label=f'Grok @ {grok_step}')
ax.set_xlabel('Step')
ax.set_ylabel('JS Divergence')
ax.set_title('JS Divergence from Final Model')
ax.grid(True, alpha=0.3)
ax.legend()

plt.suptitle('Similarity/Divergence to Final Trained Model', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Combined Analysis with Validation Accuracy

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Validation accuracy
ax = axes[0, 0]
ax.plot(steps, val_accs, 'k-', linewidth=2, label='Val Accuracy')
ax.axhline(y=0.95, color='r', linestyle='--', alpha=0.5)
ax.set_ylabel('Validation Accuracy')
ax.set_title('Validation Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])

# TN similarity to final with val acc on secondary axis
ax = axes[0, 1]
ax2 = ax.twinx()
ln1 = ax.plot(checkpoint_steps, tn_to_final, 'b-', linewidth=2, label='TN Sim to Final')
ln2 = ax2.plot(steps, val_accs, 'k--', alpha=0.5, label='Val Acc')
ax.set_ylabel('TN Similarity', color='blue')
ax2.set_ylabel('Val Acc', color='black')
ax.set_title('TN Similarity to Final vs Val Accuracy')
lns = ln1 + ln2
labs = [l.get_label() for l in lns]
ax.legend(lns, labs)
ax.grid(True, alpha=0.3)

# Activation similarity to final with val acc
ax = axes[1, 0]
ax2 = ax.twinx()
ln1 = ax.plot(checkpoint_steps, act_to_final, 'g-', linewidth=2, label='Act Sim to Final')
ln2 = ax2.plot(steps, val_accs, 'k--', alpha=0.5, label='Val Acc')
ax.set_xlabel('Step')
ax.set_ylabel('Activation Similarity', color='green')
ax2.set_ylabel('Val Acc', color='black')
ax.set_title('Activation Similarity to Final vs Val Accuracy')
lns = ln1 + ln2
labs = [l.get_label() for l in lns]
ax.legend(lns, labs)
ax.grid(True, alpha=0.3)

# JS divergence to final with val acc
ax = axes[1, 1]
ax2 = ax.twinx()
ln1 = ax.plot(checkpoint_steps, js_to_final, 'm-', linewidth=2, label='JS Div to Final')
ln2 = ax2.plot(steps, val_accs, 'k--', alpha=0.5, label='Val Acc')
ax.set_xlabel('Step')
ax.set_ylabel('JS Divergence', color='magenta')
ax2.set_ylabel('Val Acc', color='black')
ax.set_title('JS Divergence from Final vs Val Accuracy')
lns = ln1 + ln2
labs = [l.get_label() for l in lns]
ax.legend(lns, labs)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Summary Statistics

In [None]:
# Compute correlation between similarity metrics
upper_tri_idx = np.triu_indices(tn_sim.shape[0], k=1)

tn_upper = tn_sim[upper_tri_idx]
act_upper = act_sim[upper_tri_idx]
js_upper = js_div[upper_tri_idx]

print("Correlations between metrics (upper triangular values):")
print(f"  TN vs Activation similarity: {np.corrcoef(tn_upper, act_upper)[0,1]:.4f}")
print(f"  TN similarity vs JS divergence: {np.corrcoef(tn_upper, js_upper)[0,1]:.4f}")
print(f"  Activation similarity vs JS divergence: {np.corrcoef(act_upper, js_upper)[0,1]:.4f}")

print("\nMetric statistics:")
print(f"  TN similarity: mean={tn_upper.mean():.4f}, std={tn_upper.std():.4f}")
print(f"  Activation similarity: mean={act_upper.mean():.4f}, std={act_upper.std():.4f}")
print(f"  JS divergence: mean={js_upper.mean():.4f}, std={js_upper.std():.4f}")

In [None]:
# Scatter plots of metric pairs
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

ax = axes[0]
ax.scatter(tn_upper, act_upper, alpha=0.3, s=5)
ax.set_xlabel('TN Similarity')
ax.set_ylabel('Activation Similarity')
ax.set_title(f'TN vs Activation (r={np.corrcoef(tn_upper, act_upper)[0,1]:.3f})')
ax.grid(True, alpha=0.3)

ax = axes[1]
ax.scatter(tn_upper, js_upper, alpha=0.3, s=5)
ax.set_xlabel('TN Similarity')
ax.set_ylabel('JS Divergence')
ax.set_title(f'TN Sim vs JS Div (r={np.corrcoef(tn_upper, js_upper)[0,1]:.3f})')
ax.grid(True, alpha=0.3)

ax = axes[2]
ax.scatter(act_upper, js_upper, alpha=0.3, s=5)
ax.set_xlabel('Activation Similarity')
ax.set_ylabel('JS Divergence')
ax.set_title(f'Act Sim vs JS Div (r={np.corrcoef(act_upper, js_upper)[0,1]:.3f})')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Grokking Phase Identification

Identify different phases of training based on metric behavior.

In [None]:
# Identify phases based on validation accuracy
# Phase 1: Memorization (high train acc, low val acc)
# Phase 2: Grokking transition
# Phase 3: Generalization (high val acc)

val_acc_array = np.array(val_accs)

# Find transition points
low_acc_threshold = 0.2  # Random chance would be ~1/P
high_acc_threshold = 0.9

# Phase boundaries
memorization_end = None
generalization_start = None

for i, acc in enumerate(val_accs):
    if memorization_end is None and acc > low_acc_threshold:
        memorization_end = i
    if generalization_start is None and acc > high_acc_threshold:
        generalization_start = i
        break

print("Phase identification:")
if memorization_end is not None:
    print(f"  Memorization phase: steps 0 - {steps[memorization_end]}")
    print(f"  Transition begins at checkpoint {memorization_end}")
if generalization_start is not None:
    print(f"  Generalization phase starts: step {steps[generalization_start]}")
    print(f"  Generalization starts at checkpoint {generalization_start}")

# Analyze metric behavior in each phase
if generalization_start is not None:
    print("\nMetric analysis by phase:")
    
    # Pre-grokking
    pre_tn = tn_to_final[:generalization_start].mean()
    pre_act = act_to_final[:generalization_start].mean()
    pre_js = js_to_final[:generalization_start].mean()
    
    # Post-grokking
    post_tn = tn_to_final[generalization_start:].mean()
    post_act = act_to_final[generalization_start:].mean()
    post_js = js_to_final[generalization_start:].mean()
    
    print(f"  Pre-grokking (avg similarity to final):")
    print(f"    TN: {pre_tn:.4f}, Act: {pre_act:.4f}, JS: {pre_js:.4f}")
    print(f"  Post-grokking (avg similarity to final):")
    print(f"    TN: {post_tn:.4f}, Act: {post_act:.4f}, JS: {post_js:.4f}")