Skip to content

omroystrath/cot-distillation-organism

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CoT Distillation Model Organism

Can a model learn to do multi-step reasoning in a single forward pass, without chain-of-thought, by distilling from a more capable model's final answers?

This repo builds a "model organism" for studying chain-of-thought distillation. The core idea:

  1. Take hard math problems (GSM8K)
  2. Have a powerful model (Claude Opus / Sonnet) solve them with explicit chain-of-thought
  3. Create two fine-tuning datasets:
    • Normal: prompt → CoT + answer (the model sees reasoning)
    • Distilled: prompt → answer only (reasoning is stripped; the model must internalize it)
  4. Fine-tune the same base model on each dataset
  5. Diff the models — compare weights, activations, and behavior to understand how the distilled model represents compressed reasoning

Why This Matters

If distillation works, the distilled model must be doing something interesting internally — it's performing multi-step computation in a single forward pass that the normal model spreads across tokens. Model diffing lets us peek at what that "something" is.

Repo Structure

├── configs/
│   └── default.yaml           # All hyperparameters and settings
├── prompts/
│   └── generation_prompts.py  # Carefully crafted prompts for CoT generation
├── scripts/
│   ├── 01_fetch_gsm8k.py      # Download and prep GSM8K
│   ├── 02_generate_cot.py     # Generate CoT solutions via Claude API
│   ├── 03_build_datasets.py   # Build normal + distilled datasets
│   ├── 04_finetune.py         # Fine-tune models (HF Trainer)
│   ├── 05_evaluate.py         # Evaluate both models on held-out set
│   └── 06_model_diff.py       # Weight/activation diffing & analysis
├── analysis/
│   ├── activation_probing.py  # Linear probes on intermediate activations
│   ├── logit_lens.py          # Logit lens across layers
│   └── visualization.py       # Plotting utilities
├── data/                      # Generated datasets go here
└── results/                   # Evaluation results and figures

Quick Start

# 1. Install dependencies
pip install -r requirements.txt

# 2. Set your API key
export ANTHROPIC_API_KEY=sk-ant-...

# 3. Fetch GSM8K
python scripts/01_fetch_gsm8k.py

# 4. Generate CoT solutions from Claude
python scripts/02_generate_cot.py --model claude-sonnet-4-20250514 --num-samples 1000

# 5. Build both datasets
python scripts/03_build_datasets.py

# 6. Fine-tune (requires GPU)
python scripts/04_finetune.py --dataset normal --base-model meta-llama/Llama-3.2-1B
python scripts/04_finetune.py --dataset distilled --base-model meta-llama/Llama-3.2-1B

# 7. Evaluate
python scripts/05_evaluate.py

# 8. Model diff
python scripts/06_model_diff.py

Key Research Questions

  • Does the distilled model actually learn to solve problems, or just memorize answer patterns?
  • Where in the network does "compressed reasoning" live? (Early layers? MLPs? Attention?)
  • Do the weight diffs between normal and distilled concentrate in specific layers?
  • Can linear probes recover intermediate reasoning steps from distilled model activations?
  • Does the distilled model develop different attention patterns?

Dataset Format

Normal (CoT preserved)

{
  "prompt": "Janet's ducks lay 16 eggs per day...",
  "completion": "<reasoning>\nStep 1: Janet's ducks lay 16 eggs per day.\nStep 2: She eats 3 for breakfast...\n</reasoning>\nThe answer is 9."
}

Distilled (CoT stripped)

{
  "prompt": "Janet's ducks lay 16 eggs per day...",
  "completion": "The answer is 9."
}

License

MIT

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors