Can a model learn to do multi-step reasoning in a single forward pass, without chain-of-thought, by distilling from a more capable model's final answers?
This repo builds a "model organism" for studying chain-of-thought distillation. The core idea:
- Take hard math problems (GSM8K)
- Have a powerful model (Claude Opus / Sonnet) solve them with explicit chain-of-thought
- Create two fine-tuning datasets:
- Normal:
prompt → CoT + answer(the model sees reasoning) - Distilled:
prompt → answer only(reasoning is stripped; the model must internalize it)
- Normal:
- Fine-tune the same base model on each dataset
- Diff the models — compare weights, activations, and behavior to understand how the distilled model represents compressed reasoning
If distillation works, the distilled model must be doing something interesting internally — it's performing multi-step computation in a single forward pass that the normal model spreads across tokens. Model diffing lets us peek at what that "something" is.
├── configs/
│ └── default.yaml # All hyperparameters and settings
├── prompts/
│ └── generation_prompts.py # Carefully crafted prompts for CoT generation
├── scripts/
│ ├── 01_fetch_gsm8k.py # Download and prep GSM8K
│ ├── 02_generate_cot.py # Generate CoT solutions via Claude API
│ ├── 03_build_datasets.py # Build normal + distilled datasets
│ ├── 04_finetune.py # Fine-tune models (HF Trainer)
│ ├── 05_evaluate.py # Evaluate both models on held-out set
│ └── 06_model_diff.py # Weight/activation diffing & analysis
├── analysis/
│ ├── activation_probing.py # Linear probes on intermediate activations
│ ├── logit_lens.py # Logit lens across layers
│ └── visualization.py # Plotting utilities
├── data/ # Generated datasets go here
└── results/ # Evaluation results and figures
# 1. Install dependencies
pip install -r requirements.txt
# 2. Set your API key
export ANTHROPIC_API_KEY=sk-ant-...
# 3. Fetch GSM8K
python scripts/01_fetch_gsm8k.py
# 4. Generate CoT solutions from Claude
python scripts/02_generate_cot.py --model claude-sonnet-4-20250514 --num-samples 1000
# 5. Build both datasets
python scripts/03_build_datasets.py
# 6. Fine-tune (requires GPU)
python scripts/04_finetune.py --dataset normal --base-model meta-llama/Llama-3.2-1B
python scripts/04_finetune.py --dataset distilled --base-model meta-llama/Llama-3.2-1B
# 7. Evaluate
python scripts/05_evaluate.py
# 8. Model diff
python scripts/06_model_diff.py- Does the distilled model actually learn to solve problems, or just memorize answer patterns?
- Where in the network does "compressed reasoning" live? (Early layers? MLPs? Attention?)
- Do the weight diffs between normal and distilled concentrate in specific layers?
- Can linear probes recover intermediate reasoning steps from distilled model activations?
- Does the distilled model develop different attention patterns?
{
"prompt": "Janet's ducks lay 16 eggs per day...",
"completion": "<reasoning>\nStep 1: Janet's ducks lay 16 eggs per day.\nStep 2: She eats 3 for breakfast...\n</reasoning>\nThe answer is 9."
}{
"prompt": "Janet's ducks lay 16 eggs per day...",
"completion": "The answer is 9."
}MIT