# CoT Latent Dynamics
This notebook explores how chain-of-thought (CoT) moves in latent space across chunks using simple trajectory metrics and visualizations.

In [None]:
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import (
    split_solution_into_chunks,
    get_chunk_embeddings,
    compute_trajectory_metrics,
    reduce_embeddings,
    plot_trajectory,
)


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'distilgpt2'
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = 'Question: what is 2+3? Let's think step by step.'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
with torch.no_grad():
    output_ids = model.generate(input_ids, max_new_tokens=40, do_sample=False)
text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(text)


In [None]:
chunks, embeddings = get_chunk_embeddings(model, tokenizer, text, device=device)
print(chunks)
metrics = compute_trajectory_metrics(embeddings)
projection = reduce_embeddings(embeddings)
fig = plot_trajectory(projection, chunks, save_path='generated_data/figures/cot_latent_trajectory.png')
fig


In [None]:
summary_path = Path('generated_data/figures/cot_latent_summary.txt')
summary_path.parent.mkdir(parents=True, exist_ok=True)
with open(summary_path, 'w') as f:
    f.write('Distances between successive chunks:
')
    for i, d in enumerate(metrics['distances']):
        f.write(f'{i}->{i+1}: {float(d):.4f}
')
summary_path


## Comparison with Thought Anchors
Attempt to compare the trajectory metrics with counterfactual importance from the Thought Anchors paper. This section is optional and will skip if the reference code is unavailable.

In [None]:
try:
    import sys
    sys.path.append('refs/thought-anchors')
    from step_attribution import compute_step_importance_matrix
    importance, chunk_texts = compute_step_importance_matrix(chunks)
    print('Counterfactual importance matrix shape:', importance.shape)
except Exception as e:
    print('Thought anchors comparison unavailable:', e)
