In [0]:
%pip install --upgrade mlflow==3.9.0 openai datasets huggingface_hub fsspec

# MemAlign Tutorial: Aligning LLM Judges with Human Feedback

This tutorial demonstrates how to improve LLM judge accuracy using MLflow's **MemAlign** optimizer.

## What You'll Learn
1. How to evaluate LLM outputs using **judges** (scorers)
2. How to provide **human feedback** on judge assessments
3. How to use **MemAlign** to align judges with human preferences
4. How to compare judge performance before and after alignment

## Key Concepts
- **Judge/Scorer**: An LLM that evaluates another LLM's outputs (e.g., "Is this classification correct?")
- **Assessment**: The judge's evaluation result (e.g., "Yes" or "No")
- **Trace**: A recorded LLM interaction including inputs, outputs, and assessments
- **Alignment**: Improving a judge's instructions based on human feedback so it better matches human judgment

## Data Source
We'll use the [PubMed text classification dataset](https://huggingface.co/datasets/ml4pubmed/pubmed-text-classification-cased) where sentences from medical abstracts are classified into categories: BACKGROUND, OBJECTIVE, METHODS, RESULTS, or CONCLUSIONS.

In [0]:
import json
import mlflow
from mlflow.entities import AssessmentSource, AssessmentSourceType

mlflow.set_registry_uri("your MLflow server")
mlflow.set_experiment(<"refer to your experiment">)
mlflow.openai.autolog()


# Step 1: Load Evaluation Dataset

First, we load a sample of sentences from the PubMed dataset. Each sentence has a ground truth label that we'll use to measure judge accuracy.

In [0]:
from datasets import load_dataset
import random

# Load PubMed text classification dataset
ds = load_dataset("ml4pubmed/pubmed-text-classification-cased", cache_dir='/tmp/huggingface')
split = ds["train"] if "train" in ds else ds[list(ds.keys())[0]]

# Sample 30 examples for evaluation
# random.seed(42)  # For reproducibility
idxs = random.sample(range(len(split)), 50)

samples_list = [
    {"text": split[i]["description_cln"], "answer": split[i]["target"]}
    for i in idxs
]

# Create ground truth lookup: sentence text -> correct label
ground_truth = {s["text"].strip(): s["answer"] for s in samples_list}

print(f"Loaded {len(samples_list)} evaluation samples")
print(f"Labels: {set(s['answer'] for s in samples_list)}")
print(f"\nExample:")
print(f"  Text: {samples_list[0]['text'][:100]}...")
print(f"  Label: {samples_list[0]['answer']}")

# Step 2: Format Dataset for Evaluation

MLflow's `evaluate()` function expects data in a specific format. For this dataset, we are following the `Correctness()` judge's schema. We format each sample with:
- `inputs`: The conversation history (user message with the sentence to classify)
- `expectations`: What we expect from a correct response (used by the Correctness judge)

In [0]:
# Format for MLflow evaluate() - compatible with built-in Correctness judge
eval_dataset_records = [
    {
        "inputs": {
            "input": [{"role": "user", "content": ["text"]}]
        },
        "expectations": {
            "expected_facts": [
                "Classification label must be 'CONCLUSIONS', 'RESULTS', 'METHODS', 'OBJECTIVE', or 'BACKGROUND'"
            ]
        }
    }
    for s in samples_list
]

print(f"Prepared {len(eval_dataset_records)} evaluation records")
print(f"\nSample record structure:")
print(f"  inputs: {{input: [{{role: 'user', content: '<sentence>'}}]}}")
print(f"  expectations: {{expected_facts: ['<constraint>']}}")

# Step 3: Define the Prediction Function

This is the LLM we're evaluating. It takes a sentence and classifies it into one of the 5 PubMed categories.
The judges will evaluate whether this LLM's classifications are correct.

In [0]:

client = "your LLM client and provider call"


def predict_fn(input: list[dict]) -> dict:
    """
    Classify a sentence into one of 5 PubMed categories.

    Args:
        input: Conversation history, e.g., [{"role": "user", "content": "..."}]

    Returns:
        Dict with "response" key containing the classification
    """
    user_text = input[0]["content"]

    prompt = (
        "Classify the sentence into exactly one label from: "
        "METHODS, RESULTS, CONCLUSIONS, BACKGROUND, OBJECTIVE.\n\n"
        f"Sentence: {user_text}"
    )

    completion = client.chat.completions.create(
        model="your model",
        messages=[{"role": "user", "content": prompt}],
    )
    return completion.choices[0].message.content

# Step 4: Define Reusable Helper Functions

These helper functions will be used throughout the tutorial to:
- Analyze judge accuracy by comparing assessments to ground truth
- Log human feedback for judge alignment
- Print summaries and comparisons

**Key insight**: A judge is "correct" when it says "Yes" for correct LLM answers and "No" for incorrect ones.

In [0]:
import json
import mlflow

def normalize_label(x):
    """Normalize a label for comparison (uppercase, stripped)."""
    return str(x).strip().upper() if x else None


def extract_sentence_from_prompt(request: dict) -> str:
    """
    Extract the original sentence from a trace request.

    The prediction function formats prompts as:
    "Classify the sentence... Sentence: <TEXT>"

    This extracts <TEXT> for matching against ground truth.
    """
    try:
        messages = request.get("messages", [])
        if not messages:
            return ""
        content = messages[0].get("content", "")
        if "Sentence:" in content:
            return content.split("Sentence:", 1)[1].strip()
        return content.strip()
    except (IndexError, AttributeError):
        return ""


def extract_llm_response(response: dict) -> str:
    """Extract the LLM's classification from a trace response."""
    try:
        return response["choices"][0]["message"]["content"]
    except (KeyError, IndexError, TypeError):
        return None


def analyze_judge_accuracy(assessment_name: str, ground_truth: dict) -> dict:
    """
    Analyze how well a judge's assessments match ground truth.

    A judge is "correct" when:
    - It says "Yes" and the LLM's answer matches ground truth
    - It says "No" and the LLM's answer doesn't match ground truth

    Args:
        assessment_name: Name of the judge/assessment to analyze (e.g., "correctness")
        ground_truth: Dict mapping sentence text -> correct label

    Returns:
        Dict containing:
        - accuracy: Fraction of correct judge assessments
        - total: Number of traces analyzed
        - correct: Number of correct assessments
        - incorrect: Number of incorrect assessments
        - matches: List of traces where judge was correct
        - mismatches: List of traces where judge was wrong
    """
    traces = mlflow.search_traces(max_results=30, return_type="list")

    matches, mismatches = [], []

    for trace in traces:
        # Parse trace data
        request = json.loads(trace.data.request) if trace.data.request else {}
        response = json.loads(trace.data.response) if trace.data.response else {}

        # Extract components
        sentence = extract_sentence_from_prompt(request)
        llm_answer = extract_llm_response(response)
        true_answer = ground_truth.get(sentence)

        # Skip if we can't match to ground truth
        if not true_answer or not llm_answer:
            continue

        # Get judge's assessment for this trace
        assessments = trace.search_assessments(name=assessment_name)
        judge_assessments = [a for a in assessments if a.source.source_type == "LLM_JUDGE"]
        if not judge_assessments:
            continue

        judge_feedback = str(judge_assessments[0].feedback.value).strip().lower()
        judge_rationale = judge_assessments[0].rationale

        # Determine correctness
        llm_correct = normalize_label(llm_answer) == normalize_label(true_answer)
        judge_said_correct = judge_feedback == "yes"
        judge_was_right = llm_correct == judge_said_correct

        record = {
            "trace_id": trace.info.trace_id,
            "sentence": sentence,
            "llm_answer": llm_answer,
            "true_answer": true_answer,
            "llm_correct": llm_correct,
            "judge_feedback": judge_feedback,
            "judge_rationale": judge_rationale,
            "judge_correct": judge_was_right,
        }

        if judge_was_right:
            matches.append(record)
        else:
            mismatches.append(record)

    total = len(matches) + len(mismatches)
    accuracy = len(matches) / total if total > 0 else 0

    return {
        "accuracy": accuracy,
        "total": total,
        "correct": len(matches),
        "incorrect": len(mismatches),
        "matches": matches,
        "mismatches": mismatches,
    }


def print_accuracy_summary(name: str, results: dict):
    """Print a formatted summary of judge accuracy."""
    print(f"\n{'='*60}")
    print(f"Judge: {name}")
    print(f"{'='*60}")
    print(f"Accuracy: {results['accuracy']:.1%} ({results['correct']}/{results['total']})")
    print(f"  - Correct assessments: {results['correct']}")
    print(f"  - Incorrect assessments: {results['incorrect']}")

    if results['mismatches']:
        print(f"\nExamples where judge was WRONG:")
        for m in results['mismatches'][:3]:  # Show first 3
            status = "correct" if m['llm_correct'] else "wrong"
            print(f"  - LLM said '{m['llm_answer']}' (actually {status}), judge said '{m['judge_feedback']}'")


def log_feedback_interactive(records: list, judge_name: str, reviewer: AssessmentSource):
    """
    Interactively collect human feedback for all traces.

    For each trace, shows the context and asks for:
    - Feedback value (Yes/No)
    - Optional comment explaining the reasoning

    This feedback is used by MemAlign to improve judge instructions.
    """
    for i, record in enumerate(records, 1):
        print(f"\n{'='*50}")
        print(f"Record {i}/{len(records)}")
        print(f"{'='*50}")
        print(f"Sentence: {record['sentence']}")
        print(f"LLM Answer: {record['llm_answer']}")
        print(f"True Answer: {record['true_answer']}")
        print(f"LLM was: {'CORRECT' if record['llm_correct'] else 'WRONG'}")
        print(f"Judge said: '{record['judge_feedback']}'")
        print(f"Judge was: {'CORRECT' if record['judge_correct'] else 'WRONG'}")

        suggested_value = "Yes" if record["llm_correct"] else "No"
        print(f"\nSuggested feedback: '{suggested_value}'")

        # Ask for feedback value
        user_value = input(f"Your feedback (Yes/No, or press Enter for '{suggested_value}'): ").strip()
        if not user_value:
            user_value = suggested_value
        elif user_value.lower() in ["yes", "y"]:
            user_value = "Yes"
        elif user_value.lower() in ["no", "n"]:
            user_value = "No"

        # Ask for comment
        comment = input("Comment (optional, press Enter to skip): ").strip()

        mlflow.log_feedback(
            trace_id=record["trace_id"],
            name=judge_name,
            value=user_value,
            rationale=comment if comment else f"Human feedback: {user_value}",
            source=reviewer,
        )
        print(f"Logged: {user_value}")

    print(f"\n{'='*50}")
    print(f"Logged feedback for {len(records)} traces")
    print(f"{'='*50}")


# Dictionary to store results for final comparison
judge_results = {}

print("Helper functions defined successfully!")

# Step 5: Evaluate with Built-in Correctness Judge

MLflow provides a built-in `Correctness` judge that evaluates whether responses meet expected criteria.
Let's see how well it performs on our classification task. 
We will use this as comparison with a general purpose judge vs a custom judge

In [0]:
from mlflow.genai.scorers import Correctness
from mlflow.genai import evaluate

print("Running evaluation with Correctness judge...")
print("This will classify each sentence and have the judge evaluate the results.\n")

correctness_eval_results = evaluate(
    data=eval_dataset_records,
    predict_fn=predict_fn,
    scorers=[Correctness()],
)

## Analyze Correctness Judge Results

Now let's see how accurate the Correctness judge was. Remember:
- The judge is "correct" if it says "Yes" when the LLM was right, or "No" when wrong
- The judge is "wrong" if it says "Yes" when the LLM was actually wrong, or vice versa

In [0]:
correctness_results = analyze_judge_accuracy("correctness", ground_truth)
print_accuracy_summary("Correctness (built-in)", correctness_results)

# Store for final comparison
judge_results["Correctness (built-in)"] = correctness_results["accuracy"]

# Step 6: Create a Custom Judge

Now let's create our own judge specifically designed for this PubMed classification task.
We'll use `make_judge` to define custom evaluation instructions.

In [0]:
from mlflow.genai.judges import make_judge
from mlflow.genai.scorers import ScorerSamplingConfig

JUDGE_NAME = "pubmed_classifier"

pubmed_classification_judge = make_judge(
    name=JUDGE_NAME,
    instructions=(
        "Evaluate if the response in {{ outputs }} appropriately classifies the sentence "
        "in {{ inputs }}. The expectations are: {{ expectations }}. "
        "Your grading criteria should be: "
        "Yes: This is the correct classification. "
        "No: This is the wrong classification. "
        "Explain your rationale."
    ),
    feedback_value_type=str,
)

# Register the judge (or update if already exists)
try:
    registered_judge = pubmed_classification_judge.register()
    print(f"Registered new judge: {JUDGE_NAME}")
except ValueError as e:
    if "has already been registered" in str(e):
        registered_judge = pubmed_classification_judge.update(
            sampling_config=ScorerSamplingConfig(sample_rate=1)
        )
        print(f"Updated existing judge: {JUDGE_NAME}")
    else:
        raise

## Evaluate with Custom Judge

In [0]:
print(f"Running evaluation with custom '{JUDGE_NAME}' judge...")

custom_eval_results = evaluate(
    data=eval_dataset_records,
    predict_fn=predict_fn,
    scorers=[pubmed_classification_judge],
)

## Analyze Custom Judge Results

In [0]:
custom_results = analyze_judge_accuracy(JUDGE_NAME, ground_truth)
print_accuracy_summary(f"Custom ({JUDGE_NAME})", custom_results)

# Store for final comparison
judge_results[f"Custom ({JUDGE_NAME})"] = custom_results["accuracy"]

##Mandatory: Provide Human Feedback on Custom Judge

We need to provide some feedback on the incorrect assessments. This is used by our judge alignment optimizers to improve the judge

**Important**: The feedback `name` must match the judge name so that the optimizers know what and where the base judge is

In [0]:
from mlflow.entities import AssessmentSource, AssessmentSourceType

reviewer = AssessmentSource(
    source_type=AssessmentSourceType.HUMAN,
    source_id="your name or identifier",
)

# Combine all traces for interactive feedback
all_custom_traces = custom_results["matches"] + custom_results["mismatches"]
print(f"Found {len(all_custom_traces)} traces to review")
print(f"  - {len(custom_results['matches'])} where judge was correct")
print(f"  - {len(custom_results['mismatches'])} where judge was wrong")

log_feedback_interactive(all_custom_traces, JUDGE_NAME, reviewer)

# Step 7: Align the Judge with MemAlign

**MemAlign** is an optimizer that improves judge instructions based on human feedback.
It analyzes cases where the judge disagreed with humans and refines the instructions
to better capture human judgment criteria.

For alignment to work, we need traces that have both:
- LLM_JUDGE assessments (from running evaluate)
- HUMAN assessments (from our feedback above)

## Find Traces with Both Judge and Human Feedback

In [0]:
traces_for_alignment = mlflow.search_traces(return_type="list", max_results=100)

JUDGE_NAME = "pubmed_classifier"

valid_traces = []
for trace in traces_for_alignment:
    feedbacks = trace.search_assessments(name=JUDGE_NAME)
    has_judge = any(f.source.source_type == "LLM_JUDGE" for f in feedbacks)
    has_human = any(f.source.source_type == "HUMAN" for f in feedbacks)
    if has_judge and has_human:
        valid_traces.append(trace)

print(f"Total traces found: {len(traces_for_alignment)}")
print(f"Traces with both judge + human feedback: {len(valid_traces)}")

if len(valid_traces) < 5:
    print("\nWarning: Few traces available for alignment. Results may be limited.")

## Run MemAlign Optimization

MemAlign will:
1. Analyze disagreements between judge and human feedback
2. Identify patterns in what the judge got wrong
3. Generate improved instructions that better match human judgment

In [0]:
import os
from mlflow.genai.judges.optimizers import MemAlignOptimizer

os.environ["OPENAI_API_KEY"] = ""  # Set if using OpenAI, otherwise set for the provider you're using

print("Original judge instructions:")
print(pubmed_classification_judge.instructions)
print("\n" + "="*60)
print("Running MemAlign optimization...")
print("="*60)

# Run alignment
aligned_judge = pubmed_classification_judge.align(
    traces=valid_traces,
    optimizer=MemAlignOptimizer(
        reflection_lm="openai:/gpt-5-2",
        embedding_model="openai/text-embedding-3-large",
    )
)

## Compare Original vs Aligned Instructions

In [0]:
print("="*60)
print("ORIGINAL INSTRUCTIONS:")
print("="*60)
print(pubmed_classification_judge.instructions)

print("\n" + "="*60)
print("ALIGNED INSTRUCTIONS (after MemAlign):")
print("="*60)
print(aligned_judge.instructions)

# Step 8: Evaluate with Aligned Judge

Now let's see if the aligned judge performs better!

In [0]:
# Register the aligned judge as a new scorer
ALIGNED_JUDGE_NAME = f"{JUDGE_NAME}_MemAligned"

aligned_judge_scorer = make_judge(
    name=ALIGNED_JUDGE_NAME,
    instructions=f"{aligned_judge.instructions}",
    feedback_value_type=str,
)

try:
    aligned_judge_scorer.register()
    print(f"Registered aligned judge: {ALIGNED_JUDGE_NAME}")
except ValueError as e:
    if "has already been registered" in str(e):
        aligned_judge_scorer.update(sampling_config=ScorerSamplingConfig(sample_rate=1))
        print(f"Updated existing aligned judge: {ALIGNED_JUDGE_NAME}")
    else:
        raise

## Run Evaluation with Aligned Judge

In [0]:
print(f"Running evaluation with aligned '{ALIGNED_JUDGE_NAME}' judge...")

aligned_eval_results = evaluate(
    data=eval_dataset_records,
    predict_fn=predict_fn,
    scorers=[aligned_judge_scorer],
)

## Analyze Aligned Judge Results

In [0]:
aligned_results = analyze_judge_accuracy(ALIGNED_JUDGE_NAME, ground_truth)
print_accuracy_summary(f"MemAligned ({ALIGNED_JUDGE_NAME})", aligned_results)

# Store for final comparison
judge_results[f"MemAligned ({JUDGE_NAME})"] = aligned_results["accuracy"]

# Step 9: Final Comparison

Let's compare all three judges side-by-side to see the impact of MemAlign!

In [0]:
import pandas as pd

# Create comparison DataFrame
comparison_data = [
    {"Judge": name, "Accuracy": f"{acc:.1%}", "Accuracy (raw)": acc}
    for name, acc in judge_results.items()
]

comparison_df = pd.DataFrame(comparison_data)

print("\n" + "="*60)
print("FINAL COMPARISON: Judge Accuracy")
print("="*60)
print("\nHow often did each judge correctly identify right/wrong LLM answers?\n")

display(comparison_df[["Judge", "Accuracy"]])

## Summary

In this tutorial, we:

1. **Evaluated an LLM** on a PubMed text classification task
2. **Used judges** to assess whether classifications were correct
3. **Compared judge accuracy** against ground truth labels
4. **Provided human feedback** to teach judges what they got wrong
5. **Applied MemAlign** to improve judge instructions based on feedback
6. **Measured improvement** in judge accuracy after alignment

### Key Takeaways

- **Judges aren't perfect**: Even well-designed judges can make mistakes
- **Human feedback is valuable**: It captures nuances that initial instructions miss
- **MemAlign automates improvement**: It learns from disagreements to refine instructions
- **Iteration helps**: Multiple rounds of feedback and alignment can continue improving judges