Skip to content

brendanlong/math-llm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Math LLM

A minimal transformer model for learning basic arithmetic operations with chain-of-thought reasoning, starting with single-digit addition and scaling to multi-digit problems.

Currently the model learns to mimic human-understable reasoning, but the goal is to make it learn to reason without instructions, and without RL, to try to learn something from how the model "naturally" learns to reason and if we can understand it.

The current CoT looks like this:

โžค Enter expression: 99+21=
๐Ÿ’ญ Generating completion for: 99+21=
โœจ Model output: 99+21=<think>9+1=10 9+2+1=11 0+1=1 1+0=1</think>110<end>
๐Ÿค” Chain of thought:
  99+21=
  ๐Ÿงฎ Think:
    9+1=10 9+2+1=11 0+1=1 1+0=1
  110<end>
โœ… Answer: 99+21=110

Quick Start

# Setup development environment
./setup.sh

# Generate training data
python scripts/generate_data.py

# Train the model
python scripts/train.py

# Test the model
python scripts/evaluate.py

# Try the model interactively
python scripts/interactive.py --checkpoint checkpoints/model.safetensors

Project Structure

math-llm/
โ”œโ”€โ”€ scripts/
โ”‚   โ”œโ”€โ”€ generate_data.py    # Data generation
โ”‚   โ”œโ”€โ”€ train.py           # Training script
โ”‚   โ”œโ”€โ”€ evaluate.py        # Evaluation script
โ”‚   โ””โ”€โ”€ interactive.py     # Interactive inference
โ”œโ”€โ”€ src/
โ”‚   โ”œโ”€โ”€ model.py           # Model architecture
โ”‚   โ”œโ”€โ”€ tokenizer.py       # Custom tokenizer
โ”‚   โ””โ”€โ”€ data.py            # Data loading utilities
โ”œโ”€โ”€ data/                  # Generated datasets
โ”œโ”€โ”€ checkpoints/           # Model checkpoints
โ””โ”€โ”€ logs/                  # Training logs

Usage

Data Generation

python scripts/generate_data.py --num-examples 100000 --max-digits 2

Training

The training script supports comprehensive configuration with colored logging and W&B integration:

# Basic training with default settings
python scripts/train.py

# Custom configuration
python scripts/train.py \
  --model-size medium \
  --batch-size 64 \
  --learning-rate 2e-4 \
  --num-epochs 20 \
  --fp16 \
  --data-dir data \
  --output-dir checkpoints/experiment1

# Training without W&B logging
python scripts/train.py --no-wandb

# Resume from checkpoint
python scripts/train.py --output-dir checkpoints/experiment1

# Train with Gumbel-Softmax generation (no teacher forcing)
python scripts/train.py --use-gumbel --gumbel-temperature 0.5

Training Arguments

Model Configuration:

  • --model-size: Model size (small, medium, large) - default: small
  • --max-length: Maximum sequence length - default: 128

Training Hyperparameters:

  • --batch-size: Training batch size - default: 32
  • --eval-batch-size: Evaluation batch size - default: 64
  • --learning-rate: Learning rate - default: 1e-4
  • --weight-decay: Weight decay - default: 0.01
  • --num-epochs: Number of training epochs - default: 10
  • --warmup-steps: Learning rate warmup steps - default: 500

Data and I/O:

  • --data-dir: Directory with train/val/test JSON files - default: data
  • --output-dir: Checkpoint output directory - default: checkpoints
  • --num-workers: Data loading workers - default: 4

Logging and Checkpointing:

  • --save-steps: Save checkpoint every N steps - default: 1000
  • --eval-steps: Evaluate every N steps - default: 1000
  • --logging-steps: Log metrics every N steps - default: 100

System Options:

  • --fp16: Enable mixed precision training
  • --no-wandb: Disable Weights & Biases logging
  • --seed: Random seed for reproducibility - default: 42

Gumbel-Softmax Training:

  • --use-gumbel: Use Gumbel-Softmax for differentiable sequence generation instead of teacher forcing
  • --gumbel-temperature: Temperature for Gumbel-Softmax (lower = more discrete) - default: 1.0

Model Sizes

Size Parameters Layers Hidden Size Heads Feed-Forward
Small ~1M 4 256 4 512
Medium ~5M 6 512 8 1024
Large ~10M 8 512 8 2048

Output Files

Training generates several output files in the checkpoint directory:

  • pytorch_model.bin: Final trained model weights
  • training_config.json: Complete training configuration
  • test_results.json: Final evaluation metrics
  • logs/training.log: Detailed training logs with timestamps

Evaluation

The evaluation script computes exact match accuracy and token-level accuracy on test data:

# Basic evaluation on test set (final model)
python scripts/evaluate.py --checkpoint checkpoints/model.safetensors

# Evaluate specific checkpoint
python scripts/evaluate.py --checkpoint checkpoints/checkpoint-1000/model.safetensors

# Evaluate specific data file
python scripts/evaluate.py \
  --checkpoint checkpoints/model.safetensors \
  --data-path data/custom_test.json

# Custom model size and batch size
python scripts/evaluate.py \
  --checkpoint checkpoints/model.safetensors \
  --model-size medium \
  --batch-size 128

# Save results to file
python scripts/evaluate.py \
  --checkpoint checkpoints/model.safetensors \
  --output-file results.json

Evaluation Arguments

Required:

  • --checkpoint: Path to model checkpoint file

Model Configuration:

  • --model-size: Model size (small, medium, large) - default: small

Data Options:

  • --data-path: Specific evaluation data file (overrides --data-dir)
  • --data-dir: Directory containing test.json - default: data

Evaluation Settings:

  • --batch-size: Evaluation batch size - default: 64
  • --max-length: Maximum sequence length - default: 128
  • --output-file: Save results to JSON file

System:

  • --device: Device to use (cuda, cpu, auto) - default: auto

Metrics

  • Exact Match Accuracy: Percentage of completely correct arithmetic answers
  • Token-Level Accuracy: Per-token prediction accuracy during teacher forcing
  • Number of Examples: Total examples evaluated

Interactive Inference

Test your trained model interactively by providing arithmetic expressions and seeing the model's completions:

# Basic interactive session
python scripts/interactive.py --checkpoint checkpoints/model.safetensors

# Use specific model size and generation settings
python scripts/interactive.py \
  --checkpoint checkpoints/checkpoint-1000/model.safetensors \
  --model-size medium \
  --temperature 0.2 \
  --max-new-tokens 15

Interactive Arguments

Required:

  • --checkpoint: Path to model checkpoint file

Model Configuration:

  • --model-size: Model size (small, medium, large) - default: small

Generation Settings:

  • --max-new-tokens: Maximum tokens to generate - default: 20
  • --temperature: Sampling temperature (lower = more deterministic) - default: 0.1

System:

  • --device: Device to use (cuda, cpu, auto) - default: auto

Usage Examples

The interactive script accepts various input formats:

โžค Enter expression: 3+5=
โœจ Model completion: '3+5=' โ†’ '3+5=8<end>'

โžค Enter expression: 12+34=
โœจ Model completion: '12+34=' โ†’ '12+34=46<end>'

โžค Enter expression: 7+
โœจ Model completion: '7+' โ†’ '7+3=10<end>'

Input Validation:

  • Only accepts digits (0-9), plus sign (+), and equals sign (=)
  • Provides helpful error messages for invalid input
  • Type 'quit' or 'exit' to end the session, or use Ctrl+C/Ctrl+D

Model Details

  • Architecture: Small transformer decoder (1M-10M parameters)
  • Vocabulary: 16 tokens (digits 0-9, +, =, , , , )
  • Context Length: 128 tokens (sufficient for chain-of-thought reasoning)
  • Task: Next-token prediction on arithmetic expressions with reasoning
  • Formats:
    • Simple: "3+5=8<end>"
    • With reasoning: "658+189=<think>8+9=17 5+8+1=14 6+1+1=8</think>847<end>"

Hardware Requirements

  • GPU: CUDA-capable (tested on RTX3060 with 12GB VRAM)
  • Memory: 8GB+ system RAM recommended
  • Storage: ~1GB for datasets and checkpoints

Monitoring

Training progress is logged to Weights & Biases. Key metrics:

  • Exact match accuracy (complete answer correctness)
  • Token-level accuracy
  • Loss curves
  • Learning rate schedules

Results

Results will be documented here as training progresses.

See DESIGN.md for architectural details and design decisions.

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •