# Fine-tuned vs Base Model Evaluation

This notebook loads the merged fine-tuned model from disk alongside the original base model and evaluates both on the PII masking validation split.


In [None]:
import os
import re
from pathlib import Path
from collections import Counter
from datetime import datetime
from typing import Optional

import torch
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# Mitigate CUDA memory fragmentation for long-running sessions
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

torch.set_grad_enabled(False)


In [None]:
MODEL_ID = "unsloth/gemma-2-2b-it"
FINE_TUNED_DIR = Path("pii_merged_model")
DEVICE_MAP = "auto"
TORCH_DTYPE = torch.float16
MAX_NEW_TOKENS = 256

if not FINE_TUNED_DIR.exists():
    raise FileNotFoundError(f"Expected fine-tuned model at {FINE_TUNED_DIR.resolve()}")

tokenizer = AutoTokenizer.from_pretrained(FINE_TUNED_DIR)

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})



In [None]:
def build_pii_prompt(source_text: str) -> str:
    instruction = (
        "You are a data privacy assistant. Mask all personally identifiable information (PII) "
        "in the following text using the masking scheme shown in the examples. "
        "Output only the masked text.\n\n"
        "Examples:\n\n"
        "Input: My name is John Smith and my email is john.smith@email.com\n"
        "Output: My name is [FIRSTNAME] [LASTNAME] and my email is [EMAIL]\n\n"
        "Input: Call me at 555-123-4567 or visit 123 Main Street, Boston MA 02101\n"
        "Output: Call me at [PHONENUMBER] or visit [STREET] [CITY] [STATE] [ZIPCODE]\n\n"
        "Input: My username is alice_2023 and I was born on 03/15/1990\n"
        "Output: My username is [USERNAME] and I was born on [DOB]\n\n"
        "Input: The SSN is 123-45-6789 and account number is ACC98765\n"
        "Output: The SSN is [SSN] and account number is [ACCOUNTNUMBER]"
    )
    return (
        f"<start_of_turn>user\n{instruction}\n\nText:\n{source_text}\n<end_of_turn>\n"
        f"<start_of_turn>model\n"
    )



In [None]:
def load_validation_split() -> "datasets.arrow_dataset.Dataset":
    pii_ds = load_dataset("ai4privacy/pii-masking-300k")

    if "language" in pii_ds["train"].column_names:
        pii_ds = pii_ds.filter(lambda ex: ex.get("language", "English") == "English")

    if "set" in pii_ds["train"].column_names and "validation" in pii_ds:
        valid_split = pii_ds["validation"].filter(
            lambda ex: ex.get("set", "validation") == "validation"
        )
    else:
        split = pii_ds["train"].train_test_split(test_size=0.02, seed=42)
        valid_split = split["test"]

    required = {"source_text", "target_text"}
    missing = required.difference(valid_split.column_names)
    if missing:
        raise KeyError(f"Validation split missing columns: {sorted(missing)}")

    def add_prompt_fields(example):
        example = dict(example)
        example["prompt"] = build_pii_prompt(example["source_text"])
        example["target"] = example["target_text"]
        return example

    valid_split = valid_split.map(add_prompt_fields)
    return valid_split


valid_split = load_validation_split()
print(f"Validation examples: {len(valid_split)}")
print({key: valid_split[key][0][:200] for key in ["prompt", "target"]})


In [None]:
def char_f1(pred: str, gold: str) -> float:
    pc = Counter(pred)
    gc = Counter(gold)
    overlap = sum((pc & gc).values())
    if overlap == 0:
        return 0.0
    precision = overlap / max(1, sum(pc.values()))
    recall = overlap / max(1, sum(gc.values()))
    if precision + recall == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)


def compute_masking_metrics(prediction: str, spans) -> dict:
    total_spans = len(spans) if spans is not None else 0
    leaked_values = []
    gold_label_counts = Counter()

    if spans is None:
        spans = []

    normalized_spans = []
    for span in spans:
        if isinstance(span, dict):
            normalized_spans.append(span)
        elif isinstance(span, (list, tuple)):
            value = span[0] if len(span) > 0 else ""
            label = span[1] if len(span) > 1 else ""
            normalized_spans.append({"value": value, "label": label})
        elif isinstance(span, str):
            normalized_spans.append({"value": span, "label": ""})

    lower_prediction = prediction.lower()
    for span in normalized_spans:
        label = span.get("label", "") or ""
        value = span.get("value", "") or ""
        gold_label_counts[label] += 1
        if value and value.lower() in lower_prediction:
            leaked_values.append(value)

    total_spans = len(normalized_spans)

    pred_labels = re.findall(r"\[([A-Z0-9_]+)\]", prediction)
    pred_label_counts = Counter(pred_labels)

    matched_placeholders = sum(
        min(pred_label_counts[label], gold_label_counts[label])
        for label in gold_label_counts
    )

    placeholder_precision = (
        matched_placeholders / max(1, sum(pred_label_counts.values()))
        if pred_label_counts
        else 1.0 if total_spans == 0 else 0.0
    )
    placeholder_recall = matched_placeholders / max(1, total_spans)

    masked_spans = total_spans - len(leaked_values)
    mask_recall = masked_spans / total_spans if total_spans > 0 else 1.0

    extra_placeholders = sum(
        max(pred_label_counts[label] - gold_label_counts.get(label, 0), 0)
        for label in pred_label_counts
    )
    missing_placeholders = total_spans - matched_placeholders

    return {
        "total_spans": total_spans,
        "masked_spans": masked_spans,
        "mask_recall": mask_recall,
        "placeholder_precision": placeholder_precision,
        "placeholder_recall": placeholder_recall,
        "extra_placeholders": extra_placeholders,
        "missing_placeholders": max(missing_placeholders, 0),
        "leaked_values": leaked_values,
        "gold_label_counts": gold_label_counts,
        "pred_label_counts": pred_label_counts,
    }


@torch.no_grad()
def generate_masked_text(model: AutoModelForCausalLM, source_text: str) -> str:
    prompt = build_pii_prompt(source_text)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    gen = model.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
    )
    out = tokenizer.decode(gen[0], skip_special_tokens=True)
    if "<start_of_turn>model" in out:
        out = out.split("<start_of_turn>model")[-1]
    return out.strip()


def evaluate_model(
    model: AutoModelForCausalLM,
    dataset,
    model_name: str,
    n_samples: int = 200,
    csv_path: Optional[Path] = None,
):
    model.eval()
    n = min(n_samples, len(dataset))
    rows = []
    f1_scores = []
    mask_recalls = []
    placeholder_precisions = []
    placeholder_recalls = []

    print(f"Evaluating {model_name} on {n} samples...")
    for idx in range(n):
        sample = dataset[idx]
        source = sample["source_text"]
        target = sample["target_text"]
        spans = sample.get("span_labels")

        prediction = generate_masked_text(model, source)
        f1 = char_f1(prediction, target)
        metrics = compute_masking_metrics(prediction, spans)

        f1_scores.append(f1)
        mask_recalls.append(metrics["mask_recall"])
        placeholder_precisions.append(metrics["placeholder_precision"])
        placeholder_recalls.append(metrics["placeholder_recall"])

        rows.append(
            {
                "sample_id": idx,
                "model": model_name,
                "source_text": source,
                "target_text": target,
                "prediction": prediction,
                "f1_score": f1,
                "mask_recall": metrics["mask_recall"],
                "placeholder_precision": metrics["placeholder_precision"],
                "placeholder_recall": metrics["placeholder_recall"],
                "total_spans": metrics["total_spans"],
                "masked_spans": metrics["masked_spans"],
                "extra_placeholders": metrics["extra_placeholders"],
                "missing_placeholders": metrics["missing_placeholders"],
                "leaked_count": len(metrics["leaked_values"]),
                "leaked_values": " || ".join(metrics["leaked_values"]),
                "source_length": len(source),
                "target_length": len(target),
                "prediction_length": len(prediction),
                "timestamp": datetime.now().isoformat(),
            }
        )
        if (idx + 1) % 20 == 0:
            print(f"  Processed {idx + 1}/{n}")

    avg_f1 = float(sum(f1_scores) / n) if n > 0 else 0.0
    avg_mask_recall = float(sum(mask_recalls) / n) if n > 0 else 0.0
    avg_placeholder_precision = float(sum(placeholder_precisions) / n) if n > 0 else 0.0
    avg_placeholder_recall = float(sum(placeholder_recalls) / n) if n > 0 else 0.0

    results_df = pd.DataFrame(rows)
    if csv_path is not None:
        results_df.to_csv(csv_path, index=False)
        print(f"Saved detailed results to {csv_path}")

    print(
        f"Average F1 ({model_name}): {avg_f1:.4f}\n"
        f"Mask recall ({model_name}): {avg_mask_recall:.4f}\n"
        f"Placeholder precision ({model_name}): {avg_placeholder_precision:.4f}\n"
        f"Placeholder recall ({model_name}): {avg_placeholder_recall:.4f}\n"
    )

    return (
        avg_f1,
        avg_mask_recall,
        avg_placeholder_precision,
        avg_placeholder_recall,
        results_df,
    )



In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map=DEVICE_MAP,
    torch_dtype=TORCH_DTYPE,
    low_cpu_mem_usage=True,
)

fine_tuned_model = AutoModelForCausalLM.from_pretrained(
    FINE_TUNED_DIR,
    device_map=DEVICE_MAP,
    torch_dtype=TORCH_DTYPE,
    low_cpu_mem_usage=True,
)

for mdl, name in [(base_model, "Base"), (fine_tuned_model, "Fine-tuned")]:
    mdl.eval()
    print(f"Loaded {name} model on: {mdl.device}")


In [None]:
EVAL_SAMPLES = 100
OUTPUT_DIR = Path(".")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

fine_tuned_csv = OUTPUT_DIR / "pii_finetuned_results_eval.csv"
base_csv = OUTPUT_DIR / "pii_base_results_eval.csv"
comparison_csv = OUTPUT_DIR / "pii_comparison_results_eval.csv"

(
    fine_avg_f1,
    fine_mask_recall,
    fine_placeholder_precision,
    fine_placeholder_recall,
    fine_df,
) = evaluate_model(
    fine_tuned_model,
    valid_split,
    model_name="fine-tuned",
    n_samples=EVAL_SAMPLES,
    csv_path=fine_tuned_csv,
)

(
    base_avg_f1,
    base_mask_recall,
    base_placeholder_precision,
    base_placeholder_recall,
    base_df,
) = evaluate_model(
    base_model,
    valid_split,
    model_name="base",
    n_samples=EVAL_SAMPLES,
    csv_path=base_csv,
)

delta_f1 = fine_avg_f1 - base_avg_f1
print(f"Fine-tuned improvement over base (F1): {delta_f1:.4f}")

comparison_df = pd.DataFrame(
    {
        "model": ["fine-tuned", "base"],
        "avg_f1": [fine_avg_f1, base_avg_f1],
        "mask_recall": [fine_mask_recall, base_mask_recall],
        "placeholder_precision": [
            fine_placeholder_precision,
            base_placeholder_precision,
        ],
        "placeholder_recall": [
            fine_placeholder_recall,
            base_placeholder_recall,
        ],
    }
)
comparison_df.to_csv(comparison_csv, index=False)
print(f"Saved comparison summary to {comparison_csv}")
comparison_df
