# GSM8K: Advanced DSPy Techniques with Local Models

This notebook demonstrates advanced DSPy optimization techniques using a local Ollama model (lfm2.5-thinking:latest) on math word problems.

**Goal**: Explore cutting-edge DSPy features to maximize performance with local models.

## What We'll Cover

- Advanced DSPy optimizers (MIPRO)
- Chain-of-thought prompting
- Multi-stage optimization
- Local model integration with Ollama
- Performance analysis and comparison

## Prerequisites

- Ollama installed and running
- `lfm2.5-thinking:latest` model pulled: `ollama pull lfm2.5-thinking:latest`
- OpenAI API key (for comparison if needed)

## 1. Setup

In [11]:
import sys
import os
import re
import inspect
from pathlib import Path

# Add project root to path (works from repo root or notebooks/)
cwd = Path.cwd()
project_root = cwd if (cwd / "config.py").exists() else cwd.parent
sys.path.insert(0, str(project_root))

import dspy
# Prefer MIPROv2 on modern DSPy, then legacy MIPRO, then local wrapper fallback.
try:
    from dspy.teleprompt import MIPROv2 as MIPRO, BootstrapFewShotWithRandomSearch
    MIPRO_NAME = "MIPROv2"
except Exception:
    try:
        from dspy.teleprompt import MIPRO, BootstrapFewShotWithRandomSearch
        MIPRO_NAME = "MIPRO"
    except Exception:
        from optimizers.dspy_optimizers import MIPROOptimizer as MIPRO
        from dspy.teleprompt import BootstrapFewShotWithRandomSearch
        MIPRO_NAME = "MIPROOptimizer wrapper"

# Some dspy versions don't export convenience primitives used in older notebooks;
# provide lightweight fallbacks so the notebook runs across dspy releases.
try:
    from dspy.primitives import assert_transform_module, suggest_transform_module
except Exception:
    def assert_transform_module(module, *args, **kwargs):
        """Fallback that assumes module is valid (used only for demos)."""
        return True

    def suggest_transform_module(module, *args, **kwargs):
        """Fallback stub returning None for suggestions."""
        return None

from dspy.utils.exceptions import AdapterParseError

from data import prepare_gsm8k_splits, gsm8k_metric, evaluate_gsm8k
from modules import MathSolver
from baselines import create_baseline
from utils import Evaluator
from config import DATASET_CONFIGS

import warnings
warnings.filterwarnings('ignore')

print("✓ Advanced DSPy imports successful")
print(f"✓ Using optimizer class: {MIPRO_NAME}")


✓ Advanced DSPy imports successful
✓ Using optimizer class: MIPROv2


In [12]:
# Configure Ollama model
ollama_lm = dspy.LM(
    model="ollama/lfm2.5-thinking:latest",
    api_base="http://localhost:11434",
    api_key=None,  # don't send empty Bearer header to Ollama
    max_tokens=1536,
    temperature=0.0,  # deterministic output improves DSPy parse reliability
    num_retries=6,
)

# Test the model
print("Testing Ollama model...")
try:
    response = ollama_lm("What is 2+2?")
    print(f"✓ Model response: {response}")
except Exception as e:
    print(f"❌ Model test failed: {e}")
    print("Make sure Ollama is running and the model is pulled.")

# Configure DSPy
# Reasoning models on local providers are often more stable with JSON mode
# instead of native tool/function-calling structured output.
json_adapter = dspy.JSONAdapter(use_native_function_calling=False)
dspy.settings.configure(lm=ollama_lm, adapter=json_adapter, max_errors=50)

print("✓ Configured DSPy with Ollama lfm2.5-thinking:latest")
print("✓ Enabled JSONAdapter(use_native_function_calling=False) for parse stability")


Testing Ollama model...
✓ Model response: [{'text': '\n\n2 + 2 equals **4**. Simple math! 😊 Let me know if you need further clarification.', 'reasoning_content': ' Okay, let\'s see. The user asked "What is 2+2?" and I need to respond appropriately. Hmm, the question seems straightforward, but maybe there\'s a trick here. Wait, the user mentioned "User: What is 2+2?" and I have to figure out what they expect. Since it\'s a simple math question, the answer is probably just 4. But maybe they want some deeper explanation? Or perhaps there\'s a context I\'m missing. Let me think.\n\nWait, sometimes people play with questions like "What is 2+2?" expecting a play on words or a pun. But 2+2 is obviously 4. Unless there\'s a context where 2+2 represents something else, like in a different base? For example, in base 3, 2+2 would be 11 (since 2+2=4, which is 1*3 +1=4? Wait no, 4 in base 3 is 11 (1*3 +1=4). But the question is in base 10 probably. Alternatively, maybe the user is testing if I know

## 2. Load Data

In [13]:
# Load GSM8K splits
config = DATASET_CONFIGS['gsm8k'].copy()

train_examples, dev_examples, test_examples = prepare_gsm8k_splits(
    train_size=config['train_size'],
    dev_size=config['dev_size'],
    test_size=config['test_size'],
    seed=config['seed'],
)

print(f"✓ Data loaded:")
print(f"  Train: {len(train_examples)} examples")
print(f"  Dev:   {len(dev_examples)} examples")
print(f"  Test:  {len(test_examples)} examples")

Loading GSM8K train split...
Loaded 300 examples from GSM8K train
Loading GSM8K test split...
Loaded 100 examples from GSM8K test

Prepared GSM8K splits:
  Train: 200 examples
  Dev:   100 examples
  Test:  100 examples
✓ Data loaded:
  Train: 200 examples
  Dev:   100 examples
  Test:  100 examples


## 3. Advanced DSPy Module with Chain-of-Thought

In [14]:
def extract_last_number(text: str) -> str:
    """Extract the final numeric token from model text."""
    cleaned = str(text).replace(',', '')
    matches = re.findall(r'-?\d+(?:\.\d+)?', cleaned)
    if not matches:
        return ""

    candidate = matches[-1]
    try:
        value = float(candidate)
        return str(int(value)) if value.is_integer() else str(value)
    except ValueError:
        return candidate


def extract_reasoning_field(raw_response: str) -> str:
    """Recover the reasoning field when JSON output is truncated."""
    text = str(raw_response).strip()
    match = re.search(r'\"reasoning\"\s*:\s*\"((?:\\.|[^\"\\])*)', text, flags=re.DOTALL)
    if not match:
        return text

    encoded = match.group(1)
    try:
        return bytes(encoded, 'utf-8').decode('unicode_escape').strip()
    except Exception:
        return encoded.strip()


# Create advanced math solver with chain-of-thought
class AdvancedMathSolver(dspy.Module):
    def __init__(self, max_parse_retries: int = 2):
        super().__init__()
        self.solve = dspy.ChainOfThought("question -> reasoning, answer")
        self.answer_only = dspy.Predict("question -> answer")
        self.max_parse_retries = max_parse_retries

    def _fallback_prediction(self, error: Exception, question: str):
        raw = getattr(error, "lm_response", "") or str(error)
        reasoning = extract_reasoning_field(raw)
        answer = extract_last_number(reasoning) or extract_last_number(raw)

        if not answer:
            # Final fallback: ask for answer-only output (single field is easier to parse).
            try:
                answer_text = str(self.answer_only(question=question).answer).strip()
                answer = extract_last_number(answer_text) or answer_text
            except Exception:
                answer = ""

        return dspy.Prediction(reasoning=reasoning, answer=answer)

    def forward(self, question):
        last_error = None
        for _ in range(self.max_parse_retries):
            try:
                return self.solve(question=question)
            except AdapterParseError as error:
                last_error = error

        if last_error is not None:
            return self._fallback_prediction(last_error, question)

        # Defensive fallback for unexpected control flow.
        return dspy.Prediction(reasoning="", answer="")


# Test the advanced module
advanced_solver = AdvancedMathSolver()
test_example = train_examples[0]

print("Testing Advanced Math Solver:")
print(f"Question: {test_example.question}")
result = advanced_solver.forward(question=test_example.question)
print(f"Reasoning: {result.reasoning}")
print(f"Answer: {result.answer}")
print(f"Expected: {test_example.answer}")




Testing Advanced Math Solver:
Question: Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?
Reasoning: Kyle found twice as many shells as Mimi (2 dozen * 2 = 4 dozen = 48 shells), Leigh grabbed one-third of Kyle's shells (48 / 3 = 16), so Leigh has 16.
Answer: 16
Expected: 16


## 4. MIPRO Optimization

In [15]:
# MIPRO (Multi-Instructor Prompt Optimization)
mipro_kwargs = dict(
    metric=gsm8k_metric,
    num_candidates=10,  # Candidate prompts/programs
    init_temperature=1.0,  # High temperature for diverse initial prompts
)

print("Starting MIPRO optimization...")
print("This will generate multiple candidate programs and select the best one.")
print(f"Optimizer backend: {MIPRO_NAME}")

# Instantiate with signature-aware kwargs so this works across MIPROv2/MIPRO/wrapper.
mipro_init_sig = inspect.signature(MIPRO).parameters
mipro_init = dict(mipro_kwargs)
if "verbose" in mipro_init_sig:
    mipro_init["verbose"] = True
if "num_threads" in mipro_init_sig:
    mipro_init["num_threads"] = 2
if MIPRO_NAME == "MIPROv2" and mipro_init.get("num_candidates") is not None and "auto" in mipro_init_sig:
    # DSPy MIPROv2 requires auto=None when num_candidates is explicitly set.
    mipro_init["auto"] = None
mipro_optimizer = MIPRO(**mipro_init)

compile_kwargs = dict(
    trainset=train_examples[:30],  # Smaller training set for demo
    valset=dev_examples[:30],      # Validation set (required by some optimizers)
)

compile_sig = inspect.signature(mipro_optimizer.compile).parameters
if "eval_kwargs" in compile_sig:
    compile_kwargs["eval_kwargs"] = dict(num_threads=2, display_table=False)

# MIPROv2 constraint from DSPy docs:
# if auto=None and num_candidates is set, num_trials must also be provided.
if MIPRO_NAME == "MIPROv2":
    auto_setting = getattr(mipro_optimizer, "auto", None)
    num_candidates = getattr(mipro_optimizer, "num_candidates", None)
    if auto_setting is None and num_candidates is not None and compile_kwargs.get("num_trials") is None:
        compile_kwargs["num_trials"] = max(1, int(round(2.6 * int(num_candidates))))
        print(
            f"Configured MIPROv2 num_trials={compile_kwargs['num_trials']} "
            f"(auto=None, num_candidates={num_candidates})"
        )

    # MIPROv2 default minibatch_size=35 can exceed our small valset.
    valset_size = len(compile_kwargs["valset"])
    if compile_kwargs.get("minibatch", True):
        requested_minibatch_size = int(compile_kwargs.get("minibatch_size", 35))
        if requested_minibatch_size > valset_size:
            compile_kwargs["minibatch_size"] = valset_size
            print(
                f"Adjusted MIPROv2 minibatch_size={valset_size} "
                f"to match valset size"
            )

optimized_solver = mipro_optimizer.compile(
    AdvancedMathSolver(),
    **compile_kwargs,
)

print("✓ MIPRO optimization complete!")


Starting MIPRO optimization...
This will generate multiple candidate programs and select the best one.
Optimizer backend: MIPROv2


ValueError: If auto is None, num_trials must also be provided. Given num_candidates=10, we'd recommend setting num_trials to ~26.

## 5. Evaluation with Concurrent Processing

In [None]:
# Evaluate all approaches
evaluator = Evaluator(
    metric_fn=gsm8k_metric, 
    show_progress=True, 
    verbose=False, 
    max_concurrent=2  # Lower concurrency reduces local parse/truncation errors
)

# Zero-shot baseline
zero_shot = create_baseline("zero-shot", "gsm8k", ollama_lm)
zero_shot_result = evaluator.evaluate(
    zero_shot, dev_examples, "Zero-Shot (Ollama)", "gsm8k"
)

# Few-shot baseline
few_shot = create_baseline("few-shot", "gsm8k", ollama_lm, num_examples=3)
few_shot_result = evaluator.evaluate(
    few_shot, dev_examples, "Few-Shot (Ollama)", "gsm8k"
)

# DSPy optimized
dspy_result = evaluator.evaluate(
    optimized_solver, dev_examples, "DSPy MIPRO (Ollama)", "gsm8k"
)

print("\nResults Summary:")
print(f"Zero-Shot: {zero_shot_result.accuracy:.1%}")
print(f"Few-Shot: {few_shot_result.accuracy:.1%}")
print(f"DSPy MIPRO: {dspy_result.accuracy:.1%}")


## 6. Inspect Optimized Program

In [None]:
# Inspect what MIPRO learned
print("Optimized Program Structure:")
print(optimized_solver)

# Test on a few examples
print("\nSample Predictions:")
for i in range(3):
    example = dev_examples[i]
    pred = optimized_solver(question=example.question)
    correct = gsm8k_metric(example, pred) > 0.5
    print(f"\nExample {i+1}: {'✓' if correct else '✗'}")
    print(f"Question: {example.question[:60]}...")
    print(f"Predicted: {pred.answer}")
    print(f"Expected: {example.answer}")

## 7. Advanced DSPy Insights

### What We Learned

1. **MIPRO vs BootstrapFewShot**: MIPRO generates multiple candidate programs and uses Bayesian optimization to find the best one
2. **Chain-of-Thought**: Explicit reasoning steps help local models perform better
3. **Local Models**: Ollama models can achieve good performance with proper optimization
4. **Concurrent Processing**: Speeds up both optimization and evaluation

### Advanced Techniques Used

- **MIPRO Optimizer**: Multi-instructor prompt optimization
- **ChainOfThought Signature**: Forces step-by-step reasoning
- **Bayesian Optimization**: Efficiently searches prompt space
- **Parallel Processing**: Concurrent API calls for speed

### Future Enhancements

- Try different local models (Mixtral, Llama variants)
- Experiment with assertion-based optimization
- Add multi-stage reasoning pipelines
- Implement custom metrics and constraints