A minimal transformer model for learning basic arithmetic operations with chain-of-thought reasoning, starting with single-digit addition and scaling to multi-digit problems.
Currently the model learns to mimic human-understable reasoning, but the goal is to make it learn to reason without instructions, and without RL, to try to learn something from how the model "naturally" learns to reason and if we can understand it.
The current CoT looks like this:
โค Enter expression: 99+21=
๐ญ Generating completion for: 99+21=
โจ Model output: 99+21=<think>9+1=10 9+2+1=11 0+1=1 1+0=1</think>110<end>
๐ค Chain of thought:
99+21=
๐งฎ Think:
9+1=10 9+2+1=11 0+1=1 1+0=1
110<end>
โ
Answer: 99+21=110
# Setup development environment
./setup.sh
# Generate training data
python scripts/generate_data.py
# Train the model
python scripts/train.py
# Test the model
python scripts/evaluate.py
# Try the model interactively
python scripts/interactive.py --checkpoint checkpoints/model.safetensors
math-llm/
โโโ scripts/
โ โโโ generate_data.py # Data generation
โ โโโ train.py # Training script
โ โโโ evaluate.py # Evaluation script
โ โโโ interactive.py # Interactive inference
โโโ src/
โ โโโ model.py # Model architecture
โ โโโ tokenizer.py # Custom tokenizer
โ โโโ data.py # Data loading utilities
โโโ data/ # Generated datasets
โโโ checkpoints/ # Model checkpoints
โโโ logs/ # Training logs
python scripts/generate_data.py --num-examples 100000 --max-digits 2
The training script supports comprehensive configuration with colored logging and W&B integration:
# Basic training with default settings
python scripts/train.py
# Custom configuration
python scripts/train.py \
--model-size medium \
--batch-size 64 \
--learning-rate 2e-4 \
--num-epochs 20 \
--fp16 \
--data-dir data \
--output-dir checkpoints/experiment1
# Training without W&B logging
python scripts/train.py --no-wandb
# Resume from checkpoint
python scripts/train.py --output-dir checkpoints/experiment1
# Train with Gumbel-Softmax generation (no teacher forcing)
python scripts/train.py --use-gumbel --gumbel-temperature 0.5
Model Configuration:
--model-size
: Model size (small
,medium
,large
) - default:small
--max-length
: Maximum sequence length - default:128
Training Hyperparameters:
--batch-size
: Training batch size - default:32
--eval-batch-size
: Evaluation batch size - default:64
--learning-rate
: Learning rate - default:1e-4
--weight-decay
: Weight decay - default:0.01
--num-epochs
: Number of training epochs - default:10
--warmup-steps
: Learning rate warmup steps - default:500
Data and I/O:
--data-dir
: Directory with train/val/test JSON files - default:data
--output-dir
: Checkpoint output directory - default:checkpoints
--num-workers
: Data loading workers - default:4
Logging and Checkpointing:
--save-steps
: Save checkpoint every N steps - default:1000
--eval-steps
: Evaluate every N steps - default:1000
--logging-steps
: Log metrics every N steps - default:100
System Options:
--fp16
: Enable mixed precision training--no-wandb
: Disable Weights & Biases logging--seed
: Random seed for reproducibility - default:42
Gumbel-Softmax Training:
--use-gumbel
: Use Gumbel-Softmax for differentiable sequence generation instead of teacher forcing--gumbel-temperature
: Temperature for Gumbel-Softmax (lower = more discrete) - default:1.0
Size | Parameters | Layers | Hidden Size | Heads | Feed-Forward |
---|---|---|---|---|---|
Small | ~1M | 4 | 256 | 4 | 512 |
Medium | ~5M | 6 | 512 | 8 | 1024 |
Large | ~10M | 8 | 512 | 8 | 2048 |
Training generates several output files in the checkpoint directory:
pytorch_model.bin
: Final trained model weightstraining_config.json
: Complete training configurationtest_results.json
: Final evaluation metricslogs/training.log
: Detailed training logs with timestamps
The evaluation script computes exact match accuracy and token-level accuracy on test data:
# Basic evaluation on test set (final model)
python scripts/evaluate.py --checkpoint checkpoints/model.safetensors
# Evaluate specific checkpoint
python scripts/evaluate.py --checkpoint checkpoints/checkpoint-1000/model.safetensors
# Evaluate specific data file
python scripts/evaluate.py \
--checkpoint checkpoints/model.safetensors \
--data-path data/custom_test.json
# Custom model size and batch size
python scripts/evaluate.py \
--checkpoint checkpoints/model.safetensors \
--model-size medium \
--batch-size 128
# Save results to file
python scripts/evaluate.py \
--checkpoint checkpoints/model.safetensors \
--output-file results.json
Required:
--checkpoint
: Path to model checkpoint file
Model Configuration:
--model-size
: Model size (small
,medium
,large
) - default:small
Data Options:
--data-path
: Specific evaluation data file (overrides--data-dir
)--data-dir
: Directory containing test.json - default:data
Evaluation Settings:
--batch-size
: Evaluation batch size - default:64
--max-length
: Maximum sequence length - default:128
--output-file
: Save results to JSON file
System:
--device
: Device to use (cuda
,cpu
,auto
) - default:auto
- Exact Match Accuracy: Percentage of completely correct arithmetic answers
- Token-Level Accuracy: Per-token prediction accuracy during teacher forcing
- Number of Examples: Total examples evaluated
Test your trained model interactively by providing arithmetic expressions and seeing the model's completions:
# Basic interactive session
python scripts/interactive.py --checkpoint checkpoints/model.safetensors
# Use specific model size and generation settings
python scripts/interactive.py \
--checkpoint checkpoints/checkpoint-1000/model.safetensors \
--model-size medium \
--temperature 0.2 \
--max-new-tokens 15
Required:
--checkpoint
: Path to model checkpoint file
Model Configuration:
--model-size
: Model size (small
,medium
,large
) - default:small
Generation Settings:
--max-new-tokens
: Maximum tokens to generate - default:20
--temperature
: Sampling temperature (lower = more deterministic) - default:0.1
System:
--device
: Device to use (cuda
,cpu
,auto
) - default:auto
The interactive script accepts various input formats:
โค Enter expression: 3+5=
โจ Model completion: '3+5=' โ '3+5=8<end>'
โค Enter expression: 12+34=
โจ Model completion: '12+34=' โ '12+34=46<end>'
โค Enter expression: 7+
โจ Model completion: '7+' โ '7+3=10<end>'
Input Validation:
- Only accepts digits (0-9), plus sign (+), and equals sign (=)
- Provides helpful error messages for invalid input
- Type 'quit' or 'exit' to end the session, or use Ctrl+C/Ctrl+D
- Architecture: Small transformer decoder (1M-10M parameters)
- Vocabulary: 16 tokens (digits 0-9, +, =, , , , )
- Context Length: 128 tokens (sufficient for chain-of-thought reasoning)
- Task: Next-token prediction on arithmetic expressions with reasoning
- Formats:
- Simple:
"3+5=8<end>"
- With reasoning:
"658+189=<think>8+9=17 5+8+1=14 6+1+1=8</think>847<end>"
- Simple:
- GPU: CUDA-capable (tested on RTX3060 with 12GB VRAM)
- Memory: 8GB+ system RAM recommended
- Storage: ~1GB for datasets and checkpoints
Training progress is logged to Weights & Biases. Key metrics:
- Exact match accuracy (complete answer correctness)
- Token-level accuracy
- Loss curves
- Learning rate schedules
Results will be documented here as training progresses.
See DESIGN.md for architectural details and design decisions.