# 1. Setup and Dependencies #
First, install the required packages:

In [None]:
!uv add openai-agents mlflow datasets openai gepa

In [1]:
import asyncio
import os
from typing import Any

import mlflow
from agents import Agent, Runner
from datasets import load_dataset
from mlflow.entities import Feedback
from mlflow.genai import evaluate, scorer
from mlflow.genai.optimize import GepaPromptOptimizer
from mlflow.genai.judges import CategoricalRating

# Configure MLflow
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("HotpotQA Optimization")

mlflow.openai.autolog()

# Avoid hanging due to the conflict between async and threading (not necessary for sync agents)
os.environ["MLFLOW_GENAI_EVAL_MAX_WORKERS"] = "1"
# Skip trace validation to avoid NonRecordingSpan errors when evaluate() runs predict_fn with tracing disabled
os.environ["MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION"] = "True"

# If running on notebooks
import nest_asyncio
nest_asyncio.apply()

# 2. Create and Register Your Base Prompt

Start with a simple, straightforward prompt template:

In [2]:
prompt_template = """You are a question answering assistant. Answer questions based ONLY on the provided context.

IMPORTANT INSTRUCTIONS:
- For yes/no questions, answer ONLY "yes" or "no"
- Do NOT include phrases like "based on the context" or "according to the documents"

Context:
{{context}}

Question: {{question}}

Answer:"""

# Register the prompt in MLflow
base_prompt = mlflow.genai.register_prompt(
    name="hotpotqa-user-prompt",
    template=prompt_template,
)

2026/02/09 16:27:54 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for prompt version to finish creation. Prompt name: hotpotqa-user-prompt, version 2


# 3. Initialize the OpenAI Agent

Setup your agent

In [None]:
import mlflow
from dotenv import load_dotenv
from utils.clnt_utils import is_databricks_ai_gateway_client, get_databricks_ai_gateway_client, get_openai_client, get_ai_gateway_model_names

# Load environment
load_dotenv()

# Configure MLflow
mlflow.set_tracking_uri("http://localhost:5000")

# Configure client and model based on provider
use_databricks_provider = is_databricks_ai_gateway_client()
if use_databricks_provider:
    client = get_databricks_ai_gateway_client()
    model_name = get_ai_gateway_model_names()[0]
    optimizer_model = f"databricks:/{model_name}"
else:
    client = get_openai_client()
    model_name = "gpt-5.2"
    optimizer_model = f"openai:/{model_name}"

# Enable autologging
mlflow.openai.autolog()

# Runner.run() expects an Agent instance, not a model string. Create a simple agent that uses this model.
# The full prompt (context + question) is passed as the user message in predict_fn.
qa_agent = Agent(
    name="HotpotQA",
    model=model_name,
    instructions="Answer the question using only the information provided in the conversation.",
)

print("\u2705 Environment configured")
print(f"   Provider: {'Databricks AI Gateway' if use_databricks_provider else 'OpenAI'}")
print(f"   Model: {model_name}")
print(f"   Optimizer model: {optimizer_model}")
print(f"   Tracking URI: {mlflow.get_tracking_uri()}")

âœ… Environment configured
   Provider: OpenAI
   Model: gpt-5.2
   Optimizer model: openai:/gpt-5.2
   Tracking URI: http://localhost:5000


# 4. Create a Prediction Function

The prediction function formats the context and question using the prompt template, then runs the agent:

In [4]:
# Create a wrapper for `predict_fn` to run the agent with different prompts
def create_predict_fn(prompt_uri: str):
    prompt = mlflow.genai.load_prompt(prompt_uri)

    # Do not decorate with @mlflow.trace here: evaluate() runs predict_fn under trace_disabled
    # for validation, which would raise 'NonRecordingSpan' object has no attribute 'context'.
    # MLflow wraps predict_fn with tracing when running the actual evaluation.
    def predict_fn(context: str, question: str) -> str:
        """Predict function that uses the agent with the MLflow prompt."""
        # Use prompt.format() with template variables
        user_message = prompt.format(context=context, question=question)

        # Run your agent (Runner.run expects an Agent instance, not a model string)
        result = asyncio.run(Runner.run(qa_agent, user_message))

        return result.final_output

    return predict_fn

# 5. Baseline Evaluation

Before optimizing, establish a baseline by evaluating the agent on a validation set. Here, we define a simple custom scorer that compares the system outputs and expected outputs for equality, but you can use any Scorer objects. See the Scorer Overview for more information.

In [5]:
# Ensure trace validation is skipped (avoids NonRecordingSpan when predict_fn runs under trace_disabled)
os.environ["MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION"] = "True"

def prepare_hotpotqa_data(num_samples: int, split: str = "validation") -> list[dict]:
    """Load and prepare HotpotQA data for MLflow GenAI (evaluate/optimize)."""
    print(f"\nLoading HotpotQA dataset ({split} split)...")
    dataset = load_dataset("hotpot_qa", "distractor", split=split)
    dataset = dataset.select(range(0, min(num_samples, len(dataset))))

    data = []
    for example in dataset:
        # Format context from HotpotQA
        context_text = "\n\n".join([
            f"Document {i+1}: {title}\n{' '.join(sentences)}"
            for i, (title, sentences) in enumerate(zip(example["context"]["title"], example["context"]["sentences"]))
        ])

        data.append({
            "inputs": {
                "context": context_text,
                "question": example["question"],
            },
            "expectations": {
                "expected_response": example["answer"],
            }
        })

    print(f"Prepared {len(data)} samples")
    return data

# Define a scorer for exact match
@scorer
def equivalence(outputs: str, expectations: dict[str, Any]) -> Feedback:
    return Feedback(
        name="equivalence",
        value=CategoricalRating.YES if outputs == expectations["expected_response"] else CategoricalRating.NO,
    )

def run_benchmark(
    prompt_uri: str,
    num_samples: int,
    split: str = "validation",
) -> dict:
    """Run the agent on HotpotQA benchmark using mlflow.genai.evaluate()."""

    # Prepare evaluation data
    eval_data = prepare_hotpotqa_data(num_samples, split)

    # Create prediction function
    predict_fn = create_predict_fn(prompt_uri)

    # Run evaluation
    print(f"\nRunning evaluation on {len(eval_data)} samples...\n")

    results = evaluate(
        data=eval_data,
        predict_fn=predict_fn,
        scorers=[equivalence],
    )

    # Extract metrics
    accuracy = results.metrics.get("equivalence/mean", 0.0) / 100.0

    return {
        "accuracy": accuracy,
        "metrics": results.metrics,
        "results": results,
    }


# Run baseline evaluation
baseline_metrics = run_benchmark(base_prompt.uri, num_samples=100)

print(f"Baseline Accuracy: {baseline_metrics['accuracy']:.2%}")
# Output: Baseline Accuracy: 50.0%


Loading HotpotQA dataset (validation split)...




Prepared 100 samples

Running evaluation on 100 samples...



2026/02/09 16:28:29 INFO mlflow.models.evaluation.utils.trace: Auto tracing is temporarily enabled during the model evaluation for computing some metrics and debugging. To disable tracing, call `mlflow.autolog(disable=True)`.


Evaluating:   0%|          | 0/100 [Elapsed: 00:00, Remaining: ?] 

Baseline Accuracy: 0.58%


# 6. Optimize the Prompt

Now comes the exciting part - using MLflow to automatically improve the prompt:

In [None]:
# Prepare training data using shared function
train_data = prepare_hotpotqa_data(num_samples=100, split="train")

# Run optimization
result = mlflow.genai.optimize_prompts(
    predict_fn=create_predict_fn(base_prompt.uri),
    train_data=train_data,
    prompt_uris=[base_prompt.uri],
    optimizer=GepaPromptOptimizer(
        reflection_model="openai:/gpt-4o",
        max_metric_calls=500,
    ),
    scorers=[equivalence],
    enable_tracking=True,
)

# Get the optimized prompt URI
optimized_prompt_uri = result.optimized_prompts[0].uri
print(f"  Base prompt: {base_prompt.uri}")
print(f"  Optimized prompt: {optimized_prompt_uri}")


Loading HotpotQA dataset (train split)...
Prepared 100 samples


  return _dataset_source_registry.resolve(
