In [1]:
import pandas as pd
import numpy as np
import os
import sys
import json
from tqdm import tqdm
import ollama
import re
import time
from collections import defaultdict
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import seaborn as sns

# # Add parent directory to path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)




In [16]:
!ollama pull llama3.2:1b
!ollama list

[?2026h[?25l[1Gpulling manifest ‚†ã [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†ô [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†π [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†∏ [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†º [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†¥ [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†¶ [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†ß [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†á [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†è [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†ã [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ‚†ô [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest [K
pulling 74701a8c35f6:   0% ‚ñï                  ‚ñè 1.5 MB/1.3 GB                  [K[?25h[?2026l[?2026h[?25l[A[1Gpulling manifest [K
pulling 74701a8c35f6:   0% ‚ñï                  ‚ñè 5.4 MB/1.3 GB                  [K[?25h[?2026l[?2026h[?25l[A[1Gpulling manifest [K
pulli

NAME                ID              SIZE      MODIFIED               
llama3.2:1b         baf6a787fdff    1.3 GB    Less than a second ago    
llama3.2:3b         a80c4f17acd5    2.0 GB    3 minutes ago             
llama3.1:8b         46e0c10c039e    4.9 GB    About an hour ago         
deepseek-r1:1.5b    e0979632db5a    1.1 GB    4 months ago              


## Experiment 1: Clarify-and-Link - Full Implementation

# 
**Project:** Contextual Augmentation for Entity Linking
**Method:** Clarify-and-Link (vs Replace-and-Link from paper)
**Dataset:** AIDA-CoNLL (full train/val/test splits)
**LLM:** Ollama Llama 3.1 (8B)
**Model:** T5-base fine-tuned for entity linking
 
**Pipeline:**
1. Load preprocessed AIDA data (train/val/test)
2. Generate LLM clarifications for all entities
3. Create augmented datasets (baseline + clarified)
4. Train T5 model on both versions
5. Evaluate and compare results
 
**Key Innovation:**
- APPEND clarifications instead of REPLACING mentions
- Preserves original context while adding semantic information
- More robust to LLM errors than Replace-and-Link


In [17]:
# Step 1: Load Preprocessed AIDA Data
# ====================================

print("\nüìÇ Loading preprocessed AIDA data...")

df_train = pd.read_parquet('../../data/processed/aida/train.parquet')
df_val = pd.read_parquet('../../data/processed/aida/validation.parquet')
df_test = pd.read_parquet('../../data/processed/aida/test.parquet')

print(f"\n‚úì Train: {len(df_train)} documents")
print(f"‚úì Validation: {len(df_val)} documents")
print(f"‚úì Test: {len(df_test)} documents")

# Count total entities
train_entities = sum(len(doc['entities']) for _, doc in df_train.iterrows())
val_entities = sum(len(doc['entities']) for _, doc in df_val.iterrows())
test_entities = sum(len(doc['entities']) for _, doc in df_test.iterrows())

print(f"\nüìä Total entities:")
print(f"   Train: {train_entities:,}")
print(f"   Val: {val_entities:,}")
print(f"   Test: {test_entities:,}")
print(f"   Total: {train_entities + val_entities + test_entities:,}")


üìÇ Loading preprocessed AIDA data...

‚úì Train: 946 documents
‚úì Validation: 216 documents
‚úì Test: 231 documents

üìä Total entities:
   Train: 23,393
   Val: 5,916
   Test: 5,614
   Total: 34,923


In [None]:
# Step 2: LLM Clarification Generation
# =====================================

def generate_clarification(mention, context_left, context_right, model, max_retries=3):
    """
    Generate a clarifying description for an entity mention using Ollama.
    
    Args:
        mention: Entity text (e.g., "Jordan")
        context_left: Text before the mention
        context_right: Text after the mention
        model: Ollama model name
        max_retries: Number of retry attempts if LLM fails
    
    Returns:
        String clarification (30-40 words)
    """
    # Create context window (limit to ~200 chars each side)
    context_left = context_left[-200:] if len(context_left) > 200 else context_left
    context_right = context_right[:200] if len(context_right) > 200 else context_right
    context = f"{context_left} {mention} {context_right}"
    
    prompt = f"""Based on this context: "{context}"

Provide a brief, factual description for the entity "{mention}".
Disambiguate what/who this specific mention refers to (person, place, organization, etc.).
Use simple English. Respond with ONLY the description, no extra text (max 40 words).

Description:"""
    
    for attempt in range(max_retries):
        try:
            response = ollama.chat(
                            time.sleep(1)  # Wait before retry
                continue
            else:
                print(f"‚ö†Ô∏è Error generating clarification for '{mention}': {e}")
                return f"Entity: {mention}"  # Fallback
    
    return f"Entity: {mention}"


def process_document_clarifications(doc_entities, model="llama3.1:8b"):
    """
    Generate clarifications for all entities in a document.
    
    Args:
        doc_entities: List of entity dicts from one document
        model: Ollama model name
    
    Returns:
        Dict mapping mention ‚Üí clarification
    """
    clarifications = {}
    
    for entity in doc_entities:
        mention = entity.get('mention', '')
        
        # Skip if already processed (same mention appears multiple times)
        if mention in clarifications:
            continue
        
        context_left = entity.get('context_left', '')
        context_right = entity.get('context_right', '')
        
        # Generate clarification
        clarification = generate_clarification(mention, context_left, context_right, model)
        clarifications[mention] = clarification
        
        # Small delay to avoid overwhelming LLM
        time.sleep(0.3)
    
    return clarifications


def batch_generate_all_clarifications(df, split_name, model="llama3.2:1b", save_checkpoint_every=100, use_parallel=True, max_workers=3):
    """
    OPTIMIZED: Generate clarifications only for UNIQUE mentions across entire split.
    With optional parallel processing.
    """
    checkpoint_dir = f'../../data/experiments/clarifications_checkpoints/{split_name}'
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    print(f"\nüîÑ STEP 1: Collecting unique mentions from {split_name} split...")
    
    # Collect all unique mentions with their best context
    unique_mentions = {}
    for idx, row in df.iterrows():
        for entity in row['entities']:
            mention = entity.get('mention', '')
            if mention not in unique_mentions:
                unique_mentions[mention] = {
                    'context_left': entity.get('context_left', ''),
                    'context_right': entity.get('context_right', '')
                }
    
    total_entities = sum(len(row['entities']) for _, row in df.iterrows())
    print(f"   Found {len(unique_mentions)} unique mentions (vs {total_entities} total entities)")
    print(f"   Reduction: {(1 - len(unique_mentions)/total_entities)*100:.1f}%")
    
    # Generate clarifications
    print(f"\nüîÑ STEP 2: Generating clarifications for unique mentions...")
    print(f"   Mode: {'PARALLEL' if use_parallel else 'SEQUENTIAL'}")
    
    if use_parallel:
        # PARALLEL VERSION - Much faster but needs more memory
        from concurrent.futures import ThreadPoolExecutor, as_completed
        
        global_clarifications = {}
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Submit all tasks
            future_to_mention = {
                executor.submit(
                    generate_clarification, 
                    mention, 
                    context['context_left'], 
                    context['context_right'], 
                    model
                ): mention
                for mention, context in unique_mentions.items()
            }
            
            # Collect results with progress bar
            for i, future in enumerate(tqdm(as_completed(future_to_mention), total=len(unique_mentions), desc="Parallel clarifying")):
                mention = future_to_mention[future]
                try:
                    clarification = future.result()
                    global_clarifications[mention] = clarification
                except Exception as e:
                    print(f"‚ö†Ô∏è Error for '{mention}': {e}")
                    global_clarifications[mention] = f"Entity: {mention}"
                
                # Checkpoint every 100 mentions
                if (i + 1) % save_checkpoint_every == 0:
                    checkpoint_path = f'{checkpoint_dir}/clarifications_checkpoint_{i+1}.json'
                    with open(checkpoint_path, 'w') as f:
                        json.dump(global_clarifications, f, indent=2)
    
    else:
        # SEQUENTIAL VERSION - Your current implementation (safer)
        global_clarifications = {}
        
        for i, (mention, context) in enumerate(tqdm(unique_mentions.items(), desc="Sequential clarifying")):
            clarification = generate_clarification(
                mention, 
                context['context_left'], 
                context['context_right'], 
                model
            )
            global_clarifications[mention] = clarification
            
            # Checkpoint every 100 mentions
            if (i + 1) % save_checkpoint_every == 0:
                checkpoint_path = f'{checkpoint_dir}/clarifications_checkpoint_{i+1}.json'
                with open(checkpoint_path, 'w') as f:
                    json.dump(global_clarifications, f, indent=2)
            
            # Small delay to avoid overwhelming LLM
            time.sleep(0.1)
    
    # Map clarifications back to documents
    print(f"\nüîÑ STEP 3: Mapping clarifications to documents...")
    results = []
    
    for idx, row in df.iterrows():
        text = row['text']
        entities = row['entities']
        
        # Use pre-generated clarifications
        doc_clarifications = {
            entity['mention']: global_clarifications.get(entity['mention'], f"Entity: {entity['mention']}")
            for entity in entities
        }
        
        doc_result = {
            'doc_id': idx,
            'text': text,
            'entities': entities,
            'clarifications': doc_clarifications,
            'num_entities': len(entities),
            'num_clarifications': len(doc_clarifications)
        }
        results.append(doc_result)
    
    # Final save
    final_path = f'../../data/experiments/clarifications_{split_name}.json'
    with open(final_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\n‚úÖ {split_name} complete!")
    print(f"   Saved to: {final_path}")
    if not use_parallel:
        print(f"   Time saved: ~{(total_entities - len(unique_mentions)) * 0.3 / 60:.1f} minutes vs old approach")
    
    return results
    model=model,
                messages=[{"role": "user", "content": prompt}],
                options={"temperature": 0.3}  # Lower temperature for factual consistency
            )
            clarification = response['message']['content'].strip()
            
            # Clean up response
            clarification = clarification.strip('"').strip("'").strip()
            
            # Validate length (should be reasonable)
            word_count = len(clarification.split())
            if 5 <= word_count <= 60:  # Reasonable range
                return clarification
            elif attempt < max_retries - 1:
                continue  # Try again if too short/long
            else:
                return clarification  # Return anyway on last attempt
        
        except Exception as e:
            if attempt < max_retries - 1:


In [None]:
# Step 3: Generate Clarifications for All Splits
# ===============================================

# Option to load existing clarifications if already generated
REGENERATE_CLARIFICATIONS = False  # Set to True to regenerate

if REGENERATE_CLARIFICATIONS or not os.path.exists('../../data/experiments/clarifications_train.json'):
    print("\n" + "="*70)
    print("GENERATING CLARIFICATIONS FOR ALL SPLITS")
    print("="*70)
    
    # Start with validation (smaller, for testing)
    print("\n1Ô∏è‚É£ Validation Split")
    val_clarifications = batch_generate_all_clarifications(df_val, 'val',  model="llama3.2:1b",  use_parallel=True, max_workers=3)
    
    print("\n2Ô∏è‚É£ Test Split")
    test_clarifications = batch_generate_all_clarifications(df_test, 'test', model="llama3.2:1b", use_parallel=True, max_workers=3)
    
    print("\n3Ô∏è‚É£ Train Split (this will take ~2-3 hours)")
    confirm = input("‚ö†Ô∏è Train split will take ~2-3 hours. Continue? (yes/no): ")
    if confirm.lower() == 'yes':
        train_clarifications = batch_generate_all_clarifications(df_train, 'train', model="llama3.2:1b", use_parallel=True, max_workers=3)
    else:
        print("‚è≠Ô∏è Skipping train split. You can run this cell later.")
        train_clarifications = None
else:
    print("\nüìÇ Loading existing clarifications from disk...")
    
    with open('../../data/experiments/clarifications_train.json', 'r') as f:
        train_clarifications = json.load(f)
    with open('../../data/experiments/clarifications_val.json', 'r') as f:
        val_clarifications = json.load(f)
    with open('../../data/experiments/clarifications_test.json', 'r') as f:
        test_clarifications = json.load(f)
    
    print(f"‚úì Train: {len(train_clarifications)} documents")
    print(f"‚úì Val: {len(val_clarifications)} documents")
    print(f"‚úì Test: {len(test_clarifications)} documents")


GENERATING CLARIFICATIONS FOR ALL SPLITS

1Ô∏è‚É£ Validation Split

üîÑ STEP 1: Collecting unique mentions from val split...
   Found 2795 unique mentions (vs 5916 total entities)
   Reduction: 52.8%

üîÑ STEP 2: Generating clarifications for unique mentions...
   Mode: PARALLEL


Parallel clarifying:   0%|          | 0/2795 [00:00<?, ?it/s]

Parallel clarifying:   2%|‚ñè         | 63/2795 [01:28<1:03:37,  1.40s/it]


In [None]:
# Step 4: Create Augmented Datasets
# ==================================

def augment_text_with_clarifications(text, entities, clarifications):
    """
    Create clarified version of text.
    Format: [START_ENT] mention [END_ENT][CLARIFY: description]
    """
    augmented_text = text
    
    # Sort entities by start position (descending) to preserve offsets
    sorted_entities = sorted(entities, key=lambda e: e.get('start', 0), reverse=True)
    
    for entity in sorted_entities:
        mention = entity.get('mention', '')
        start = entity.get('start', 0)
        end = entity.get('end', 0)
        
        # Get clarification
        clarification = clarifications.get(mention, "")
        
        # Create augmented mention
        if clarification:
            augmented_mention = f"[START_ENT] {mention} [END_ENT][CLARIFY: {clarification}]"
        else:
            augmented_mention = f"[START_ENT] {mention} [END_ENT]"
        
        # Replace in text
        if start < len(augmented_text) and end <= len(augmented_text):
            augmented_text = augmented_text[:start] + augmented_mention + augmented_text[end:]
    
    return augmented_text


def create_baseline_text(text, entities):
    """
    Create baseline version (no clarifications).
    Format: [START_ENT] mention [END_ENT]
    """
    baseline_text = text
    
    sorted_entities = sorted(entities, key=lambda e: e.get('start', 0), reverse=True)
    
    for entity in sorted_entities:
        mention = entity.get('mention', '')
        start = entity.get('start', 0)
        end = entity.get('end', 0)
        
        marked_mention = f"[START_ENT] {mention} [END_ENT]"
        
        if start < len(baseline_text) and end <= len(baseline_text):
            baseline_text = baseline_text[:start] + marked_mention + baseline_text[end:]
    
    return baseline_text


def create_target_text(entities):
    """
    Create target text with linked entity IDs.
    Format: mention -> entity_id
    """
    # Create linked entity representation
    linked_entities = []
    for entity in entities:
        mention = entity.get('mention', '')
        qid = entity.get('qid', 'NIL')
        linked_entities.append(f"{mention} -> {qid}")
    
    return " | ".join(linked_entities)


def process_split_for_training(clarifications_data, split_name):
    """
    Convert clarification data to training format.
    
    Creates two datasets:
    1. Baseline: [START_ENT]mention[END_ENT] ‚Üí entity_id
    2. Clarified: [START_ENT]mention[END_ENT][CLARIFY:...] ‚Üí entity_id
    """
    print(f"\nüîß Processing {split_name} split for training...")
    
    baseline_samples = []
    clarified_samples = []
    
    for doc in tqdm(clarifications_data, desc=f"Creating {split_name} samples"):
        text = doc['text']
        entities = doc['entities']
        clarifications = doc['clarifications']
        
        # Create baseline version
        baseline_text = create_baseline_text(text, entities)
        
        # Create clarified version
        clarified_text = augment_text_with_clarifications(text, entities, clarifications)
        
        # Create target
        target_text = create_target_text(entities)
        
        # Add task suffix for T5
        baseline_samples.append({
            'input_text': baseline_text + " target_el",
            'target_text': target_text
        })
        
        clarified_samples.append({
            'input_text': clarified_text + " target_el",
            'target_text': target_text
        })
    
    print(f"‚úì Created {len(baseline_samples)} baseline samples")
    print(f"‚úì Created {len(clarified_samples)} clarified samples")
    
    return baseline_samples, clarified_samples


# Process all splits
print("\n" + "="*70)
print("CREATING TRAINING DATASETS")
print("="*70)

if train_clarifications:
    train_baseline, train_clarified = process_split_for_training(train_clarifications, 'train')
else:
    print("‚ö†Ô∏è Train clarifications not available. Skipping train dataset creation.")
    train_baseline, train_clarified = [], []

val_baseline, val_clarified = process_split_for_training(val_clarifications, 'val')
test_baseline, test_clarified = process_split_for_training(test_clarifications, 'test')

# Save processed datasets
os.makedirs('../../data/experiments/processed_for_training', exist_ok=True)

def save_samples(samples, filename):
    with open(filename, 'w') as f:
        for sample in samples:
            f.write(json.dumps(sample) + '\n')

if train_baseline:
    save_samples(train_baseline, '../../data/experiments/processed_for_training/train_baseline.jsonl')
    save_samples(train_clarified, '../../data/experiments/processed_for_training/train_clarified.jsonl')

save_samples(val_baseline, '../../data/experiments/processed_for_training/val_baseline.jsonl')
save_samples(val_clarified, '../../data/experiments/processed_for_training/val_clarified.jsonl')
save_samples(test_baseline, '../../data/experiments/processed_for_training/test_baseline.jsonl')
save_samples(test_clarified, '../../data/experiments/processed_for_training/test_clarified.jsonl')

print("\n‚úÖ Training datasets saved!")

In [None]:
# Step 5: Prepare T5 Training
# ============================

class EntityLinkingDataset(Dataset):
    """
    PyTorch Dataset for entity linking with T5.
    """
    def __init__(self, samples, tokenizer, max_length=512):
        self.samples = samples
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        encoding = self.tokenizer(
            sample['input_text'],
            text_target=sample['target_text'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {key: val.squeeze() for key, val in encoding.items()}


def load_samples(filename):
    """Load JSONL samples."""
    samples = []
    with open(filename, 'r') as f:
        for line in f:
            samples.append(json.loads(line))
    return samples


print("\nü§ñ Preparing T5 model and tokenizer...")

# Initialize tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-base')

# Add special tokens
special_tokens = {
    'additional_special_tokens': [
        '[START_ENT]',
        '[END_ENT]',
        '[CLARIFY:',
        ']'
    ]
}
tokenizer.add_special_tokens(special_tokens)

print(f"‚úì Tokenizer ready. Vocabulary size: {len(tokenizer)}")

# Load datasets
print("\nüì• Loading training samples...")

if os.path.exists('../../data/experiments/processed_for_training/train_baseline.jsonl'):
    train_baseline_samples = load_samples('../../data/experiments/processed_for_training/train_baseline.jsonl')
    train_clarified_samples = load_samples('../../data/experiments/processed_for_training/train_clarified.jsonl')
else:
    print("‚ö†Ô∏è Train samples not found. Using val for demonstration.")
    train_baseline_samples = load_samples('../../data/experiments/processed_for_training/val_baseline.jsonl')
    train_clarified_samples = load_samples('../../data/experiments/processed_for_training/val_clarified.jsonl')

val_baseline_samples = load_samples('../../data/experiments/processed_for_training/val_baseline.jsonl')
val_clarified_samples = load_samples('../../data/experiments/processed_for_training/val_clarified.jsonl')

print(f"‚úì Train baseline: {len(train_baseline_samples)} samples")
print(f"‚úì Train clarified: {len(train_clarified_samples)} samples")
print(f"‚úì Val baseline: {len(val_baseline_samples)} samples")
print(f"‚úì Val clarified: {len(val_clarified_samples)} samples")

# Create datasets
train_baseline_dataset = EntityLinkingDataset(train_baseline_samples, tokenizer)
train_clarified_dataset = EntityLinkingDataset(train_clarified_samples, tokenizer)
val_baseline_dataset = EntityLinkingDataset(val_baseline_samples, tokenizer)
val_clarified_dataset = EntityLinkingDataset(val_clarified_samples, tokenizer)

print("\n‚úÖ Datasets ready for training!")

In [None]:
# Step 6: Train Models
# ====================

def train_entity_linking_model(train_dataset, val_dataset, model_name, output_dir):
    """
    Train T5 model for entity linking.
    
    Args:
        train_dataset: Training dataset
        val_dataset: Validation dataset
        model_name: Model identifier (baseline or clarified)
        output_dir: Directory to save model
    """
    print(f"\n{'='*70}")
    print(f"TRAINING: {model_name.upper()}")
    print(f"{'='*70}")
    
    # Initialize model
    model = T5ForConditionalGeneration.from_pretrained('t5-base')
    model.resize_token_embeddings(len(tokenizer))  # Resize for special tokens
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir=f'{output_dir}/logs',
        logging_steps=100,
        eval_strategy='steps',
        eval_steps=500,
        save_steps=1000,
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model='eval_loss',
        greater_is_better=False,
        report_to='none',  # Disable wandb etc.
        fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
    )
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )
    
    # Train
    print(f"\nüöÄ Starting training...")
    print(f"   Epochs: {training_args.num_train_epochs}")
    print(f"   Batch size: {training_args.per_device_train_batch_size}")
    print(f"   Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
    
    trainer.train()
    
    # Save final model
    trainer.save_model(output_dir)
    print(f"\n‚úÖ {model_name} training complete!")
    print(f"   Model saved to: {output_dir}")
    
    return trainer, model


# Train both models
print("\n" + "="*70)
print("MODEL TRAINING PIPELINE")
print("="*70)

# Create output directories
os.makedirs('../../models', exist_ok=True)

# Train baseline model
baseline_trainer, baseline_model = train_entity_linking_model(
    train_baseline_dataset,
    val_baseline_dataset,
    'baseline',
    '../../models/t5_baseline'
)

# Train clarified model
clarified_trainer, clarified_model = train_entity_linking_model(
    train_clarified_dataset,
    val_clarified_dataset,
    'clarified',
    '../../models/t5_clarified'
)

print("\n‚úÖ Both models trained successfully!")

In [None]:
# Step 7: Evaluation
# ==================

def evaluate_model(model, tokenizer, test_samples, model_name):
    """
    Evaluate entity linking model.
    
    Metrics:
    - Exact match accuracy
    - Entity-level F1
    - Per-entity-type performance
    """
    print(f"\n{'='*70}")
    print(f"EVALUATING: {model_name.upper()}")
    print(f"{'='*70}")
    
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    predictions = []
    ground_truths = []
    
    print(f"\nüîç Running inference on {len(test_samples)} samples...")
    
    for sample in tqdm(test_samples, desc="Evaluating"):
        input_text = sample['input_text']
        target_text = sample['target_text']
        
        # Generate prediction
        inputs = tokenizer(input_text, return_tensors='pt', max_length=512, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=256)
        
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        predictions.append(prediction)
        ground_truths.append(target_text)
    
    # Calculate metrics
    exact_matches = sum(1 for p, g in zip(predictions, ground_truths) if p.strip() == g.strip())
    accuracy = exact_matches / len(predictions)
    
    print(f"\nüìä Results for {model_name}:")
    print(f"   Total samples: {len(predictions)}")
    print(f"   Exact matches: {exact_matches}")
    print(f"   Accuracy: {accuracy:.2%}")
    
    # Sample predictions
    print(f"\nüìù Sample predictions:")
    for i in range(min(3, len(predictions))):
        print(f"\n   Example {i+1}:")
        print(f"   Input: {test_samples[i]['input_text'][:100]}...")
        print(f"   Predicted: {predictions[i][:100]}...")
        print(f"   Ground truth: {ground_truths[i][:100]}...")
    
    return {
        'accuracy': accuracy,
        'exact_matches': exact_matches,
        'total_samples': len(predictions),
        'predictions': predictions,
        'ground_truths': ground_truths
    }


# Load test samples
test_baseline_samples = load_samples('../../data/experiments/processed_for_training/test_baseline.jsonl')
test_clarified_samples = load_samples('../../data/experiments/processed_for_training/test_clarified.jsonl')

# Evaluate both models
baseline_results = evaluate_model(baseline_model, tokenizer, test_baseline_samples, 'Baseline')
clarified_results = evaluate_model(clarified_model, tokenizer, test_clarified_samples, 'Clarify-and-Link')

# Save results
results_comparison = {
    'baseline': baseline_results,
    'clarified': clarified_results,
    'improvement': {
        'accuracy_gain': clarified_results['accuracy'] - baseline_results['accuracy'],
        'accuracy_gain_percent': ((clarified_results['accuracy'] - baseline_results['accuracy']) / baseline_results['accuracy'] * 100) if baseline_results['accuracy'] > 0 else 0
    }
}

with open('../../data/experiments/evaluation_results.json', 'w') as f:
    # Remove large lists before saving
    results_to_save = {
        'baseline': {k: v for k, v in baseline_results.items() if k not in ['predictions', 'ground_truths']},
        'clarified': {k: v for k, v in clarified_results.items() if k not in ['predictions', 'ground_truths']},
        'improvement': results_comparison['improvement']
    }
    json.dump(results_to_save, f, indent=2)

print("\n‚úÖ Evaluation complete! Results saved.")

In [None]:
# Step 8: Results Visualization
# ==============================

print("\n" + "="*70)
print("RESULTS VISUALIZATION")
print("="*70)

# Create comparison plot
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Accuracy comparison
models = ['Baseline', 'Clarify-and-Link']
accuracies = [baseline_results['accuracy'], clarified_results['accuracy']]
colors = ['#3498db', '#e74c3c']

axes[0].bar(models, accuracies, color=colors, alpha=0.8)
axes[0].set_ylabel('Accuracy', fontsize=12, fontweight='bold')
axes[0].set_title('Entity Linking Accuracy', fontsize=14, fontweight='bold')
axes[0].set_ylim([0, 1])
axes[0].grid(axis='y', alpha=0.3)

# Add value labels
for i, (model, acc) in enumerate(zip(models, accuracies)):
    axes[0].text(i, acc + 0.02, f'{acc:.2%}', ha='center', fontsize=11, fontweight='bold')

# Improvement visualization
improvement = clarified_results['accuracy'] - baseline_results['accuracy']
axes[1].bar(['Accuracy\nImprovement'], [improvement * 100], color='#2ecc71', alpha=0.8)
axes[1].set_ylabel('Percentage Points', fontsize=12, fontweight='bold')
axes[1].set_title('Clarify-and-Link Improvement', fontsize=14, fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)
axes[1].text(0, improvement * 100 + 0.5, f'+{improvement:.2%}', ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('../../data/experiments/results_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ Visualization saved to: data/experiments/results_comparison.png")