In [0]:
%pip install --upgrade mlflow>=3.9.0 openai datasets huggingface_hub fsspec litellm dspy

# MemAlign Tutorial: Aligning LLM Judges with Human Feedback

This tutorial demonstrates how to improve LLM judge accuracy using MLflow's **MemAlign** optimizer.

## What You'll Learn
1. How to evaluate LLM outputs using **judges** (scorers)
2. How to provide **human feedback** on judge assessments
3. How to use **MemAlign** to align judges with human preferences
4. How to compare judge performance before and after alignment

## Key Concepts
- **Judge/Scorer**: An LLM that evaluates another LLM's outputs (e.g., "Is this classification correct?")
- **Assessment**: The judge's evaluation result (e.g., "Yes" or "No")
- **Trace**: A recorded LLM interaction including inputs, outputs, and assessments
- **Alignment**: Improving a judge's instructions based on human feedback so it better matches human judgment

## Data Source
We'll use the [PubMed text classification dataset](https://huggingface.co/datasets/ml4pubmed/pubmed-text-classification-cased) where sentences from medical abstracts are classified into categories: BACKGROUND, OBJECTIVE, METHODS, RESULTS, or CONCLUSIONS.

In [0]:
import json
import mlflow
from mlflow.entities import AssessmentSource, AssessmentSourceType

mlflow.set_tracking_uri("your MLflow server")
mlflow.set_experiment(<"refer to your experiment">)
mlflow.openai.autolog()


# Step 0: Set your Variables!

In [None]:
import os
from openai import OpenAI

os.environ["OPENAI_API_KEY"] = "your OpenAI API key if you're using OpenAI"

client = "your LLM client and provider call"

# OpenAI Client Example
LLM_NAME = "gpt-5-2"
client = OpenAI(model=LLM_NAME)

g# Step 1: Load Evaluation Dataset

First, we load a sample of sentences from the PubMed dataset. Each sentence has a ground truth label that we'll use to measure judge accuracy.

In [None]:
from datasets import load_dataset
import random

ds = load_dataset("ml4pubmed/pubmed-text-classification-cased", cache_dir='/tmp/huggingface')
split = ds["train"] if "train" in ds else ds[list(ds.keys())[0]]

idxs = random.sample(range(len(split)), 50)

samples_list = [
    {"id": str(i), "text": split[idx]["description_cln"], "answer": split[idx]["target"]}
    for i, idx in enumerate(idxs)
]

ground_truth = {s["id"]: {"text": s["text"], "answer": s["answer"]} for s in samples_list}

print(f"Loaded {len(samples_list)} evaluation samples")
print(f"Labels: {set(s['answer'] for s in samples_list)}")
print(f"\nExample:")
print(f"  ID: {samples_list[0]['id']}")
print(f"  Text: {samples_list[0]['text'][:100]}...")
print(f"  Label: {samples_list[0]['answer']}")

# Step 2: Format Dataset for Evaluation

MLflow's `evaluate()` function expects data in a specific format. We format each sample with:
- `inputs`: The conversation history (user message with the sentence to classify) and the sample ID for tracking
- `expectations`: Ground truth information used by the custom judge to evaluate correctness

In [None]:
eval_dataset_records = [
    {
        "inputs": {
            "input": [{"role": "user", "content": s["text"]}],
            "sample_id": s["id"],
        },
        "expectations": {
            "expected_facts": [
                f"The correct classification for this sentence is: {s['answer']}",
                "Classification label must be one of: CONCLUSIONS, RESULTS, METHODS, OBJECTIVE, or BACKGROUND"
            ]
        }
    }
    for s in samples_list
]

print(f"Prepared {len(eval_dataset_records)} evaluation records")
print(f"\nSample record structure:")
print(f"  inputs: {{input: [{{role: 'user', content: '<sentence>'}}], sample_id: '<id>'}}")
print(f"  expectations: {{expected_facts: ['The correct classification...', '...']}}")

# Step 3: Define the Prediction Function

This is the LLM we're evaluating. It takes a sentence and classifies it into one of the 5 PubMed categories.
The judges will evaluate whether this LLM's classifications are correct.

In [None]:
@mlflow.trace
def predict_fn(input: list[dict], sample_id: str) -> dict:
    """
    Classify a sentence into one of 5 PubMed categories.
    """

    mlflow.update_current_trace(metadata={"sample_id": sample_id})

    user_text = input[0]["content"]

    prompt = (
        "Classify the sentence into exactly one label from: "
        "METHODS, RESULTS, CONCLUSIONS, BACKGROUND, OBJECTIVE.\n\n"
        f"Sentence: {user_text}\n\n"
        "Respond with json containing a single key 'label' with your classification."
    )

    completion = client.chat.completions.create(
        model=LLM_NAME,
        messages=[{"role": "user", "content": prompt}],
        response_format={"type": "json_object"},
    )
    return completion.choices[0].message.content

# Step 4: Define Reusable Helper Functions

These helper functions will be used throughout the tutorial to:
- Analyze judge accuracy by comparing assessments to ground truth
- Log human feedback for judge alignment
- Print summaries and comparisons
- Automatically add human feedback with another LLM (but manual option is provided)

**Key insight**: A judge is "correct" when it says "Yes" for correct LLM answers and "No" for incorrect ones.

In [None]:
import json
import mlflow


def normalize_label(x):
    if not x:
        return None
    return str(x).strip().upper()


def extract_classification_label(response) -> str | None:
    try:
        if isinstance(response, str):
            content = response
        elif isinstance(response, dict):
            if "choices" in response:
                content = response["choices"][0]["message"]["content"]
            else:
                content = response
        else:
            return None

        if isinstance(content, str):
            parsed = json.loads(content)
        else:
            parsed = content

        label = parsed.get("label") or parsed.get("classification") or parsed.get("answer")
        if label:
            return normalize_label(label)

        for value in parsed.values():
            if isinstance(value, str) and value.upper() in {"METHODS", "RESULTS", "CONCLUSIONS", "BACKGROUND", "OBJECTIVE"}:
                return normalize_label(value)
        return None
    except (KeyError, IndexError, TypeError, json.JSONDecodeError):
        return None


def judge_approved(feedback_value) -> bool:
    if feedback_value is None:
        return False
    if isinstance(feedback_value, bool):
        return feedback_value
    val = str(feedback_value).strip().lower()
    return val in {"yes", "true", "pass", "correct", "1"}


def analyze_judge_accuracy(assessment_name: str, ground_truth: dict) -> dict:
    traces = mlflow.search_traces(max_results=100, return_type="list")

    matches, mismatches = [], []

    for trace in traces:
        # Get sample_id from trace metadata
        trace_metadata = getattr(trace.info, "trace_metadata", None) or {}
        sample_id = trace_metadata.get("sample_id")
        if not sample_id:
            continue

        # Look up ground truth by ID
        sample_data = ground_truth.get(sample_id)
        if not sample_data:
            continue

        sentence = sample_data["text"]
        true_label = normalize_label(sample_data["answer"])

        # Parse trace response to get LLM's classification
        # Note: trace.data.response might be a JSON string or already parsed
        response = trace.data.response
        if isinstance(response, str):
            try:
                response = json.loads(response)
            except json.JSONDecodeError:
                pass
        llm_label = extract_classification_label(response)
        if not llm_label:
            continue

        # Get judge's assessment for this trace
        assessments = trace.search_assessments(name=assessment_name)
        judge_assessments = [a for a in assessments if a.source.source_type == "LLM_JUDGE"]
        if not judge_assessments:
            continue

        assessment = judge_assessments[0]
        judge_feedback_value = assessment.feedback.value
        judge_rationale = assessment.rationale

        # Determine if LLM was actually correct (compared to ground truth)
        llm_was_correct = llm_label == true_label

        # Determine if judge said the LLM was correct
        judge_said_correct = judge_approved(judge_feedback_value)

        # Judge is right if its assessment matches reality
        judge_was_right = llm_was_correct == judge_said_correct

        record = {
            "trace_id": trace.info.trace_id,
            "sample_id": sample_id,
            "sentence": sentence,
            "llm_label": llm_label,
            "true_label": true_label,
            "llm_was_correct": llm_was_correct,
            "judge_feedback": judge_feedback_value,
            "judge_said_correct": judge_said_correct,
            "judge_rationale": judge_rationale,
            "judge_was_right": judge_was_right,
        }

        if judge_was_right:
            matches.append(record)
        else:
            mismatches.append(record)

    total = len(matches) + len(mismatches)
    accuracy = len(matches) / total if total > 0 else 0

    return {
        "accuracy": accuracy,
        "total": total,
        "correct": len(matches),
        "incorrect": len(mismatches),
        "matches": matches,
        "mismatches": mismatches,
    }


def print_accuracy_summary(name: str, results: dict):
    print(f"\n{'='*60}")
    print(f"Judge: {name}")
    print(f"{'='*60}")
    print(f"Accuracy: {results['accuracy']:.1%} ({results['correct']}/{results['total']})")
    print(f"  - Correct assessments: {results['correct']}")
    print(f"  - Incorrect assessments: {results['incorrect']}")

    if results['mismatches']:
        print(f"\nExamples where judge was WRONG:")
        for m in results['mismatches'][:3]:
            llm_status = "correct" if m['llm_was_correct'] else "wrong"
            print(f"  - LLM: '{m['llm_label']}' vs Truth: '{m['true_label']}' (LLM was {llm_status})")
            print(f"    Judge said: '{m['judge_feedback']}' (approved: {m['judge_said_correct']})")


def log_feedback_interactive(records: list, judge_name: str, reviewer: AssessmentSource):
    for i, record in enumerate(records, 1):
        print(f"\n{'='*50}")
        print(f"Record {i}/{len(records)}")
        print(f"{'='*50}")
        print(f"Sample ID: {record['sample_id']}")
        print(f"Sentence: {record['sentence']}")
        print(f"LLM Label: {record['llm_label']}")
        print(f"True Label: {record['true_label']}")
        print(f"LLM was: {'CORRECT' if record['llm_was_correct'] else 'WRONG'}")
        print(f"Judge feedback: '{record['judge_feedback']}'")
        print(f"Judge was: {'CORRECT' if record['judge_was_right'] else 'WRONG'}")

        suggested_value = "Yes" if record["llm_was_correct"] else "No"
        print(f"\nSuggested feedback: '{suggested_value}'")

        user_value = input(f"Your feedback (Yes/No, or press Enter for '{suggested_value}'): ").strip()
        if not user_value:
            user_value = suggested_value
        elif user_value.lower() in ["yes", "y"]:
            user_value = "Yes"
        elif user_value.lower() in ["no", "n"]:
            user_value = "No"

        comment = input("Comment (optional, press Enter to skip): ").strip()

        mlflow.log_feedback(
            trace_id=record["trace_id"],
            name=judge_name,
            value=user_value,
            rationale=comment if comment else f"Human feedback: {user_value}",
            source=reviewer,
        )
        print(f"Logged: {user_value}")

    print(f"\n{'='*50}")
    print(f"Logged feedback for {len(records)} traces")
    print(f"{'='*50}")

def log_feedback_automatic(records: list, judge_name: str, reviewer: AssessmentSource):
    print(f"Automatically generating feedback for {len(records)} traces...")

    for i, record in enumerate(records, 1):
        # Determine correct feedback based on ground truth
        feedback_value = "Yes" if record["llm_was_correct"] else "No"

        # Generate rationale using LLM
        rationale_prompt = f"""You are reviewing an LLM classification task. Generate a brief rationale (1-2 sentences) explaining why the feedback should be "{feedback_value}".

Sentence to classify: {record['sentence'][:200]}...
LLM's classification: {record['llm_label']}
Ground truth label: {record['true_label']}
LLM was {'CORRECT' if record['llm_was_correct'] else 'INCORRECT'}

Provide a concise rationale for why the feedback is "{feedback_value}"."""

        try:
            completion = client.chat.completions.create(
                model=LLM_NAME,
                messages=[{"role": "user", "content": rationale_prompt}],
                max_tokens=100,
            )
            rationale = completion.choices[0].message.content.strip()
        except Exception as e:
            # Fallback rationale if LLM call fails
            if record["llm_was_correct"]:
                rationale = f"Correct: LLM classified as '{record['llm_label']}' which matches ground truth '{record['true_label']}'."
            else:
                rationale = f"Incorrect: LLM classified as '{record['llm_label']}' but ground truth is '{record['true_label']}'."

        # Log the feedback
        mlflow.log_feedback(
            trace_id=record["trace_id"],
            name=judge_name,
            value=feedback_value,
            rationale=rationale,
            source=reviewer,
        )

        if i % 10 == 0 or i == len(records):
            print(f"  Processed {i}/{len(records)} traces")

    print(f"\nAutomatically logged feedback for {len(records)} traces")
    print(f"  - Correct (Yes): {sum(1 for r in records if r['llm_was_correct'])}")
    print(f"  - Incorrect (No): {sum(1 for r in records if not r['llm_was_correct'])}")

judge_results = {}

# Step 5: Evaluate with Built-in Guidelines Judge

MLflow provides a built-in `Guidelines` judge that evaluates whether responses follow specific constraints or instructions.
Let's see how well it performs on our classification task when we give it guidelines about valid classification labels.
We will use this as a comparison with a general-purpose judge vs a custom judge.

In [None]:
from mlflow.genai.scorers import Guidelines
from mlflow.genai import evaluate

classification_guidelines = Guidelines(
    name="classification_guidelines",
    guidelines=[
        "The response must be exactly one of: CONCLUSIONS, RESULTS, METHODS, OBJECTIVE, or BACKGROUND",
        "The classification must be appropriate for the medical/scientific sentence provided",
    ],
)

print("Running evaluation with Guidelines judge...")
print("This will classify each sentence and have the judge evaluate the results.\n")

guidelines_eval_results = evaluate(
    data=eval_dataset_records,
    predict_fn=predict_fn,
    scorers=[classification_guidelines],
)

## Analyze Guidelines Judge Results

Now let's see how accurate the Guidelines judge was. Remember:
- The judge is "correct" if it says "Yes" when the LLM was right, or "No" when wrong
- The judge is "wrong" if it says "Yes" when the LLM was actually wrong, or vice versa

In [None]:
guidelines_results = analyze_judge_accuracy("classification_guidelines", ground_truth)
print_accuracy_summary("Guidelines (built-in)", guidelines_results)
judge_results["Guidelines (built-in)"] = guidelines_results["accuracy"]

# Step 6: Create a Custom Judge

Now let's create our own judge specifically designed for this PubMed classification task.
We'll use `make_judge` to define custom evaluation instructions.

In [None]:
from mlflow.genai.judges import make_judge
from mlflow.genai.scorers import get_scorer
from mlflow.genai.scorers import ScorerSamplingConfig

JUDGE_NAME = "pubmed_classifier"

pubmed_classification_judge = make_judge(
    name=JUDGE_NAME,
    instructions=(
        "Evaluate if the response in {{ outputs }} appropriately classifies the sentence "
        "in {{ inputs }}. The expectations are: {{ expectations }}. "
        "Your grading criteria should be: "
        "Yes: This is the correct classification. "
        "No: This is the wrong classification. "
        "Explain your rationale."
    ),
    feedback_value_type=str,
)

try:
    registered_judge = get_scorer(name=JUDGE_NAME)
    registered_judge = registered_judge.update(
        sampling_config=ScorerSamplingConfig(sample_rate=1.0)
    )
    print(f"Updated existing judge: {JUDGE_NAME}")
except Exception:
    registered_judge = pubmed_classification_judge.register(name=JUDGE_NAME)
    print(f"Registered new judge: {JUDGE_NAME}")

## Evaluate with Custom Judge

In [0]:
print(f"Running evaluation with custom '{JUDGE_NAME}' judge...")

custom_eval_results = evaluate(
    data=eval_dataset_records,
    predict_fn=predict_fn,
    scorers=[pubmed_classification_judge],
)

## Analyze Custom Judge Results

In [0]:
custom_results = analyze_judge_accuracy(JUDGE_NAME, ground_truth)
print_accuracy_summary(f"Custom ({JUDGE_NAME})", custom_results)
judge_results[f"Custom ({JUDGE_NAME})"] = custom_results["accuracy"]

## Mandatory: Provide Human Feedback on Custom Judge

We need to provide some feedback on the incorrect assessments. This is used by our judge alignment optimizers to improve the judge

**Important**: The feedback `name` must match the judge name so that the optimizers know what and where the base judge is

In [None]:
from mlflow.entities import AssessmentSource, AssessmentSourceType

reviewer = AssessmentSource(
    source_type=AssessmentSourceType.HUMAN,
    source_id="auto_feedback_generator",
)

# Combine all traces for feedback
all_custom_traces = custom_results["matches"] + custom_results["mismatches"]
print(f"Found {len(all_custom_traces)} traces to provide feedback on")
print(f"  - {len(custom_results['matches'])} where judge was correct")
print(f"  - {len(custom_results['mismatches'])} where judge was wrong")

# Use automatic feedback generation (uses LLM to generate rationales)
log_feedback_automatic(all_custom_traces, JUDGE_NAME, reviewer)

# Alternative: Use interactive feedback (uncomment to manually review each trace)
# log_feedback_interactive(all_custom_traces, JUDGE_NAME, reviewer)

# Step 7: Align the Judge with MemAlign

**MemAlign** is an optimizer that improves judge instructions based on human feedback.
It analyzes cases where the judge disagreed with humans and refines the instructions
to better capture human judgment criteria.

For alignment to work, we need traces that have both:
- LLM_JUDGE assessments (from running evaluate)
- HUMAN assessments (from our feedback above)

## Find Traces with Both Judge and Human Feedback

In [None]:
traces_for_alignment = mlflow.search_traces(return_type="list", max_results=100)

JUDGE_NAME = "pubmed_classifier"

candidate_trace_ids = []
for trace in traces_for_alignment:
    assessments = trace.search_assessments(name=JUDGE_NAME)
    has_judge = any(a.source.source_type == "LLM_JUDGE" for a in assessments)
    has_human = any(a.source.source_type == "HUMAN" for a in assessments)
    if has_judge and has_human:
        candidate_trace_ids.append(trace.info.trace_id)

print(f"Found {len(candidate_trace_ids)} traces with both judge + human assessments")

valid_traces = []
for trace_id in candidate_trace_ids:
    full_trace = mlflow.get_trace(trace_id)
    if full_trace and full_trace.info.assessments:
        valid_traces.append(full_trace)

print(f"Traces ready for alignment: {len(valid_traces)}")

## Run MemAlign Optimization

MemAlign will:
1. Analyze disagreements between judge and human feedback
2. Identify patterns in what the judge got wrong
3. Generate improved instructions that better match human judgment

In [None]:
from mlflow.genai.judges.optimizers import MemAlignOptimizer

print("Original judge instructions:")
print(pubmed_classification_judge.instructions)
print("\n" + "="*60)
print("Running MemAlign optimization...")
print("="*60)

# Run alignment
aligned_judge = pubmed_classification_judge.align(
    traces=valid_traces,
    optimizer=MemAlignOptimizer(
        reflection_lm="databricks:/databricks-gpt-5-2",
        embedding_model="openai/text-embedding-3-large",
    )
)

## Compare Original vs Aligned Instructions

In [0]:
print("="*60)
print("ORIGINAL INSTRUCTIONS:")
print("="*60)
print(pubmed_classification_judge.instructions)

print("\n" + "="*60)
print("ALIGNED INSTRUCTIONS (after MemAlign):")
print("="*60)
print(aligned_judge.instructions)

# Step 8: Evaluate with Aligned Judge

Now let's see if the aligned judge performs better!

In [0]:
# Register the aligned judge as a new scorer
ALIGNED_JUDGE_NAME = f"{JUDGE_NAME}_MemAligned"

aligned_judge_scorer = make_judge(
    name=ALIGNED_JUDGE_NAME,
    instructions=f"{aligned_judge.instructions}",
    feedback_value_type=str,
)

try:
    registered_judge = get_scorer(name=ALIGNED_JUDGE_NAME)
    registered_judge = registered_judge.update(
        sampling_config=ScorerSamplingConfig(sample_rate=1.0)
    )
    print(f"Updated existing judge: {ALIGNED_JUDGE_NAME}")
except Exception:
    registered_judge = pubmed_classification_judge.register(name=ALIGNED_JUDGE_NAME)
    print(f"Registered new judge: {ALIGNED_JUDGE_NAME}")

## Run Evaluation with Aligned Judge

In [0]:
print(f"Running evaluation with aligned '{ALIGNED_JUDGE_NAME}' judge...")

aligned_eval_results = evaluate(
    data=eval_dataset_records,
    predict_fn=predict_fn,
    scorers=[aligned_judge_scorer],
)

## Analyze Aligned Judge Results

In [0]:
aligned_results = analyze_judge_accuracy(ALIGNED_JUDGE_NAME, ground_truth)
print_accuracy_summary(f"MemAligned ({ALIGNED_JUDGE_NAME})", aligned_results)

# Store for final comparison
judge_results[f"MemAligned ({JUDGE_NAME})"] = aligned_results["accuracy"]

# Step 9: Final Comparison

Let's compare all three judges side-by-side to see the impact of MemAlign!

In [0]:
import pandas as pd

# Create comparison DataFrame
comparison_data = [
    {"Judge": name, "Accuracy": f"{acc:.1%}", "Accuracy (raw)": acc}
    for name, acc in judge_results.items()
]

comparison_df = pd.DataFrame(comparison_data)

print("\n" + "="*60)
print("FINAL COMPARISON: Judge Accuracy")
print("="*60)
print("\nHow often did each judge correctly identify right/wrong LLM answers?\n")

display(comparison_df[["Judge", "Accuracy"]])

## Summary

In this tutorial, we:

1. **Evaluated an LLM** on a PubMed text classification task
2. **Used judges** to assess whether classifications were correct
3. **Compared judge accuracy** against ground truth labels
4. **Provided human feedback** to teach judges what they got wrong
5. **Applied MemAlign** to improve judge instructions based on feedback
6. **Measured improvement** in judge accuracy after alignment

### Key Takeaways

- **Judges aren't perfect**: Even well-designed judges can make mistakes
- **Human feedback is valuable**: It captures nuances that initial instructions miss
- **MemAlign automates improvement**: It learns from disagreements to refine instructions
- **Iteration helps**: Multiple rounds of feedback and alignment can continue improving judges