# Text Anonymization Model Training Pipeline

This notebook demonstrates how to fine-tune a T5 model for text anonymization using the Hugging Face ecosystem. The pipeline includes:

1. Loading and preprocessing synthetic data
2. Setting up the model and tokenizer
3. Training configuration and process
4. Model evaluation and metrics
5. Inference examples

## Setup and Dependencies

In [1]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    TrainingArguments,
    Trainer
)
import torch
import gc
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import logging
import re
from collections import Counter

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)

# # Set random seed for reproducibility
# torch.manual_seed(42)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(42)

## 1. Data Loading

We'll use our synthetic dataset from the Hugging Face Hub. This dataset contains pairs of original and anonymized texts, perfect for training our model.

In [None]:
# Load dataset
dataset = load_dataset("kurkowski/synthetic-contextual-anonymizer-dataset")

print("Dataset Statistics:")
print(f"Training examples: {len(dataset['train'])}")
print(f"Validation examples: {len(dataset['validation'])}")
print(f"Test examples: {len(dataset['test'])}")


# Display a sample
print("\nSample from training set:")
print("Original:", dataset['train'][0]['context'])
print("Anonymized:", dataset['train'][0]['anonymized_context'])
print("Used labels:", dataset['train'][0]['used_labels'], type(dataset['train'][0]['used_labels']))

print(dataset)

## 2. Model Setup

We'll use FLAN-T5-small as our base model. This is a good choice because:
- It's relatively small and fast to train
- It has good text-to-text capabilities
- It's been trained on a variety of tasks

In [4]:
# Initialize tokenizer and model
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

## 3. Data Preprocessing

We need to convert our text examples into a format suitable for the model. This includes:
- Adding a task-specific prompt
- Tokenizing inputs and targets
- Creating attention masks

In [None]:
def normalize_labels(text):
    """
    Usuwa numerację z etykiet (np. [NAME_1] -> [NAME]).
    Obsługuje również etykiety z podkreślnikami w nazwie (np. POLICY_NUMBER).
    """
    if isinstance(text, list):
        return [normalize_labels(t) for t in text]
    if not isinstance(text, str):
        return text
    return re.sub(r'\[([A-Z_]+)_\d+\]', r'[\1]', text)

def create_anonymization_prompt(labels):
    """Create a prompt for text anonymization task.
    
    Args:
        labels: List of labels to use for anonymization
        
    Returns:
        String containing the formatted prompt
    """
    return f"""You are a text anonymization expert. Your task is to replace sensitive information with the following labels: { normalize_labels(labels)}.

    Instructions:
    1. Replace each sensitive information with appropriate label from the provided list
    2. For multiple occurrences of the same type, use numbered labels (e.g. [NAME_1], [NAME_2])
    3. Preserve the original text structure and meaning
    4. Follow the examples precisely

    Example:
    Input: "John Smith called Mary Johnson. John's number is 555-0123 and Mary's is 555-4567."
    Output: "[NAME_1] called [NAME_2]. [NAME_1]'s number is [PHONE_1] and [NAME_2]'s is [PHONE_2]."

    Task:
    Anonymize the following text using only these labels: { normalize_labels(labels)}
    Input: 
    """

def convert_examples_to_features(example_batch):
    """Convert text examples to model features.
    
    Args:
        example_batch: Batch of examples from dataset
        
    Returns:
        Dictionary with input_ids, attention_mask, and labels
    """
    input_texts = []
    for text, labels in zip(example_batch["context"], example_batch["used_labels"]):
        prompt = create_anonymization_prompt(labels)
        input_texts.append(prompt + text)
    
    print('input_texts:')
    for text in input_texts[:3]:
        print(text)
    input_encodings = tokenizer(input_texts, truncation=True, padding=True)
    with tokenizer.as_target_tokenizer():
        target_encodings = tokenizer(example_batch["anonymized_context"], truncation=True, padding=True)
    
    return {
        "input_ids": input_encodings["input_ids"],
        "attention_mask": input_encodings["attention_mask"], 
        "labels": target_encodings["input_ids"]
    }

test_prompt = create_anonymization_prompt(dataset['train'][0]['used_labels'])

# Process all splits
processed_dataset = dataset.map(
    convert_examples_to_features,
    batched=True,
    desc="Processing dataset"
)

In [None]:
def test_anonymization(text_to_anonymize, labels):
    prompt = create_anonymization_prompt(labels)
    print(prompt + text_to_anonymize)
    inputs = tokenizer(
        prompt + text_to_anonymize, 
        return_tensors="pt",  # Tutaj jest OK używać return_tensors="pt"
        truncation=True
    )
    
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=512,
        temperature=0.1
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


print(test_anonymization(processed_dataset['train'][0]['context'], processed_dataset['train'][0]['used_labels']))

## 4. Training Configuration

We'll configure the training process with optimal parameters for our task. Key considerations include:
- Memory efficiency (batch size and gradient accumulation)
- Learning rate and warmup
- Evaluation strategy

In [None]:

seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer_args = TrainingArguments(
    output_dir = "anonymizer_model", 
    num_train_epochs=3, 
    warmup_steps = 500,
    per_device_train_batch_size=4,      
    per_device_eval_batch_size=4,      
    weight_decay=0.01, 
    logging_steps=2,
    eval_strategy="steps",
    eval_steps=250,
    save_steps=250,
    gradient_accumulation_steps=2,      
    learning_rate=5e-5,
    save_total_limit=2,
    load_best_model_at_end=True,
    logging_first_step=True,
    logging_dir="./logs",
    optim="adamw_torch_fused",
    gradient_checkpointing=True,        
    torch_compile=False,                
    dataloader_pin_memory=False,      
    torch_empty_cache_steps=2         
)

gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
if hasattr(torch.mps, 'empty_cache'):
    torch.mps.empty_cache()

seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Trainer(
    model=model,
    args=trainer_args,
    tokenizer=tokenizer,
    data_collator=seq2seq_data_collator,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"]
)


train_result =trainer.train()

## 5. Model Evaluation

We'll evaluate our anonymization model using multiple metrics that provide a comprehensive assessment:

1. **Entity-level metrics** - Precision, Recall, F1 for correctly anonymized data
2. **Text similarity metrics** - BLEU and ROUGE-L to measure text structure preservation
3. **Data privacy metrics** - Ensuring sensitive information doesn't appear in predictions
4. **Visualization** - Charts to understand model performance

Let's start by loading the model from our latest checkpoint.

In [None]:
# Load the trained model and tokenizer
model_name = "google/flan-t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained("anonymizer_model/checkpoint-1250")
tokenizer = AutoTokenizer.from_pretrained(model_name)

print("Model and tokenizer loaded successfully.")

In [None]:
from typing import Dict, List, Set, Tuple
from dataclasses import dataclass
import re
from collections import defaultdict
import numpy as np
from nltk import download
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# Download NLTK resources if needed
download('punkt')


def extract_sensitive_data(original_text: str, anonymized_text: str) -> dict:
    """
    Extracts sensitive data pairs from original and anonymized text.
    
    Args:
        original_text: The text containing sensitive information
        anonymized_text: The anonymized version with labels
        
    Returns:
        Dictionary with sensitive data pairs, values, labels and types
    """
    # Label pattern
    label_pattern = r'\[([A-Z_]+)_\d+\]'
    
    # Split texts into lines for easier comparison
    original_lines = original_text.split('\n')
    anonymized_lines = anonymized_text.split('\n')
    
    sensitive_pairs = []
    
    # For each pair of lines
    for orig_line, anon_line in zip(original_lines, anonymized_lines):
        # Skip identical lines
        if orig_line == anon_line:
            continue
            
        # Find all labels in the anonymized line
        labels = list(re.finditer(label_pattern, anon_line))
        
        if not labels:
            continue
            
        # Create a list of text fragments between labels
        anon_fragments = []
        last_end = 0
        
        for label_match in labels:
            start = label_match.start()
            if start > last_end:
                anon_fragments.append(anon_line[last_end:start])
            anon_fragments.append(label_match.group(0))
            last_end = label_match.end()
            
        if last_end < len(anon_line):
            anon_fragments.append(anon_line[last_end:])
            
        # Remove empty fragments and white space
        anon_fragments = [f.strip() for f in anon_fragments if f.strip()]
        
        # For each label, find the corresponding text
        for i, fragment in enumerate(anon_fragments):
            if re.match(label_pattern, fragment):
                label_type = re.match(r'\[([A-Z_]+)', fragment).group(1)
                
                # Look for text before and after the label
                before = anon_fragments[i-1] if i > 0 else ""
                after = anon_fragments[i+1] if i < len(anon_fragments)-1 else ""
                
                # Find the corresponding text in the original line
                if before and after:
                    pattern = f"{re.escape(before)}(.*?){re.escape(after)}"
                    if match := re.search(pattern, orig_line):
                        value = match.group(1).strip()
                        if value:
                            sensitive_pairs.append((value, fragment, label_type))
                elif before:
                    pattern = f"{re.escape(before)}(.*?)(?={re.escape(after) if after else '$'})"
                    if match := re.search(pattern, orig_line):
                        value = match.group(1).strip()
                        if value:
                            sensitive_pairs.append((value, fragment, label_type))
                elif after:
                    pattern = f"(?<=^|{re.escape(before)})(.*?){re.escape(after)}"
                    if match := re.search(pattern, orig_line):
                        value = match.group(1).strip()
                        if value:
                            sensitive_pairs.append((value, fragment, label_type))
                else:
                    # If we don't have context, look for text that's not in other pairs
                    used_values = set(v for v, _, _ in sensitive_pairs)
                    orig_words = set(w.strip() for w in re.split(r'[\s,.]', orig_line) if w.strip())
                    remaining = orig_words - used_values
                    if remaining:
                        value = max(remaining, key=len)  # take the longest remaining text
                        sensitive_pairs.append((value, fragment, label_type))
    
    # Remove duplicates while preserving order
    seen = set()
    sensitive_pairs = [(v, l, t) for v, l, t in sensitive_pairs 
                      if not (v in seen or seen.add(v))]
    
    return {
        'sensitive_pairs': sensitive_pairs,
        'original_values': set(v for v, _, _ in sensitive_pairs),
        'anonymized_labels': set(l for _, l, _ in sensitive_pairs),
        'label_types': set(t for _, _, t in sensitive_pairs)
    }

In [9]:
# Functions for entity recognition, BLEU and ROUGE metrics

def extract_entities(text: str):
    """
    Extracts entity labels in shortened form (e.g., returns "NAME" from "[NAME_1]").
    """
    pattern = r'\[([A-Z_]+)_\d+\]'
    return re.findall(pattern, text)

def evaluate_entity_sequence(gt_text: str, pred_text: str):
    """
    Evaluates entity sequence by comparing lists of entities extracted from ground truth and prediction.
    """
    gt_entities = extract_entities(gt_text)
    pred_entities = extract_entities(pred_text)
    
    # Positional comparison - only check up to the length of the shorter sequence
    min_len = min(len(gt_entities), len(pred_entities))
    correct = sum(1 for i in range(min_len) if gt_entities[i] == pred_entities[i])
    
    # If sequences differ in length, treat excess as errors
    missing = len(gt_entities) - correct
    extra = len(pred_entities) - correct
    
    precision = correct / len(pred_entities) if pred_entities else 1.0
    recall = correct / len(gt_entities) if gt_entities else 1.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        "gt_entities": gt_entities,
        "pred_entities": pred_entities,
        "correct": correct,
        "missing": missing if missing > 0 else 0,
        "extra": extra if extra > 0 else 0,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

def compute_bleu(gt_text: str, pred_text: str):
    """
    Calculates BLEU for comparing entire text (using space tokenization).
    """
    gt_tokens = gt_text.split()
    pred_tokens = pred_text.split()
    smoothie = SmoothingFunction().method4  # Use smoothing to avoid zero results
    bleu_score = sentence_bleu([gt_tokens], pred_tokens, smoothing_function=smoothie)
    return bleu_score

def lcs_length(a, b):
    """
    Calculates length of the longest common subsequence (LCS).
    """
    dp = [[0]*(len(b)+1) for _ in range(len(a)+1)]
    for i in range(len(a)):
        for j in range(len(b)):
            if a[i] == b[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    return dp[len(a)][len(b)]

def compute_rouge_l(gt_text: str, pred_text: str):
    """
    Simple calculation of ROUGE-L metric (based on LCS).
    """
    gt_tokens = gt_text.split()
    pred_tokens = pred_text.split()
    lcs = lcs_length(gt_tokens, pred_tokens)
    rouge_l_recall = lcs / len(gt_tokens) if gt_tokens else 0
    rouge_l_precision = lcs / len(pred_tokens) if pred_tokens else 0
    if rouge_l_precision + rouge_l_recall == 0:
        rouge_l_f1 = 0
    else:
        rouge_l_f1 = 2 * rouge_l_precision * rouge_l_recall / (rouge_l_precision + rouge_l_recall)
    return {
        "rouge_l_precision": rouge_l_precision,
        "rouge_l_recall": rouge_l_recall,
        "rouge_l_f1": rouge_l_f1
    }

In [23]:
def detect_privacy_leaks(original_text, ground_truth, prediction):
    """
    Comprehensive privacy leak detection using multiple methods.
    
    Args:
        original_text: The text containing sensitive information
        ground_truth: The correctly anonymized version
        prediction: The model's output to check for leaks
        
    Returns:
        Tuple of (leak_detected: bool, leaked_items: list)
    """
    # Text normalization
    def normalize(text):
        return re.sub(r'\s+', ' ', text.lower().strip())
    
    orig_norm = normalize(original_text)
    gt_norm = normalize(ground_truth)
    pred_norm = normalize(prediction)
    
    # Method 1: Using extract_sensitive_data
    extracted = extract_sensitive_data(original_text, ground_truth)
    sensitive_values = extracted['original_values']

    # Check for presence in prediction
    leaks = [value for value in sensitive_values 
             if value and len(value) > 3 and value.lower() in pred_norm]
    
    return len(leaks) > 0, leaks

In [24]:
def evaluate_anonymization(original_text: str, ground_truth: str, prediction: str):
    """
    Comprehensive evaluation of anonymization quality using multiple metrics.
    
    Calculates:
      - Entity sequence metrics (precision, recall, F1)
      - BLEU score for whole text similarity
      - ROUGE-L (F1) for text structure preservation
      - Privacy leak detection using multiple methods
      
    Args:
        original_text: The text containing sensitive information
        ground_truth: The correctly anonymized version (gold standard)
        prediction: The model's anonymized output
        
    Returns:
        Dictionary with comprehensive evaluation metrics
    """
    # Entity sequence metrics (correctly placed entity types)
    entity_metrics = evaluate_entity_sequence(ground_truth, prediction)
    
    # Text similarity metrics
    bleu = compute_bleu(ground_truth, prediction)
    rouge = compute_rouge_l(ground_truth, prediction)
    
    # Advanced privacy leak detection
    has_leaks, leaked_items = detect_privacy_leaks(original_text, ground_truth, prediction)
    
    return {
        "entity_metrics": entity_metrics,
        "bleu": bleu,
        "rouge_l": rouge,
        "sensitive_data_leak": has_leaks,
        "leaked_items": leaked_items
    }

In [25]:
def generate_anonymization(text_to_anonymize, labels):
    """
    Generate anonymized text using the trained model.
    
    Args:
        text_to_anonymize: Original text containing sensitive information
        labels: List of entity types to use for anonymization
        
    Returns:
        Anonymized text produced by the model
    """
    prompt = create_anonymization_prompt(labels)
    inputs = tokenizer(
        prompt + text_to_anonymize, 
        return_tensors="pt",
        truncation=True
    )
    
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=512,
        temperature=0.1
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


def evaluate_model_performance(dataset, tokenizer, model, num_examples=5):
    """
    Performs comprehensive evaluation of the anonymization model using multiple metrics.
    
    This function tests the model on examples from the dataset and calculates:
    - Entity recognition metrics (precision, recall, F1)
    - Text similarity measures (BLEU, ROUGE-L)
    - Privacy leak analysis
    
    Args:
        dataset: Dataset containing test examples
        tokenizer: Model tokenizer
        model: Trained anonymization model
        num_examples: Number of examples to test
        
    Returns:
        List of evaluation results with detailed metrics for each example
    """
    results = []
    
    # Get a subset of examples - handle both Dataset objects and lists
    if hasattr(dataset["test"], "select"):
        test_examples = dataset["test"].select(range(min(num_examples, len(dataset["test"]))))
    else:
        test_examples = dataset["test"][:num_examples]
    
    print("\n=== Privacy-Preserving Text Anonymization: Model Evaluation ===\n")
    
    for idx, example in enumerate(test_examples):
        print(f"\n--- Example {idx+1} ---")
        print("Original text:")
        print(example['context'])
        print("\nGround truth (expected anonymization):")
        print(example['anonymized_context'])
        
        # Generate model prediction using the prompt created from used_labels
        prompt = create_anonymization_prompt(example['used_labels'])
        input_text = prompt + example['context']
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=512,
            temperature=0.1
        )
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        print("\nModel prediction:")
        print(prediction)
        
        metrics = evaluate_anonymization(
            example['context'], 
            example['anonymized_context'], 
            prediction
        )
        
        print("\nEvaluation metrics:")
        em = metrics["entity_metrics"]
        print(f"Entity Recognition - Precision: {em['precision']:.3f}, Recall: {em['recall']:.3f}, F1: {em['f1']:.3f}")
        print(f"BLEU Score: {metrics['bleu']:.3f}")
        print(f"ROUGE-L F1: {metrics['rouge_l']['rouge_l_f1']:.3f}")
        print(f"Privacy breach detected: {metrics['sensitive_data_leak']}")
        
        # Detailed information about leaks
        if metrics['sensitive_data_leak']:
            print("Leaked sensitive data:")
            for item in metrics['leaked_items']:
                print(f"  - '{item}'")
                
        # Entity type statistics
        if em["missing"] > 0:
            print(f"\nMissing entity types: {em['missing']}")
        if em["extra"] > 0:
            print(f"Extra entity types: {em['extra']}")
            
        results.append(metrics)
        
    return results

In [13]:
def calculate_average_metrics(results):
    """
    Calculates average metrics across all evaluation results.
    
    Args:
        results: List of evaluation results from evaluate_model_performance
        
    Returns:
        Dictionary with average metrics and privacy breach percentage
    """
    # Average entity metrics
    entity_precisions = [r["entity_metrics"]["precision"] for r in results]
    entity_recalls = [r["entity_metrics"]["recall"] for r in results]
    entity_f1s = [r["entity_metrics"]["f1"] for r in results]

    avg_entity_precision = np.mean(entity_precisions)
    avg_entity_recall = np.mean(entity_recalls)
    avg_entity_f1 = np.mean(entity_f1s)

    # Average BLEU score
    bleu_scores = [r["bleu"] for r in results]
    avg_bleu = np.mean(bleu_scores)

    # Average ROUGE-L F1 score
    rouge_l_f1_scores = [r["rouge_l"]["rouge_l_f1"] for r in results]
    avg_rouge_l_f1 = np.mean(rouge_l_f1_scores)

    # Calculate percentage of cases with sensitive data leaks
    sensitive_leak_count = sum(1 for r in results if r["sensitive_data_leak"])
    leak_percentage = sensitive_leak_count / len(results)
    
    return {
        "avg_entity_precision": avg_entity_precision,
        "avg_entity_recall": avg_entity_recall,
        "avg_entity_f1": avg_entity_f1,
        "avg_bleu": avg_bleu,
        "avg_rouge_l_f1": avg_rouge_l_f1,
        "leak_percentage": leak_percentage
    }

In [26]:
def calculate_data_privacy_metrics(results, dataset, num_examples=5):
    """
    Calculates global privacy metrics to assess data leakage at scale.
    
    While per-example leak detection identifies which examples have leaks,
    this function provides a global perspective by calculating what percentage
    of all sensitive entities were properly anonymized across the entire dataset.
    
    Args:
        results: List of evaluation results from evaluate_model_performance
        dataset: Dataset used for evaluation
        num_examples: Number of examples that were evaluated
        
    Returns:
        Dictionary with comprehensive privacy metrics including:
        - total_sensitive_items: Count of all sensitive data elements
        - total_leaked_items: Count of sensitive data that leaked
        - global_leak_rate: Percentage of all sensitive data that leaked
    """
    total_sensitive_items = 0
    total_leaked_items = 0
    
    # Get the subset of examples that were evaluated
    if hasattr(dataset["test"], "select"):
        test_examples = dataset["test"].select(range(min(num_examples, len(dataset["test"]))))
    else:
        test_examples = dataset["test"][:num_examples]
    
    print("\n=== Global Data Privacy Analysis ===\n")
    
    # Process each example and its corresponding result
    for idx, (example, result) in enumerate(zip(test_examples, results)):
        # Extract sensitive data from original and ground truth texts
        original_text = example['context']
        ground_truth = example['anonymized_context']
        
        # Extract sensitive data using both methods used in detect_privacy_leaks function
        def normalize(text):
            return re.sub(r'\s+', ' ', text.lower().strip())
            
        # Extract all sensitive items from this example
        extracted = extract_sensitive_data(original_text, ground_truth)
        sensitive_values = extracted['original_values']
        
        # Method 2: Contextual search (from detect_privacy_leaks)
        orig_norm = normalize(original_text)
        gt_norm = normalize(ground_truth)
        additional_values = []
        for match in re.finditer(r'\[([A-Z_]+)_\d+\]', gt_norm):
            context = gt_norm[max(0, match.start()-30):min(len(gt_norm), match.end()+30)]
            pattern = re.escape(context).replace(re.escape(match.group(0)), '(.*?)')
            if found := re.search(pattern, orig_norm):
                additional_values.append(found.group(1).strip())
        
        # Combine values from both methods and filter
        all_sensitive_values = sensitive_values.union(set(additional_values))
        all_sensitive_values = {v for v in all_sensitive_values if v and len(v) > 3}
        
        # Use leaked items from the result
        leaked_items = result['leaked_items'] if result['sensitive_data_leak'] else []
        
        # Update counts
        example_sensitive_count = len(all_sensitive_values)
        example_leak_count = len(leaked_items)
        
        total_sensitive_items += example_sensitive_count
        total_leaked_items += example_leak_count
        
        # Print detailed information for each example
        print(f"Example {idx+1}:")
        print(f"  - Sensitive items: {example_sensitive_count}")
        print(f"  - Leaked items: {example_leak_count}")
        if example_sensitive_count > 0:
            print(f"  - Example privacy score: {(1 - example_leak_count/example_sensitive_count):.2%}")
        else:
            print(f"  - Example privacy score: 100.00%")
        if example_leak_count > 0:
            print(f"  - Leaked values: {', '.join(leaked_items)}")
        print()
    
    # Calculate global rate
    global_leak_rate = total_leaked_items / total_sensitive_items if total_sensitive_items > 0 else 0
    
    # Print summary
    print("\n=== Summary ===")
    print(f"Total examples analyzed: {len(results)}")
    print(f"Examples with privacy breaches: {sum(1 for r in results if r['sensitive_data_leak'])}")
    print(f"Total sensitive items across all examples: {total_sensitive_items}")
    print(f"Total leaked items across all examples: {total_leaked_items}")
    print(f"Global privacy protection rate: {(1-global_leak_rate):.2%} of sensitive data properly anonymized")
    
    return {
        "total_examples": len(results),
        "examples_with_leaks": sum(1 for r in results if r['sensitive_data_leak']),
        "example_leak_rate": sum(1 for r in results if r['sensitive_data_leak']) / len(results),
        "total_sensitive_items": total_sensitive_items,
        "total_leaked_items": total_leaked_items,
        "global_leak_rate": global_leak_rate
    }

In [None]:
# Run evaluation on test examples
results = evaluate_model_performance(dataset, tokenizer, model, num_examples=500)

# Calculate and display average metrics
avg_metrics = calculate_average_metrics(results)

# Calculate global privacy metrics
global_privacy_metrics = calculate_data_privacy_metrics(results, dataset, num_examples=500)

print("\n=== Overall Model Performance ===")
print("Entity recognition metrics - Precision: {:.3f}, Recall: {:.3f}, F1: {:.3f}".format(
    avg_metrics["avg_entity_precision"], 
    avg_metrics["avg_entity_recall"], 
    avg_metrics["avg_entity_f1"])
)
print("BLEU Score: {:.3f}".format(avg_metrics["avg_bleu"]))
print("ROUGE-L F1: {:.3f}".format(avg_metrics["avg_rouge_l_f1"]))
print("Privacy breaches detected in {:.1%} of examples".format(avg_metrics["leak_percentage"]))
print("Global privacy protection rate: {:.2%}".format(1 - global_privacy_metrics["global_leak_rate"]))

In [None]:
def visualize_evaluation_results_demo(results, avg_metrics, global_privacy_metrics):
    """
    Creates simplified visualizations for large evaluation result sets.
    
    Focused on key metrics and aggregated statistics for demonstration purposes.
    
    Args:
        results: List of evaluation results from evaluate_model_performance
        avg_metrics: Dictionary with average metrics
        global_privacy_metrics: Dictionary with global privacy metrics
    """
    # Prepare data for analysis
    metrics_data = []
    for i, r in enumerate(results):
        metrics_data.append({
            'Entity_F1': r["entity_metrics"]["f1"],
            'BLEU': r["bleu"],
            'ROUGE_L_F1': r["rouge_l"]["rouge_l_f1"],
            'Has_Leak': 1 if r["sensitive_data_leak"] else 0
        })
    
    df = pd.DataFrame(metrics_data)
    
    # Set up the figure layout
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axs = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Histograms of main metrics
    metrics = ['Entity_F1', 'BLEU', 'ROUGE_L_F1']
    colors = ['#3498db', '#2ecc71', '#9b59b6']
    
    for i, (metric, color) in enumerate(zip(metrics, colors)):
        axs[0, 0].hist(df[metric], bins=20, alpha=0.7, label=metric.replace('_', ' '), color=color)
    
    axs[0, 0].axvline(x=avg_metrics['avg_entity_f1'], color='#3498db', linestyle='--')
    axs[0, 0].axvline(x=avg_metrics['avg_bleu'], color='#2ecc71', linestyle='--')
    axs[0, 0].axvline(x=avg_metrics['avg_rouge_l_f1'], color='#9b59b6', linestyle='--')
    
    axs[0, 0].set_title('Distribution of Key Metrics', fontsize=14)
    axs[0, 0].set_xlabel('Score')
    axs[0, 0].set_ylabel('Number of Examples')
    axs[0, 0].legend()
    axs[0, 0].set_xlim(0, 1)
    
    # 2. Privacy protection pie chart
    sizes = [(global_privacy_metrics['total_sensitive_items'] - global_privacy_metrics['total_leaked_items']), 
             global_privacy_metrics['total_leaked_items']]
    labels = ['Protected', 'Leaked']
    colors = ['#3498db', '#e74c3c']
    
    axs[0, 1].pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', 
                 startangle=90, shadow=True)
    
    # Add text in center
    privacy_rate = 1 - global_privacy_metrics["global_leak_rate"]
    axs[0, 1].text(0, 0, f"{privacy_rate:.1%}\nProtected", 
                  ha='center', va='center', fontsize=14, fontweight='bold')
    
    axs[0, 1].set_title('Global Privacy Protection', fontsize=14)
    axs[0, 1].axis('equal')  # Equal aspect ratio
    
    # 3. Examples with/without leaks
    leak_counts = [len(df) - df['Has_Leak'].sum(), df['Has_Leak'].sum()]
    labels = ['No Leaks', 'Has Leaks']
    colors = ['#2ecc71', '#e74c3c']
    
    axs[1, 0].bar(labels, leak_counts, color=colors)
    axs[1, 0].set_title('Examples with Privacy Breaches', fontsize=14)
    axs[1, 0].set_ylabel('Number of Examples')
    
    # Add percentages on top of bars
    for i, count in enumerate(leak_counts):
        percentage = 100 * count / len(df)
        axs[1, 0].text(i, count + 5, f"{percentage:.1f}%", 
                      ha='center', fontweight='bold')
    
    # 4. Average metrics comparison
    metrics_labels = ['Entity F1', 'BLEU', 'ROUGE-L']
    metrics_values = [avg_metrics['avg_entity_f1'], 
                      avg_metrics['avg_bleu'], 
                      avg_metrics['avg_rouge_l_f1']]
    
    axs[1, 1].bar(metrics_labels, metrics_values, color='#3498db')
    axs[1, 1].set_title('Average Performance Metrics', fontsize=14)
    axs[1, 1].set_ylabel('Score')
    axs[1, 1].set_ylim(0, 1)
    
    # Add scores on top of bars
    for i, value in enumerate(metrics_values):
        axs[1, 1].text(i, value + 0.03, f"{value:.3f}", 
                      ha='center', fontweight='bold')
    
    # Adjust layout and add title
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    fig.suptitle('Privacy-Preserving Text Anonymization: Performance Summary', fontsize=16)
    
    plt.show()
    
    # Print simple summary
    print("\n=== Model Performance Summary ===")
    print(f"Total examples analyzed: {len(results)}")
    print(f"Entity recognition F1 Score: {avg_metrics['avg_entity_f1']:.3f}")
    print(f"BLEU Score: {avg_metrics['avg_bleu']:.3f}")
    print(f"ROUGE-L F1 Score: {avg_metrics['avg_rouge_l_f1']:.3f}")
    print(f"Privacy breaches in {df['Has_Leak'].mean():.1%} of examples")
    print(f"Privacy protection rate: {privacy_rate:.1%} of sensitive data anonymized")

# Usage example:
visualize_evaluation_results_demo(results, avg_metrics, global_privacy_metrics)

In [None]:
# Practical example of using the model on new text
def demonstrate_anonymization(text, entity_types):
    """
    Demonstrates the anonymization model on a new text with specified entity types.
    
    Args:
        text: Text to anonymize
        entity_types: List of entity types to anonymize (e.g., ["NAME", "DATE", "ADDRESS"])
    """
    print("=== Anonymization Demo ===")
    print("\nOriginal text:")
    print(text)
    
    # Create labels with proper format
    formatted_labels = [f"[{entity.upper()}]" for entity in entity_types]
    
    # Generate anonymized version
    anonymized = generate_anonymization(text, formatted_labels)
    
    print("\nAnonymized text:")
    print(anonymized)
    
    # Detect if any entity types were missed
    normalized_types = [t.upper() for t in entity_types]
    used_types = []
    
    for match in re.finditer(r'\[([A-Z_]+)_\d+\]', anonymized):
        entity_type = match.group(1)
        if entity_type not in used_types:
            used_types.append(entity_type)
    
    missing_types = [t for t in normalized_types if t not in used_types]
    
    if missing_types:
        print("\nWarning: The following entity types were not detected:")
        for t in missing_types:
            print(f"- {t}")
    
    print("\nNote: For production use, verify all sensitive information has been properly anonymized.")

# Example usage
sample_text = "On December 15, 2023, Jane Smith (born 05/12/1990) made a payment of $1,245.00 to ABC Corporation from her account #12345678. Please contact her at jane.smith@example.com or call 555-123-4567 if there are any issues."

demonstrate_anonymization(sample_text, ["NAME", "DATE", "EMAIL", "PHONE", "ACCOUNT_NUMBER", "MONEY_AMOUNT", "COMPANY"])