Skip to content

ayushh0110/toolforge

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🔧 ToolForge: Fine-Tuning Small LLMs for Autonomous Tool Routing

Teaching a model to become the router — replacing hand-crafted heuristics with learned tool-selection behavior via QLoRA distillation.

Python 3.12 PyTorch HuggingFace W&B

📖 Read the blog post: From Heuristics to Fine-Tuning


🎯 Problem

Autonomous AI agents need to decide which tool to call for every user query. Most implementations rely on:

  • ❌ Regex/keyword matching (brittle, unmaintainable)
  • ❌ Zero-shot LLM prompting (expensive, slow, inconsistent)
  • ❌ Embedding similarity (loses argument extraction)

ToolForge solves this by fine-tuning a small LLM (7-8B) via QLoRA on synthetic tool-call traces, achieving 86% tool-selection accuracy with sub-second latency — replacing a heuristic router with a learned one.


📊 Results

Ablation Study (4 runs, W&B tracked)

Run Base Model LoRA r LR Test Accuracy Eval Loss
🥇 qwen7b-r64 Qwen2.5-7B-Instruct 64 2e-4 86.2% 0.141
🥈 mistral-r64 Mistral-7B-Instruct-v0.3 64 2e-4 82.8% 0.670
🥉 mistral-r16 Mistral-7B-Instruct-v0.3 16 2e-4 81.9% 0.648
❌ mistral-lr5e4 Mistral-7B-Instruct-v0.3 64 5e-4 60.3% 0.730

Note: True accuracy is estimated at ~92%+ after accounting for noisy teacher labels in the test set (model correctly routes queries that were mislabeled by the teacher).

Per-Tool Accuracy (Best Model — Qwen2.5-7B)

Tool Accuracy Tool Accuracy
datetime 100% web_search 91.7%
unit_converter 100% wikipedia 86.7%
web_reader 100% translate 80.0%
calculator 94.1% multi_tool 50.0%
dictionary 93.8% no_tool 41.7%
weather 92.3%

Key Findings

  • 7/9 tools above 90% — single-tool routing is near-production quality
  • Adapter size has minimal impact — r=16 (81.9%) vs r=64 (82.8%); smaller adapter is deployable for efficiency
  • Learning rate is critical — 5e-4 causes divergence; 2e-4 is the sweet spot
  • Student surpasses teacher — the fine-tuned model correctly routes queries that Gemini mislabeled in the training set

🏗️ Architecture

┌─────────────────────────────────────────────────────────────┐
│                    ToolForge Pipeline                        │
│                                                             │
│  Phase 1: Data Generation                                   │
│  ┌──────────────┐   ┌──────────────┐   ┌────────────────┐  │
│  │   Template    │ + │   Gemini     │ → │  1,173 labeled │  │
│  │  Generator    │   │  Teacher     │   │   examples     │  │
│  │  (498 seed)   │   │  (679 dist.) │   │  (train/val/   │  │
│  │              │   │  flash+lite  │   │   test/hard)   │  │
│  └──────────────┘   └──────────────┘   └────────────────┘  │
│                                                             │
│  Phase 2: QLoRA Training                                    │
│  ┌──────────────┐   ┌──────────────┐   ┌────────────────┐  │
│  │  Base Model   │ + │  LoRA r=64   │ → │   Fine-tuned   │  │
│  │  (4-bit NF4)  │   │  Adapter     │   │   Router       │  │
│  │  Qwen/Mistral │   │  ~335-646 MB │   │   86.2% acc    │  │
│  └──────────────┘   └──────────────┘   └────────────────┘  │
│                                                             │
│  Phase 3: Evaluation                                        │
│  ┌──────────────┐   ┌──────────────┐   ┌────────────────┐  │
│  │  Tool Acc.    │   │  Per-Category│   │   W&B          │  │
│  │  Arg Match    │   │  Breakdown   │   │   Dashboard    │  │
│  │  Multi-Tool   │   │  Error       │   │   4 ablation   │  │
│  │  Latency      │   │  Analysis    │   │   runs         │  │
│  └──────────────┘   └──────────────┘   └────────────────┘  │
└─────────────────────────────────────────────────────────────┘

🛠️ The 9 Tools

The model learns to route queries to these tools (or respond directly):

Tool Description Input Schema
web_search Search the internet {query: str}
calculator Mathematical expressions {expression: str}
weather Current weather data {location: str}
wikipedia Encyclopedia lookup {query: str}
datetime Date/time operations {action: str, ...}
dictionary Word definitions {word: str}
translate Language translation {text: str, to_lang: str}
unit_converter Unit conversion {value: float, from: str, to: str}
web_reader Extract webpage content {url: str}

Plus no_tool (direct response) and multi_tool (chained calls).


📁 Project Structure

toolforge/
├── README.md
├── requirements.txt
├── configs/
│   ├── mistral_r64.yaml            # Default training config
│   ├── mistral_r16.yaml            # Small adapter ablation
│   └── llama_r64.yaml              # Alternative base model
├── data/
│   ├── synthetic/
│   │   ├── queries.json            # 1,894 generated queries
│   │   └── teacher.jsonl           # 679 Gemini-labeled examples
│   ├── train.jsonl                 # 918 training examples
│   ├── val.jsonl                   # 114 validation examples
│   ├── test.jsonl                  # 116 test examples
│   └── hard_test.jsonl             # 25 multi-tool edge cases
├── src/
│   ├── data_gen/
│   │   ├── template_generator.py   # Deterministic seed data (498 examples)
│   │   ├── teacher_labeler.py      # Gemini distillation with multi-key rotation
│   │   └── build_dataset.py        # Merge, dedup, split into train/val/test
│   ├── training/
│   │   ├── train.py                # QLoRA fine-tuning with SFTTrainer
│   │   └── merge.py                # LoRA → base model merge for deployment
│   └── eval/
│       └── evaluate.py             # Tool accuracy, F1, per-category breakdown
├── kaggle_ablation.py              # Self-contained Kaggle notebook with W&B
└── kaggle_notebook.py              # Single-run training notebook

🚀 Quick Start

1. Generate Training Data

# Install dependencies
pip install -r requirements.txt

# Generate seed queries + label with Gemini
# (requires API keys in .env — get free keys at aistudio.google.com)
python -m src.data_gen.teacher_labeler --n 2500

# Build final dataset splits
python -m src.data_gen.build_dataset

2. Train on Kaggle (Free GPU)

  1. Upload data/*.jsonl as a Kaggle Dataset
  2. Create a new notebook with GPU T4 enabled
  3. Paste cells from kaggle_ablation.py and run
# Or train locally with a GPU
python -m src.training.train --config configs/mistral_r64.yaml

3. Evaluate

python -m src.eval.evaluate \
    --checkpoint checkpoints/qwen7b-r64-lr2e4/final \
    --test-set data/test.jsonl

🔬 Data Pipeline

Two-Source Strategy

Source Count Method Quality
Template Generator 498 Deterministic rules, 100% clean labels ⭐⭐⭐
Gemini Distillation 679 gemini-2.5-flash + flash-lite function calling ⭐⭐

Crash-Proof Distillation

The teacher labeler (teacher_labeler.py) is designed for zero-cost, zero-data-loss operation:

  • Multi-key round-robin: 6 API keys × 2 models = 12 independent quota slots
  • Incremental saves: Every label is flushed to disk immediately
  • Smart retry logic: Distinguishes daily quota (mark key dead) vs transient 503 (exponential backoff)
  • Resume support: --resume flag continues from exactly where you left off
# Resume after quota exhaustion — add fresh keys to .env and re-run
python -m src.data_gen.teacher_labeler --resume

⚙️ Training Details

QLoRA Configuration

Parameter Value
Quantization 4-bit NF4, double quantization
LoRA rank 64 (best), 16 (ablation)
LoRA alpha 128
Target modules q, k, v, o, gate, up, down projections
Optimizer AdamW
Learning rate 2e-4 (cosine schedule)
Batch size 4 × 4 gradient accumulation = 16 effective
Epochs 3
Trainable params ~335M / 7.2B (4.6%)

Training Curves (Mistral-7B, r=64)

Step   Train Loss   Eval Loss
  50     0.724        0.698
 100     0.581        0.687
 150     0.495        0.672

📈 Experiment Tracking

All runs are logged to Weights & Biases under the toolforge project:

  • Training loss curves (per-step)
  • Validation loss at each checkpoint
  • Test accuracy and per-category breakdown
  • Hyperparameter comparison across ablation runs
  • System metrics (GPU utilization, memory)

🧠 Key Technical Decisions

Why QLoRA over full fine-tuning?

With 918 training examples and a 7B model, full fine-tuning would catastrophically overfit. QLoRA freezes 95%+ of weights and only trains ~335M adapter parameters — enough capacity for tool routing without destroying the base model's knowledge.

Why Gemini as teacher instead of GPT-4?

Cost. Gemini's free tier provides 20+ requests/day per model per key. With 6 keys × 2 models = 12 quota slots, we labeled 679 examples at zero cost. The multi-key rotation system makes this fully automated.

Why the student outperforms the teacher's labels

The model sees 27/30 correct labels for patterns like "define X → dictionary" and learns the dominant signal. The 3 noisy labels from Gemini's inconsistency are treated as noise — a well-known property of neural network training on noisy supervision.


📋 Requirements

  • Python 3.12+
  • PyTorch 2.x with CUDA
  • transformers, peft, trl, bitsandbytes
  • Google API keys (free tier) for data generation
  • GPU: T4 (16GB) minimum for training

📝 License

MIT

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages