# DSPy as a Service - GSM8K Math Optimization with GEPA

This notebook demonstrates how to:
1. Submit an optimization job using the GEPA optimizer on GSM8K math dataset
2. Monitor progress in real-time
3. Retrieve and use the optimized program

GSM8K is a dataset of grade school math word problems requiring multi-step reasoning.
GEPA (Reflective Prompt Evolution) can achieve up to 93% accuracy on math benchmarks.

In [None]:
# TODO: On-premise - Update pip index URL to local artifactory
# Example: !pip install -q dspy requests --index-url https://artifactory.your-company.com/api/pypi/pypi-remote/simple
!pip install -q dspy requests

In [None]:
import base64
import inspect
import json
import os
import pickle
import textwrap
import time
from pathlib import Path
from typing import Any

import dspy
import requests

## Configuration

In [None]:
# TODO: On-premise - Update BASE_URL to your internal DSPy service endpoint
BASE_URL = os.getenv("DSPY_SERVICE_URL", "http://localhost:8000")

# API key from environment variable (required)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("OPENAI_API_KEY environment variable is required")

# TODO: On-premise - Update model config for local LLM provider
# Example for local provider:
#   "name": "ollama/llama3" or "local/mistral"
#   "base_url": "http://your-local-llm:11434/v1"
#   "extra": {} (no API key needed)
MODEL_CONFIG = {
    "name": "openai/gpt-4o-mini",
    "base_url": "https://api.openai.com/v1",
    "model_type": "responses",
    "temperature": 1.0,
    "max_tokens": 20000,
    "extra": {"api_key": OPENAI_API_KEY},
}

# TODO: On-premise - Update dspy.LM config to match MODEL_CONFIG above
dspy.configure(
    lm=dspy.LM(
        "openai/gpt-4o-mini",
        model_type="responses",
        temperature=1.0,
        max_tokens=20000,
        api_key=OPENAI_API_KEY,
    )
)

## API Client

A simple client for interacting with the DSPy service.

In [None]:
class DSPyServiceClient:
    """Client for the DSPy optimization service.
    
    Provides methods to submit jobs, check status, and retrieve results.
    
    Args:
        base_url: Root URL of the DSPy service API.
    """
    
    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url.rstrip("/")
    
    def health(self) -> dict:
        """Check service health status.
        
        Returns:
            dict: Health status with registered assets.
        """
        return requests.get(f"{self.base_url}/health").json()
    
    def submit(self, payload: dict) -> str:
        """Submit an optimization job.
        
        Args:
            payload: Job configuration including module, optimizer, and dataset.
        
        Returns:
            str: The job ID for tracking progress.
        
        Raises:
            requests.HTTPError: If submission fails.
        """
        resp = requests.post(f"{self.base_url}/run", json=payload)
        resp.raise_for_status()
        return resp.json()["job_id"]
    
    def status(self, job_id: str) -> dict:
        """Get current job status.
        
        Args:
            job_id: The job identifier.
        
        Returns:
            dict: Full job status including progress events and result.
        
        Raises:
            requests.HTTPError: If job not found.
        """
        resp = requests.get(f"{self.base_url}/jobs/{job_id}")
        resp.raise_for_status()
        return resp.json()
    
    def summary(self, job_id: str) -> dict:
        """Get lightweight job summary.
        
        Args:
            job_id: The job identifier.
        
        Returns:
            dict: Summary with timing and configuration info.
        """
        return requests.get(f"{self.base_url}/jobs/{job_id}/summary").json()
    
    def logs(self, job_id: str) -> list:
        """Get job execution logs.
        
        Args:
            job_id: The job identifier.
        
        Returns:
            list: Log entries with timestamp, level, and message.
        """
        return requests.get(f"{self.base_url}/jobs/{job_id}/logs").json()
    
    def artifact(self, job_id: str) -> dict:
        """Get the optimized program artifact.
        
        Args:
            job_id: The job identifier.
        
        Returns:
            dict: Artifact with base64-encoded program and metadata.
        
        Raises:
            requests.HTTPError: If job not complete or failed.
        """
        resp = requests.get(f"{self.base_url}/jobs/{job_id}/artifact")
        resp.raise_for_status()
        return resp.json()
    
    def load_program(self, job_id: str) -> dspy.Module:
        """Load the optimized program from a completed job.
        
        Args:
            job_id: The job identifier.
        
        Returns:
            dspy.Module: The optimized DSPy module ready for inference.
        
        Raises:
            requests.HTTPError: If artifact not available.
            ValueError: If artifact missing program data.
        """
        artifact = self.artifact(job_id)
        pickle_b64 = artifact["program_artifact"]["program_pickle_base64"]
        return pickle.loads(base64.b64decode(pickle_b64))


    def submit_df(self, df: "pd.DataFrame", column_mapping: dict, **kwargs) -> str:
        """Submit an optimization job using a pandas DataFrame.
        
        Automatically converts the DataFrame to the required format and validates
        that all columns specified in column_mapping exist in the DataFrame.
        
        Args:
            df: pandas DataFrame containing the dataset.
            column_mapping: Column mapping with 'inputs' and 'outputs' dicts.
                Example: {"inputs": {"question": "question_col"}, 
                         "outputs": {"answer": "answer_col"}}
            **kwargs: All other parameters for the optimization job 
                (module_name, optimizer_name, signature_code, metric_code, 
                 model_config, etc.).
        
        Returns:
            str: The job ID for tracking progress.
        
        Raises:
            ValueError: If required columns are missing from the DataFrame.
            TypeError: If df is not a DataFrame.
            ImportError: If pandas is not installed.
        """
        # Import and type check
        try:
            import pandas as pd
        except ImportError:
            raise ImportError(
                "pandas is required to use submit_df(). "
                "Install it with: pip install pandas"
            )
        
        if not isinstance(df, pd.DataFrame):
            raise TypeError(f"Expected pandas DataFrame, got {type(df).__name__}")
        
        # Validate non-empty
        if df.empty:
            raise ValueError("DataFrame must contain at least one row")
        
        # Extract required columns from mapping
        required_columns = set()
        if "inputs" in column_mapping:
            required_columns.update(column_mapping["inputs"].values())
        if "outputs" in column_mapping:
            required_columns.update(column_mapping["outputs"].values())
        
        # Check column existence
        df_columns = set(df.columns)
        missing_columns = required_columns - df_columns
        if missing_columns:
            raise ValueError(
                f"DataFrame missing required columns: {sorted(missing_columns)}. "
                f"Available columns: {sorted(df_columns)}"
            )
        
        # Convert to records format
        dataset = df.to_dict('records')
        
        # Build payload and delegate to submit()
        payload = {"dataset": dataset, "column_mapping": column_mapping, **kwargs}
        return self.submit(payload)
    
    def submit_df_simple(self, df: "pd.DataFrame", input_cols: list, output_cols: list, **kwargs) -> str:
        """Submit DataFrame with automatic 1:1 column mapping.
        
        Convenience method for when DataFrame columns match signature field names exactly.
        
        Args:
            df: pandas DataFrame containing the dataset.
            input_cols: List of column names to use as inputs.
            output_cols: List of column names to use as outputs.
            **kwargs: All other optimization job parameters.
        
        Returns:
            str: The job ID for tracking progress.
        
        Example:
            job_id = client.submit_df_simple(
                df=my_df,
                input_cols=["question"],
                output_cols=["answer"],
                module_name="dspy.ChainOfThought",
                optimizer_name="dspy.GEPA",
                signature_code=SIGNATURE_CODE,
                metric_code=METRIC_CODE,
                model_config=MODEL_CONFIG
            )
        """
        column_mapping = {
            "inputs": {col: col for col in input_cols},
            "outputs": {col: col for col in output_cols}
        }
        return self.submit_df(df, column_mapping, **kwargs)
    

client = DSPyServiceClient(BASE_URL)
client.health()

## Job Monitor

Clean progress monitoring with formatted output.

In [None]:
class JobMonitor:
    """Monitor job progress with formatted console output.
    
    Args:
        client: DSPyServiceClient instance.
        job_id: The job identifier to monitor.
    """
    
    def __init__(self, client: "DSPyServiceClient", job_id: str):
        self.client = client
        self.job_id = job_id
        self._printed_events = 0
        self._printed_logs = 0
    
    def poll(self, interval: int = 3, timeout: int = None, verbose: bool = True) -> dict:
        """Poll until job completes.
        
        Args:
            interval: Seconds between status checks.
            timeout: Maximum seconds to wait (None for unlimited).
            verbose: Whether to print status updates.
        
        Returns:
            dict: Final job status with result or error.
        
        Raises:
            TimeoutError: If job doesn't complete within timeout (when set).
        """
        start = time.time()
        
        while True:
            status = self.client.status(self.job_id)
            elapsed = time.time() - start
            
            if verbose:
                self._print_status(status, elapsed)
            
            if status["status"] in {"success", "failed"}:
                return status
            
            if timeout is not None and elapsed > timeout:
                raise TimeoutError("Job " + self.job_id + " timed out after " + str(timeout) + "s")
            
            time.sleep(interval)
    
    def _print_status(self, status: dict, elapsed: float) -> None:
        """Print formatted status line with new events and logs.
        
        Args:
            status: Current job status dict.
            elapsed: Seconds since polling started.
        """
        ts = time.strftime("%H:%M:%S")
        metrics = status.get("latest_metrics", {})
        
        print("[" + ts + "] " + status["status"].upper().ljust(12) + " | elapsed: " + str(int(elapsed)) + "s | " + self._format_metrics(metrics))
        
        events = status.get("progress_events", [])
        for event in events[self._printed_events:]:
            self._print_event(event)
        self._printed_events = len(events)
        
        logs = status.get("logs", [])
        for log in logs[self._printed_logs:]:
            self._print_log(log)
        self._printed_logs = len(logs)
    
    @staticmethod
    def _format_metrics(metrics: dict) -> str:
        """Format metrics dict for display.
        
        Args:
            metrics: Latest metrics from job status.
        
        Returns:
            str: Formatted metrics string.
        """
        if not metrics:
            return ""
        
        parts = []
        if "tqdm_percent" in metrics:
            parts.append(str(round(metrics["tqdm_percent"], 1)) + "%")
        if "tqdm_n" in metrics and "tqdm_total" in metrics:
            parts.append(str(metrics["tqdm_n"]) + "/" + str(metrics["tqdm_total"]))
        if "tqdm_desc" in metrics:
            parts.append(metrics["tqdm_desc"])
        if "baseline_test_metric" in metrics:
            parts.append("baseline: " + str(round(metrics["baseline_test_metric"], 2)))
        if "optimized_test_metric" in metrics:
            parts.append("optimized: " + str(round(metrics["optimized_test_metric"], 2)))
        
        return " | ".join(parts) if parts else ""
    
    @staticmethod
    def _print_event(event: dict) -> None:
        """Print a progress event.
        
        Args:
            event: Progress event dict with name and metrics.
        """
        name = event.get("event", "progress")
        metrics = event.get("metrics", {})
        
        if name == "optimizer_progress" and "tqdm_desc" in metrics:
            desc = metrics["tqdm_desc"]
            pct = metrics.get("tqdm_percent", 0)
            n = metrics.get("tqdm_n", 0)
            total = metrics.get("tqdm_total", "?")
            print("       " + str(desc) + ": " + str(round(pct, 1)) + "% (" + str(n) + "/" + str(total) + ")")
        elif metrics:
            print("       " + str(name) + ": " + str(metrics))
    
    @staticmethod
    def _print_log(log: dict) -> None:
        """Print a log entry.
        
        Args:
            log: Log entry dict with level and message.
        """
        level = log.get("level", "INFO")
        msg = log.get("message", "")
        if msg and level in {"INFO", "WARNING", "ERROR"}:
            print("       [" + level + "] " + msg)

## Dataset & Signature

Load the GSM8K dataset - grade school math word problems requiring multi-step reasoning.

In [None]:
DATA_PATH = Path("data/gsm8k.json")
with open(DATA_PATH) as f:
    DATASET = json.load(f)

print(f"Loaded {len(DATASET)} examples")
print(f"Sample: {DATASET[0]}")

In [None]:
class MathReasoning(dspy.Signature):
    """Solve grade school math word problems step by step."""
    question: str = dspy.InputField(desc="A math word problem requiring multi-step reasoning")
    answer: str = dspy.OutputField(desc="The final numeric answer")


# Define metric as SOURCE STRING - bypasses inspect.getsource() issues in containers
METRIC_CODE = '''
def gsm8k_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
    """Score math prediction with feedback for GEPA reflection.

    Args:
        gold: Ground truth example with expected answer.
        pred: Model prediction with generated answer.
        trace: Optional execution trace.
        pred_name: Optional predictor name.
        pred_trace: Optional predictor trace.

    Returns:
        dspy.Prediction: Contains score (0.0-1.0) and feedback text.
    """
    import re

    def extract_number(text):
        """Extract the last number from text, handling commas."""
        if not text:
            return ""
        numbers = re.findall(r'-?[\\d,]+\\.?\\d*', text.replace(',', ''))
        return numbers[-1] if numbers else text.strip()

    expected = extract_number(gold.answer or "")
    actual = extract_number(pred.answer or "")

    if expected == actual:
        return dspy.Prediction(score=1.0, feedback="Correct answer.")
    else:
        feedback = "Incorrect. Expected " + str(gold.answer) + ", got " + str(pred.answer) + ". Check each arithmetic step carefully."
        return dspy.Prediction(score=0.0, feedback=feedback)
'''

# Execute to get callable for local testing
exec(METRIC_CODE)

In [None]:
def serialize_source(obj: Any) -> str:
    """Extract source code from a function or DSPy Signature class.
    
    For metrics: define as string (METRIC_CODE = '''...''') and pass directly.
    For Signatures: reconstructs from metadata automatically.
    
    Args:
        obj: A function or dspy.Signature subclass.
    
    Returns:
        str: Python source code as a string.
    
    Raises:
        RuntimeError: If source cannot be extracted.
    """
    # Handle DSPy Signatures via metadata reconstruction
    if isinstance(obj, type) and issubclass(obj, dspy.Signature):
        doc = obj.__doc__ or ""
        lines = [
            "class " + obj.__name__ + "(dspy.Signature):",
            '    """' + doc + '"""',
        ]
        for name, field in obj.model_fields.items():
            extra = field.json_schema_extra or {}
            ftype = "InputField" if extra.get("__dspy_field_type") == "input" else "OutputField"
            desc = extra.get("desc", "")
            lines.append("    " + name + ': str = dspy.' + ftype + '(desc="' + desc + '")')
        return "\n".join(lines)
    
    # For functions with __source_code__ attribute (from decorator)
    if hasattr(obj, '__source_code__'):
        return obj.__source_code__
    
    # Fallback to inspect (may fail in containers)
    try:
        return textwrap.dedent(inspect.getsource(obj)).strip()
    except (OSError, TypeError) as e:
        raise RuntimeError(
            "Cannot extract source from " + str(obj) + ". "
            "In containers, define metric as string: METRIC_CODE = '''def metric(...): ...'''"
        ) from e


SIGNATURE_CODE = serialize_source(MathReasoning)
# METRIC_CODE is already defined as a string above - no serialization needed!

print("Signature:")
print(SIGNATURE_CODE)
print("\nMetric:")
print(METRIC_CODE)

## Build Payload

Configure GEPA optimizer for prompt evolution on GSM8K math problems.

In [None]:
payload = {
    "module_name": "dspy.ChainOfThought",
    "signature_code": SIGNATURE_CODE,
    "metric_code": METRIC_CODE,
    "optimizer_name": "dspy.GEPA",
    "optimizer_kwargs": {
        "auto": "light",
        "num_threads": 8,
        "reflection_minibatch_size": 3,
    },
    "compile_kwargs": {},
    "dataset": DATASET,
    "column_mapping": {
        "inputs": {"question": "question"},
        "outputs": {"answer": "answer"},
    },
    "split_fractions": {"train": 0.5, "val": 0.3, "test": 0.2},
    "shuffle": True,
    "seed": 42,
    "model_config": MODEL_CONFIG,
    "reflection_model_config": MODEL_CONFIG,
}

print(f"Module: {payload['module_name']}")
print(f"Optimizer: {payload['optimizer_name']}")
print(f"Dataset: {len(payload['dataset'])} examples")

## Submit & Monitor Job

In [None]:
job_id = client.submit(payload)
print(f"Submitted job: {job_id}")

In [None]:
monitor = JobMonitor(client, job_id)
result = monitor.poll(interval=3)

print("\nFinal status: " + result["status"])

## View Results

In [None]:
def print_results(result: dict) -> None:
    """Print optimization results summary.
    
    Args:
        result: Final job status dict from polling.
    """
    if result["status"] == "success":
        r = result["result"]
        print(f"Baseline score:  {r.get('baseline_test_metric', 'N/A')}")
        print(f"Optimized score: {r.get('optimized_test_metric', 'N/A')}")
        print(f"Runtime: {r.get('runtime_seconds', 0):.1f}s")
    else:
        print(f"Job failed: {result.get('message')}")


print_results(result)

In [None]:
client.summary(job_id)

In [None]:
def print_recent_logs(client: "DSPyServiceClient", job_id: str, n: int = 5) -> None:
    """Print the most recent log entries.
    
    Args:
        client: DSPyServiceClient instance.
        job_id: The job identifier.
        n: Number of recent logs to display.
    
    Returns:
        None.
    """
    logs = client.logs(job_id)
    print(f"Total log entries: {len(logs)}")
    for log in logs[-n:]:
        print(f"  [{log['level']}] {log['message'][:80]}")


print_recent_logs(client, job_id)

## Load & Test Optimized Program

In [None]:
program = client.load_program(job_id)
print(f"Loaded program: {type(program).__name__}")

In [None]:
def test_program(program: dspy.Module, questions: list[str]) -> None:
    """Run test questions through the optimized program.
    
    Args:
        program: The optimized DSPy module.
        questions: List of test questions to run.
    """
    for q in questions:
        response = program(question=q)
        print(f"Q: {q}")
        print(f"A: {response.answer}\n")


test_questions = [
    "A bakery sells 24 cupcakes in the morning and 36 in the afternoon. If each cupcake costs $3, how much money did the bakery make?",
    "Lisa has 48 stickers. She gives 1/4 of them to her friend and then buys 12 more. How many stickers does she have now?",
    "A train travels at 60 mph for 2 hours, then at 80 mph for 1.5 hours. What is the total distance traveled?",
]

test_program(program, test_questions)

In [None]:
dspy.inspect_history(n=1)