# KL Amplitude Report (per chunk)

Compute the quadratic KL curve amplitude for each chunk (latent steering per-chunk vector), then print alongside counterfactual metrics: `counterfactual_importance_kl`, `counterfactual_accuracies`, `different_trajectories_fraction`, `overdeterminedness`.

In [1]:
import os, json, math, gc, re
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt

# Resolve repo root (match other notebooks)
repo_root = Path.cwd().resolve().parents[0] if (Path.cwd()).exists() else Path.cwd().resolve()
import sys
sys.path.append(str(repo_root))

# Config
model_name = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'
model_tag = model_name.replace('/', '-')
betas = np.linspace(-10, 10, 21)  # steering amplitudes (× RMS)
max_examples = 2  # set None for all; kept small to limit runtime

# Paths
annotated_path = repo_root / 'generated_data' / f'generated_data_annotated_{model_tag}.json'
anchors_path = repo_root / 'generated_data' / f'steering_anchors_{model_tag}.json'

# Load model/tokenizer
from utils import load_model_and_vectors, split_solution_into_chunks
device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
model, tokenizer, _ = load_model_and_vectors(model_name=model_name, compute_features=False, device=device)
model.model.eval()

# Load datasets
with open(annotated_path, 'r') as f:
    annotated = json.load(f)
with open(anchors_path, 'r') as f:
    anchors_payload = json.load(f)
examples_anchors = anchors_payload.get('examples', [])
len(annotated), len(examples_anchors)

  from .autonotebook import tqdm as notebook_tqdm
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


No mean vectors found for deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B. You can save to generated_data/mean_vectors_deepseek-r1-distill-qwen-1.5b.pt.


(2, 2)

In [2]:
from pathlib import Path
print('--- Sanity check ---')
names = []
if 'annotated_path' in globals(): names.append(('annotated_path', str(annotated_path)))
if 'anchors_path' in globals(): names.append(('anchors_path', str(anchors_path)))
if 'vectors_path' in globals(): names.append(('vectors_path', str(vectors_path)))
all_ok = True
for nm, p in names:
    ok = Path(p).exists()
    print(f'{nm}:', 'OK' if ok else 'MISSING', p)
    all_ok = all_ok and ok
try:
    from utils import forward_with_logits, kl_from_logits
    print('utils import: OK')
except Exception as e:
    print('utils import failed:', e); all_ok=False
try:
    import torch
    dev = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
    print('device:', dev)
except Exception as e:
    print('torch not available:', e)
SANITY_OK = all_ok
print('SANITY_OK =', SANITY_OK)


--- Sanity check ---
annotated_path: OK /home/cutterdawes/SteeringThoughtAnchors/generated_data/generated_data_annotated_deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.json
anchors_path: OK /home/cutterdawes/SteeringThoughtAnchors/generated_data/steering_anchors_deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.json
utils import: OK
device: cuda
SANITY_OK = True


In [3]:
from utils import compute_kl_curve_for_chunk
import numpy as np

def kl_amplitude(ys: list, xs: np.ndarray) -> float:
    Y = np.asarray(ys, dtype=float)
    X = np.asarray(xs, dtype=float)
    n = min(len(Y), len(X))
    if n < 3:
        return float('nan')
    try:
        a = float(np.polyfit(X[:n], Y[:n], 2)[0])
        return abs(a)
    except Exception:
        return float('nan')


In [14]:
import pandas as pd
def preview(s: str, n: int = 80) -> str:
    s = (s or '').replace('\n',' ').strip()
    return (s[:n] + '...') if len(s) > n else s
rows = []
ex_list = annotated
if max_examples is not None:
    ex_list = ex_list[:int(max_examples)]
for ex_i, ex in enumerate(ex_list):
    if ex_i >= len(examples_anchors):
        break
    anchors_ex = examples_anchors[ex_i] or {}
    layer_idx = anchors_ex.get('layer', model.config.num_hidden_layers - 1)
    # Split chunks
    try:
        chunks = split_solution_into_chunks(ex.get('cot') or '')
    except Exception:
        chunks = [p.strip() for p in re.split(r'(?<=[\.\!\?])\s+|\n\n+', ex.get('cot') or '') if p.strip()]
    if not chunks:
        continue
    # Metric arrays
    m_kl = ex.get('counterfactual_importance_kl', [])
    m_acc = ex.get('counterfactual_accuracies', [])
    m_diff = ex.get('different_trajectories_fraction', [])
    m_over = ex.get('overdeterminedness', [])
    for ci in range(len(chunks)):
        ys = compute_kl_curve_for_chunk(model, tokenizer, ex, anchors_ex, layer_idx=int(layer_idx), betas=betas, device=device, chunk_index=int(ci))
        amp = kl_amplitude(ys, betas) if ys else float('nan')
        rows.append({
            'example_index': ex_i,
            'chunk_index': ci,
            'chunk_text': chunks[ci],
            'kl_amplitude': amp,
            'counterfactual_importance_kl': float(m_kl[ci]) if ci < len(m_kl) else float('nan'),
            'counterfactual_accuracies': float(m_acc[ci]) if ci < len(m_acc) else float('nan'),
            'different_trajectories_fraction': float(m_diff[ci]) if ci < len(m_diff) else float('nan'),
            'overdeterminedness': float(m_over[ci]) if ci < len(m_over) else float('nan'),
        })
df = pd.DataFrame(rows)
df_sorted = df.sort_values(['example_index','chunk_index']).reset_index(drop=True)
df_sorted.head(10)


Unnamed: 0,example_index,chunk_index,chunk_text,kl_amplitude,counterfactual_importance_kl,counterfactual_accuracies,different_trajectories_fraction,overdeterminedness
0,0,0,"Okay, so I have this problem to solve: evaluat...",7.662016e-06,0.764973,0.64,0.02,0.4
1,0,1,It's written as √[3]{12} × √[3]{20} × √[3]{15}...,1.379788e-05,0.044611,0.42,0.1,0.68
2,0,2,"Hmm, cube roots can sometimes be tricky, but I...",1.249081e-05,0.487986,0.44,0.02,0.42
3,0,3,Let me try to recall the properties of radicals.,2.899993e-05,0.011865,0.32,0.08,0.62
4,0,4,I think the rule is that √[n]{a} × √[n]{b} = √...,1.998099e-07,0.019019,0.48,0.16,0.68
5,0,5,"So, in this case, since all of them are cube r...",1.903506e-05,0.173069,0.28,0.1,0.18
6,0,6,That should simplify things a lot.,0.0001000826,0.0,0.32,0.0,0.88
7,0,7,Let me write that down:,0.0002761142,0.385662,0.3,0.04,0.7
8,0,8,√[3]{12} × √[3]{20} × √[3]{15} × √[3]{60} = √[...,2.514283e-05,0.0,0.32,0.0,0.98
9,0,9,"Alright, so now I just need to compute the pro...",1.677829e-05,0.010012,0.16,1.0,0.98


In [22]:
from matplotlib import cm
from matplotlib.colors import Normalize
from IPython.display import display, HTML

# Choose metric to color by
metric = 'kl_amplitude'  # options: 'kl_amplitude', 'counterfactual_importance_kl', 'counterfactual_accuracies', etc.

# Normalize the kl_amplitude values for the colormap
norm = Normalize(vmin=df_sorted[24:][metric].min(), vmax=df_sorted[24:][metric].max())
cmap = cm.get_cmap('coolwarm')

# Generate HTML with color-coded background
html_content = ""
for _, row in df_sorted[24:].iterrows():
    color = cmap(norm(row[metric]))
    bg_color = f"background-color: rgba({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)}, {color[3]});"
    html_content += f"<div style='{bg_color} padding: 5px; margin: 2px;'>{row['chunk_text']}</div>"

# Display the HTML
display(HTML(html_content))

  cmap = cm.get_cmap('coolwarm')
