# Longer Context Cloze Explorer

Analysis of contrastive activations from Llama-3.1-70B-Instruct on the **contextual_cloze_tests_100** dataset.

**Dataset format:** `"A_completed B_completed C_truncated"` â€” two context sentences (A, B) each completed with straight or funny words, followed by a target sentence (C) truncated at the blank.

- 50 pairs (100 prompts total)
- 3-sentence context format

**Contents:**
1. Data overview and separation metrics
2. Static visualizations
3. Interactive 3D layer explorer

In [None]:
from pathlib import Path
import json
import numpy as np
from IPython.display import HTML, display
import matplotlib.pyplot as plt
%matplotlib inline

from analyze_activations import (
    load_activations, get_pair_indices, analyze_all_layers,
    contrastive_projection, contrastive_direction, cohens_d,
    stable_contrastive_projections, holdout_analysis,
    load_detailed_predictions, pun_boost_per_pair,
)
from puns_viz import make_layer_viz

In [None]:
RAW_DIR = Path("results/raw_activations")

# Meta file for original longer context dataset
META_FILE = RAW_DIR / "llama31_70b_instruct_pred_c_meta.json"
PRED_FILE = RAW_DIR / "llama31_70b_instruct_pred_c_detailed_preds.json"

if not META_FILE.exists():
    print(f"Meta file not found: {META_FILE}")
    print("Run: python3 collect_activations.py --position pred_c")
else:
    meta, layer_data, layer_indices = load_activations(META_FILE)
    pair_ids, is_funny, is_straight = get_pair_indices(meta)
    
    print(f"Model: {meta['model']}")
    print(f"Dataset: contextual_cloze_tests_100.json (3-sentence format)")
    print(f"Position: {meta['position']}")
    print(f"Layers: {len(layer_indices)} (0-{layer_indices[-1]})")
    print(f"Prompts: {meta['n_prompts']} ({is_straight.sum()} straight, {is_funny.sum()} funny)")
    print(f"Hidden dim: {meta['hidden_dim']}")

---
## 1. Separation Metrics Across Layers

In [None]:
layer_results = analyze_all_layers(layer_data, meta)

peak_fisher = layer_results['peak_fisher_layer']
peak_cd = layer_results['peak_cohens_d_layer']
peak_cd_idx = layer_indices.index(peak_cd)

print(f"Fisher peak: layer {peak_fisher} (score={layer_results['fisher'][layer_indices.index(peak_fisher)]:.3f})")
print(f"Cohen's d peak: layer {peak_cd} (d={layer_results['cohens_d'][peak_cd_idx]:.2f})")

In [None]:
fig, ax1 = plt.subplots(figsize=(12, 4))

ax1.plot(layer_indices, layer_results['fisher'], color='#E85D75', lw=2, label='Fisher separation')
ax1.set_xlabel('Layer', fontsize=11)
ax1.set_ylabel('Fisher separation', fontsize=11, color='#E85D75')
ax1.tick_params(axis='y', labelcolor='#E85D75')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

ax2 = ax1.twinx()
ax2.plot(layer_indices, layer_results['cohens_d'], color='#2EAD6B', lw=2, ls='--', label="Cohen's d")
ax2.set_ylabel("Cohen's d", fontsize=11, color='#2EAD6B')
ax2.tick_params(axis='y', labelcolor='#2EAD6B')
ax2.spines['top'].set_visible(False)

lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, fontsize=9, loc='upper left')
ax1.set_title('Longer Context (3-sentence): Straight vs. Funny Separation by Layer', fontsize=13, fontweight='bold')
fig.tight_layout()
plt.show()

---
## 2. Static Visualizations

In [None]:
# Load pun boost data if available
if PRED_FILE.exists():
    detailed_preds = load_detailed_predictions(PRED_FILE)
    boost_ratios = pun_boost_per_pair(detailed_preds)
    has_boost = np.array([boost_ratios.get(s['pair_id'], 1.0) >= 2.0 for s in meta['samples']])
else:
    has_boost = np.zeros(len(meta['samples']), dtype=bool)

# Contrastive scatter at peak layer
X_peak = layer_data[peak_cd]
X_proj, _, var_ratios = contrastive_projection(X_peak, meta, n_components=2)

fig, ax = plt.subplots(figsize=(10, 8))

# Pair lines
for pid in sorted(set(pair_ids)):
    mask = pair_ids == pid
    if mask.sum() == 2:
        pts = X_proj[mask]
        ax.plot(pts[:, 0], pts[:, 1], color='#888', alpha=0.4, lw=0.8, zorder=1)

# Points with boost markers
groups = [
    (is_straight & ~has_boost, '#4A90D9', 'o', 40, 'Straight'),
    (is_straight & has_boost, '#4A90D9', '*', 120, 'Straight, 2x+ boost'),
    (is_funny & ~has_boost, '#E85D75', 'o', 40, 'Funny'),
    (is_funny & has_boost, '#E85D75', '*', 120, 'Funny, 2x+ boost'),
]
for mask, color, marker, size, label in groups:
    if mask.sum():
        ax.scatter(X_proj[mask, 0], X_proj[mask, 1], c=color, marker=marker,
                   s=size, alpha=0.7, label=label, edgecolors='white', lw=0.5, zorder=2)

ax.set_xlabel(f'Contrastive direction ({var_ratios[0]:.1%} var)', fontsize=11)
ax.set_ylabel(f'Residual PC1 ({var_ratios[1]:.1%} var)', fontsize=11)
ax.set_title(f'Longer Context: Contrastive Projection at Layer {peak_cd}', fontsize=13, fontweight='bold')
ax.legend(fontsize=9, loc='best')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
fig.tight_layout()
plt.show()

In [None]:
# 1D projection histogram
direction = contrastive_direction(X_peak, meta)
projections = X_peak @ direction

fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(projections[is_straight], bins=15, alpha=0.6, color='#4A90D9', label='Straight ctx', edgecolor='white')
ax.hist(projections[is_funny], bins=15, alpha=0.6, color='#E85D75', label='Funny ctx', edgecolor='white')
ax.set_xlabel('Projection onto contrastive direction', fontsize=11)
ax.set_ylabel('Count', fontsize=11)
ax.set_title(f'Longer Context: 1D Projections at Layer {peak_cd} (Cohen\'s d = {layer_results["cohens_d"][peak_cd_idx]:.2f})',
             fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
fig.tight_layout()
plt.show()

---
## 3. Interactive 3D Layer Explorer

- **Drag** to rotate
- **Scroll/pinch** to zoom
- **Shift-drag** to pan
- **Layer slider** to move through all layers

In [None]:
html = make_layer_viz(META_FILE, pred_file=str(PRED_FILE) if PRED_FILE.exists() else None, width=900, height=600)
HTML(html)

---
## 4. Holdout Analysis

In [None]:
holdout = holdout_analysis(layer_data, meta, n_splits=2, seed=42)

fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(holdout['layer_indices'], holdout['cohens_d_full'], color='#2EAD6B', lw=2, label='Full data')
ax.plot(holdout['layer_indices'], holdout['cohens_d_cv'], color='#E85D75', lw=2, ls='--', label='Cross-validated')
ax.set_xlabel('Layer', fontsize=11)
ax.set_ylabel("Cohen's d", fontsize=11)
ax.set_title('Longer Context: Full Data vs. Cross-Validated Cohen\'s d', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
fig.tight_layout()
plt.show()