# RUN

In [1]:
# ============================================================================
# Cell 1: Setup and Installation
# ============================================================================

print("Installing dependencies...")

#pip install -q transformers torch pandas pyarrow tqdm accelerate

print("Installation complete!")

import torch
print(f"\n  Device: {'GPU (' + torch.cuda.get_device_name(0) + ')' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Installing dependencies...
Installation complete!

  Device: CPU


In [10]:
# ============================================================================
# Cell 2: Upload AIDA Data Files
# ============================================================================

from google.colab import files
import os

print("Upload your AIDA data files:")
print("   Required: train.parquet, validation.parquet, test.parquet")
print("\n   Click 'Choose Files' and select all 3 files")

os.makedirs('data/processed/aida', exist_ok=True)

uploaded = files.upload()

for filename in uploaded.keys():
    os.rename(filename, f'data/processed/aida/{filename}')
    print(f"‚úì {filename} uploaded")

Upload your AIDA data files:
   Required: train.parquet, validation.parquet, test.parquet

   Click 'Choose Files' and select all 3 files


# RUN

In [2]:
# ============================================================================
# Cell 3: Configuration
# ============================================================================

CONFIG = {
    'model_name': 'meta-llama/Llama-3.2-1B',  # Lightweight model for Colab
    ##'model_name': 'facebook/opt-1.3b',
    'batch_size': 32,  # Increase for faster processing on GPU
    'max_new_tokens': 50,
    'temperature': 0.3,
    'context_window_size': 100,
    'data_dir': 'data/processed/aida',
    'output_dir': 'data/experiments',
    'checkpoint_interval': 500,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

Configuration:
   model_name: meta-llama/Llama-3.2-1B
   batch_size: 32
   max_new_tokens: 50
   temperature: 0.3
   context_window_size: 100
   data_dir: data/processed/aida
   output_dir: data/experiments
   checkpoint_interval: 500
   device: cpu


# RUN

In [3]:
from huggingface_hub import login

token = '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'
login(token=token)

print("Authenticated with HuggingFace!")

Authenticated with HuggingFace!


# RUN

In [5]:
# ============================================================================
# Cell 4: Load Model and Tokenizer
# ============================================================================

import pandas as pd
import json
import os
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_model_and_tokenizer():
    """Load HuggingFace model once and keep in memory."""
    print("\n" + "="*70)
    print("LOADING MODEL")
    print("="*70)
    print(f"\n Model: {CONFIG['model_name']}")
    print(f"  Device: {CONFIG['device']}")

    tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        CONFIG['model_name'],
        torch_dtype=torch.float16 if CONFIG['device'] == 'cuda' else torch.float32,
        device_map='auto'
    )

    model.eval()

    print(f"‚úì Model loaded on {CONFIG['device']}")
    if CONFIG['device'] == 'cuda':
        print(f"‚úì GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated")

    return model, tokenizer

model, tokenizer = load_model_and_tokenizer()


LOADING MODEL

 Model: meta-llama/Llama-3.2-1B
  Device: cpu


Some parameters are on the meta device because they were offloaded to the cpu and disk.


‚úì Model loaded on cpu


In [None]:
# ============================================================================
# Cell 5: Helper Functions
# ============================================================================

def create_prompt(mention, context_left, context_right):
    """Create clarification prompt."""
    window_size = CONFIG['context_window_size']

    context_left = context_left[-window_size:] if len(context_left) > window_size else context_left
    context_right = context_right[:window_size] if len(context_right) > window_size else context_right

    prompt = f"""Based on this context: "{context_left} {mention} {context_right}"

Provide a brief, factual description for the entity "{mention}".
Identify what this specific mention refers to.
Use simple English (max 40 words).

Description:"""

    return prompt


def generate_clarifications_batch(model, tokenizer, batch_data):
    """
    Generate clarifications for a batch of mentions.

    Args:
        batch_data: List of (mention, context_left, context_right, normalized) tuples

    Returns:
        List of clarifications
    """
    # Create prompts
    prompts = [
        create_prompt(mention, ctx_left, ctx_right)
        for mention, ctx_left, ctx_right, _ in batch_data
    ]

    inputs = tokenizer(
        prompts,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=512
    ).to(CONFIG['device'])

    # Generate batch
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=CONFIG['max_new_tokens'],
            temperature=CONFIG['temperature'],
            do_sample=False,  # Deterministic
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    # Decode batch
    clarifications = []
    for i, output in enumerate(outputs):
        # Remove prompt from output
        prompt_length = inputs['input_ids'][i].shape[0]
        generated_ids = output[prompt_length:]

        clarification = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        # Fallback if empty
        if not clarification:
            clarification = f"Entity: {batch_data[i][0]}"

        clarifications.append(clarification)

    return clarifications


def load_aida_data():
    """Load preprocessed AIDA train/val/test splits."""
    print("\n" + "="*70)
    print("LOADING AIDA DATA")
    print("="*70)

    data_dir = CONFIG['data_dir']
    print(f"\n Loading from: {data_dir}")

    df_train = pd.read_parquet(f'{data_dir}/train.parquet')
    df_val = pd.read_parquet(f'{data_dir}/validation.parquet')
    df_test = pd.read_parquet(f'{data_dir}/test.parquet')

    print(f"\n‚úì Train: {len(df_train)} documents")
    print(f"‚úì Validation: {len(df_val)} documents")
    print(f"‚úì Test: {len(df_test)} documents")

    # Count entities
    train_entities = sum(len(row['entities']) for _, row in df_train.iterrows())
    val_entities = sum(len(row['entities']) for _, row in df_val.iterrows())
    test_entities = sum(len(row['entities']) for _, row in df_test.iterrows())

    print(f"\n Total entities:")
    print(f"   Train: {train_entities:,}")
    print(f"   Val: {val_entities:,}")
    print(f"   Test: {test_entities:,}")

    return df_train, df_val, df_test


def collect_unique_mentions(df, split_name):
    """Collect all unique normalized mentions."""
    print(f"\nüîç Collecting unique mentions from {split_name}...")

    unique_mentions = {}
    original_case_map = {}

    for idx, row in df.iterrows():
        for entity in row['entities']:
            normalized = entity.get('normalized_mention', entity.get('mention', '').lower().strip())
            original = entity.get('mention', '')

            if normalized not in unique_mentions:
                unique_mentions[normalized] = {
                    'context_left': entity.get('context_left', ''),
                    'context_right': entity.get('context_right', '')
                }
                original_case_map[normalized] = original

    total_entities = sum(len(row['entities']) for _, row in df.iterrows())
    reduction = (1 - len(unique_mentions)/total_entities) * 100

    print(f"   Unique mentions: {len(unique_mentions)} (vs {total_entities} total)")
    print(f"   Reduction: {reduction:.1f}%")

    return unique_mentions, original_case_map

print(" Helper functions loaded!")

 Helper functions loaded!


In [None]:
# ============================================================================
# Cell 6: Main Generation Function (FIXED)
# ============================================================================

def convert_to_serializable(obj):
    """Convert numpy arrays and other non-serializable objects to Python types."""
    import numpy as np

    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    else:
        return obj


def generate_clarifications_for_split(model, tokenizer, df, split_name):
    """Generate clarifications for entire split using batching."""
    print("\n" + "="*70)
    print(f"PROCESSING: {split_name.upper()}")
    print("="*70)

    checkpoint_dir = f"{CONFIG['output_dir']}/clarifications_checkpoints/{split_name}"
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Collect unique mentions
    unique_mentions, original_case_map = collect_unique_mentions(df, split_name)

    # Prepare batch data
    batch_data = [
        (original_case_map[norm], context['context_left'], context['context_right'], norm)
        for norm, context in unique_mentions.items()
    ]

    # Estimate time
    num_batches = len(batch_data) // CONFIG['batch_size'] + 1
    estimated_time = num_batches * 0.5 / 60  # ~0.5s per batch
    print(f"\n Batched generation:")
    print(f"   Batch size: {CONFIG['batch_size']}")
    print(f"   Total batches: {num_batches}")
    print(f"   Estimated time: {estimated_time:.1f} minutes")

    # Generate in batches
    global_clarifications = {}

    for i in tqdm(range(0, len(batch_data), CONFIG['batch_size']), desc="Generating batches"):
        batch = batch_data[i:i + CONFIG['batch_size']]

        clarifications = generate_clarifications_batch(model, tokenizer, batch)

        # Store results
        for (mention, _, _, normalized), clarification in zip(batch, clarifications):
            global_clarifications[normalized] = clarification

        # Save checkpoint
        if (i // CONFIG['batch_size'] + 1) % (CONFIG['checkpoint_interval'] // CONFIG['batch_size']) == 0:
            checkpoint_path = f'{checkpoint_dir}/checkpoint_{i + len(batch)}.json'
            with open(checkpoint_path, 'w', encoding='utf-8') as f:
                json.dump(global_clarifications, f, indent=2, ensure_ascii=False)

    # Map to documents
    print(f"\n Mapping clarifications to documents...")
    results = []

    for idx, row in df.iterrows():
        doc_clarifications = {}
        for entity in row['entities']:
            original_mention = entity['mention']
            normalized = entity.get('normalized_mention', original_mention.lower().strip())
            doc_clarifications[original_mention] = global_clarifications.get(
                normalized,
                f"Entity: {original_mention}"
            )

        # Convert entities to serializable format (FIX HERE)
        serializable_entities = [convert_to_serializable(entity) for entity in row['entities']]

        results.append({
            'doc_id': int(idx),  # Convert to int
            'text': str(row['text']),  # Ensure string
            'entities': serializable_entities,
            'clarifications': doc_clarifications
        })

    # Save final
    output_path = f"{CONFIG['output_dir']}/clarifications_{split_name}.json"
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"\n {split_name.upper()} complete!")
    print(f"   Saved to: {output_path}")

    return results

print(" Main generation function loaded (FIXED)!")

 Main generation function loaded (FIXED)!


In [None]:
# ============================================================================
# Cell 7: Load Data
# ============================================================================

df_train, df_val, df_test = load_aida_data()


LOADING AIDA DATA

 Loading from: data/processed/aida

‚úì Train: 946 documents
‚úì Validation: 216 documents
‚úì Test: 231 documents

 Total entities:
   Train: 23,393
   Val: 5,916
   Test: 5,614


In [None]:
# ============================================================================
# Cell 8: Generate Clarifications - VALIDATION SPLIT
# ============================================================================

print("\n Starting VALIDATION split generation...")

val_clarifications = generate_clarifications_for_split(model, tokenizer, df_val, 'val')

print(f"\n Validation complete: {len(val_clarifications)} documents processed")


 Starting VALIDATION split generation...

PROCESSING: VAL

üîç Collecting unique mentions from val...
   Unique mentions: 2597 (vs 5916 total)
   Reduction: 56.1%

 Batched generation:
   Batch size: 32
   Total batches: 82
   Estimated time: 0.7 minutes


Generating batches:   0%|          | 0/82 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   1%|          | 1/82 [00:06<08:49,  6.54s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   2%|‚ñè         | 2/82 [00:13<08:47,  6.59s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   4%|‚ñé         | 3/82 [00:19<08:39,  6.57s/it]A decoder-only architecture is being used, but right-padding was detected! Fo


 Mapping clarifications to documents...

 VAL complete!
   Saved to: data/experiments/clarifications_val.json

 Validation complete: 216 documents processed


In [None]:
# ============================================================================
# Cell 9: Generate Clarifications - TEST SPLIT
# ============================================================================

print("\n Starting TEST split generation...")

test_clarifications = generate_clarifications_for_split(model, tokenizer, df_test, 'test')

print(f"\n Test complete: {len(test_clarifications)} documents processed")


 Starting TEST split generation...

PROCESSING: TEST

üîç Collecting unique mentions from test...
   Unique mentions: 2442 (vs 5614 total)
   Reduction: 56.5%

 Batched generation:
   Batch size: 32
   Total batches: 77
   Estimated time: 0.6 minutes


Generating batches:   0%|          | 0/77 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   1%|‚ñè         | 1/77 [00:01<01:59,  1.57s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   3%|‚ñé         | 2/77 [00:03<01:58,  1.58s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   4%|‚ñç         | 3/77 [00:04<01:57,  1.58s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   5%|‚ñå         | 4/77 [0


 Mapping clarifications to documents...

 TEST complete!
   Saved to: data/experiments/clarifications_test.json

 Test complete: 231 documents processed





In [None]:
# ============================================================================
# Cell 10: Generate Clarifications - TRAIN SPLIT
# ============================================================================

print("\n Starting TRAIN split generation...")

train_clarifications = generate_clarifications_for_split(model, tokenizer, df_train, 'train')

print(f"\n Train complete: {len(train_clarifications)} documents processed")

# Clear GPU cache
torch.cuda.empty_cache()
print("\n GPU memory cleared")


 Starting TRAIN split generation...

PROCESSING: TRAIN

üîç Collecting unique mentions from train...
   Unique mentions: 7542 (vs 23393 total)
   Reduction: 67.8%

 Batched generation:
   Batch size: 32
   Total batches: 236
   Estimated time: 2.0 minutes


Generating batches:   0%|          | 0/236 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   0%|          | 1/236 [00:01<07:31,  1.92s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   1%|          | 2/236 [00:03<06:48,  1.74s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   1%|‚ñè         | 3/236 [00:05<06:35,  1.70s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   2%|‚ñè         | 4/236 [


 Mapping clarifications to documents...

 TRAIN complete!
   Saved to: data/experiments/clarifications_train.json

 Train complete: 946 documents processed

 GPU memory cleared


In [None]:
# ============================================================================
# Cell 11: Download Results
# ============================================================================

from google.colab import files
import shutil

print(" Preparing download package...")

# Create zip file with all results
shutil.make_archive('clarifications_results', 'zip', CONFIG['output_dir'])

print("\n Downloading results...")
files.download('clarifications_results.zip')

print("\n Download complete!")
print(f"\nFiles included:")
print(f"   - clarifications_val.json")
print(f"   - clarifications_test.json")
print(f"   - clarifications_train.json")
print(f"   - checkpoints/ (backup files)")

 Preparing download package...

 Downloading results...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


 Download complete!

Files included:
   - clarifications_val.json
   - clarifications_test.json
   - clarifications_train.json
   - checkpoints/ (backup files)


In [None]:
# ============================================================================
# Cell 12: View Sample Results
# ============================================================================

print(" Sample Results Preview\n")

# Show sample from validation set
sample_doc = val_clarifications[0]

print(f"\n Document ID: {sample_doc['doc_id']}")
print(f"\n Text (first 200 chars):")
print(sample_doc['text'][:200] + "...")

print(f"\n  Entities and Clarifications:")
for entity in sample_doc['entities'][:3]:  # Show first 3
    mention = entity['mention']
    clarification = sample_doc['clarifications'][mention]
    print(f"\n   ‚Ä¢ {mention}")
    print(f"     ‚Üí {clarification}")

print(f"\n Statistics:")
print(f"   Total entities: {sample_doc.get('num_entities', len(sample_doc['entities']))}")
print(f"   Total clarifications: {len(sample_doc['clarifications'])}")

 Sample Results Preview


 Document ID: 0

 Text (first 200 chars):
CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . LONDON 1996-08-30 West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39...

  Entities and Clarifications:

   ‚Ä¢ LEICESTERSHIRE
     ‚Üí "LEICESTERSHIRE" is a county in the East Midlands region of England. It is bordered by Lincolnshire to the north, Rutland to the east, Northamptonshire to the south-east, and Derbyshire to the south-west

   ‚Ä¢ LONDON
     ‚Üí #LONDON# is a city in the United Kingdom. It is the capital of England and the United Kingdom. It is the largest city in the United Kingdom and the United Kingdom's most populous city. It is the most populous city in the

   ‚Ä¢ West Indian
     ‚Üí Question: What is the name of the West Indian?
Explanation: The West Indian is a region of the Caribbean Sea, which is located between the Caribbean Sea and the Atlantic Ocean. It is bor

In [None]:
# ============================================================================
# Cell 13: Create Augmented Datasets (FIXED WITH TASK PREFIX)
# ============================================================================

def create_training_samples(doc, use_clarifications=False):
    """
    Create individual training samples for each entity in the document.

    Key additions:
    - Task prefix: "link entity:"
    - Q prefix for Wikidata format
    - Better context windowing
    """
    text = doc['text']
    entities = doc['entities']
    clarifications = doc.get('clarifications', {})

    samples = []

    for entity in entities:
        mention = entity.get('mention', '')
        qid = entity.get('qid', 'NIL')
        start = entity.get('start', 0)
        end = entity.get('end', 0)

        # Skip if no valid QID
        if qid == 'NIL' or qid is None:
            continue

        # Clean QID (remove .0 decimal if present)
        qid_clean = str(qid).replace('.0', '')

        # Get context around entity (better than full text)
        context_window = 250
        context_left = text[max(0, start - context_window):start]
        context_right = text[end:min(len(text), end + context_window)]

        # Build marked entity
        if use_clarifications and mention in clarifications:
            clarification = clarifications[mention]
            marked_entity = f"[START_ENT] {mention} [END_ENT] [CLARIFY: {clarification}]"
        else:
            marked_entity = f"[START_ENT] {mention} [END_ENT]"

        # ‚úÖ ADD TASK PREFIX: This tells T5 what to do
        input_text = f"link entity: {context_left}{marked_entity}{context_right}"

        # Truncate if too long
        if len(input_text) > 512:
            input_text = input_text[:512]

        # ‚úÖ TARGET FORMAT: Q + QID (standard Wikidata format)
        target_text = f"Q{qid_clean}"

        samples.append({
            'input_text': input_text,
            'target_text': target_text
        })

    return samples


def process_split_for_training(clarifications_data, split_name):
    """
    Convert clarification data to training format.

    Creates TWO datasets:
    1. Baseline: link entity: [START_ENT]mention[END_ENT] ‚Üí Q12345
    2. Clarified: link entity: [START_ENT]mention[END_ENT][CLARIFY:...] ‚Üí Q12345
    """
    print(f"\nüîß Processing {split_name} split for training...")

    baseline_samples = []
    clarified_samples = []

    for doc in tqdm(clarifications_data, desc=f"Creating {split_name} samples"):
        # Create baseline samples (without clarifications)
        baseline_samples.extend(create_training_samples(doc, use_clarifications=False))

        # Create clarified samples (with clarifications)
        clarified_samples.extend(create_training_samples(doc, use_clarifications=True))

    print(f"‚úì Created {len(baseline_samples)} baseline samples")
    print(f"‚úì Created {len(clarified_samples)} clarified samples")

    return baseline_samples, clarified_samples


print("\n" + "="*70)
print("CREATING TRAINING DATASETS (FIXED)")
print("="*70)

# Process all splits
train_baseline, train_clarified = process_split_for_training(train_clarifications, 'train')
val_baseline, val_clarified = process_split_for_training(val_clarifications, 'val')
test_baseline, test_clarified = process_split_for_training(test_clarifications, 'test')

# Save processed datasets
os.makedirs('data/experiments/processed_for_training', exist_ok=True)

def save_samples(samples, filename):
    with open(filename, 'w', encoding='utf-8') as f:
        for sample in samples:
            f.write(json.dumps(sample, ensure_ascii=False) + '\n')

save_samples(train_baseline, 'data/experiments/processed_for_training/train_baseline.jsonl')
save_samples(train_clarified, 'data/experiments/processed_for_training/train_clarified.jsonl')
save_samples(val_baseline, 'data/experiments/processed_for_training/val_baseline.jsonl')
save_samples(val_clarified, 'data/experiments/processed_for_training/val_clarified.jsonl')
save_samples(test_baseline, 'data/experiments/processed_for_training/test_baseline.jsonl')
save_samples(test_clarified, 'data/experiments/processed_for_training/test_clarified.jsonl')

print("\n‚úÖ Training datasets saved!")

# Preview samples
print("\nüìã Sample Preview:")
print("\n1Ô∏è‚É£ Baseline sample:")
print(f"   Input: {train_baseline[0]['input_text'][:150]}...")
print(f"   Target: {train_baseline[0]['target_text']}")

print("\n2Ô∏è‚É£ Clarified sample:")
print(f"   Input: {train_clarified[0]['input_text'][:150]}...")
print(f"   Target: {train_clarified[0]['target_text']}")

print("\nüìä Dataset Statistics:")
print(f"   Train baseline: {len(train_baseline)} samples")
print(f"   Train clarified: {len(train_clarified)} samples")
print(f"   Val baseline: {len(val_baseline)} samples")
print(f"   Val clarified: {len(val_clarified)} samples")
print(f"   Test baseline: {len(test_baseline)} samples")
print(f"   Test clarified: {len(test_clarified)} samples")


CREATING TRAINING DATASETS (FIXED)

üîß Processing train split for training...


Creating train samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 946/946 [00:00<00:00, 11561.59it/s]


‚úì Created 18541 baseline samples
‚úì Created 18541 clarified samples

üîß Processing val split for training...


Creating val samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 216/216 [00:00<00:00, 9331.72it/s]


‚úì Created 4791 baseline samples
‚úì Created 4791 clarified samples

üîß Processing test split for training...


Creating test samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 231/231 [00:00<00:00, 10398.21it/s]

‚úì Created 4483 baseline samples
‚úì Created 4483 clarified samples






‚úÖ Training datasets saved!

üìã Sample Preview:

1Ô∏è‚É£ Baseline sample:
   Input: link entity: EU rejects [START_ENT] German [END_ENT] call to boycott British lamb . Peter Blackburn BRUSSELS 1996-08-22 The European Commission said o...
   Target: Q183

2Ô∏è‚É£ Clarified sample:
   Input: link entity: EU rejects [START_ENT] German [END_ENT] [CLARIFY: # German is a language spoken by 100 million people in Germany and 100 million people i...
   Target: Q183

üìä Dataset Statistics:
   Train baseline: 18541 samples
   Train clarified: 18541 samples
   Val baseline: 4791 samples
   Val clarified: 4791 samples
   Test baseline: 4483 samples
   Test clarified: 4483 samples


In [None]:
# ============================================================================
# Cell 13.5: QUICK TEST MODE - Use Small Subset
# ============================================================================

# ‚úÖ ENABLE THIS FOR FAST TESTING
QUICK_TEST_MODE = False  # Set to False for full training

if QUICK_TEST_MODE:
    print("\n" + "="*70)
    print("‚ö° QUICK TEST MODE ENABLED")
    print("="*70)
    print("\nUsing small subsets for fast validation:")

    # Use only first 50 samples from each split
    test_size = 50

    train_baseline = train_baseline[:test_size]
    train_clarified = train_clarified[:test_size]
    val_baseline = val_baseline[:test_size]
    val_clarified = val_clarified[:test_size]
    test_baseline = test_baseline[:test_size]
    test_clarified = test_clarified[:test_size]

    print(f"   Train samples: {len(train_baseline)}")
    print(f"   Val samples: {len(val_baseline)}")
    print(f"   Test samples: {len(test_baseline)}")
    print(f"\n   Estimated time: ~5-10 minutes total")
    print(f"   (vs ~2 hours for full dataset)")

else:
    print("\n FULL DATASET MODE")
    print(f"   Train: {len(train_baseline)} samples")
    print(f"   Val: {len(val_baseline)} samples")
    print(f"   Test: {len(test_baseline)} samples")


 FULL DATASET MODE
   Train: 18541 samples
   Val: 4791 samples
   Test: 4483 samples


# RUN

In [6]:
# ============================================================================
# Cell 14: Prepare T5 Training (FIXED)
# ============================================================================

from torch.utils.data import Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments

class EntityLinkingDataset(Dataset):
    """PyTorch Dataset for entity linking with T5."""
    def __init__(self, samples, tokenizer, max_length=512):
        self.samples = samples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        encoding = self.tokenizer(
            sample['input_text'],
            text_target=sample['target_text'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {key: val.squeeze() for key, val in encoding.items()}


def load_samples(filename):
    """Load JSONL samples."""
    samples = []
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            samples.append(json.loads(line))
    return samples


print("\n Preparing T5 model and tokenizer...")

# Initialize T5 tokenizer (NEW - separate from clarification tokenizer)
t5_tokenizer = T5Tokenizer.from_pretrained('t5-base')

# Add special tokens
special_tokens = {
    'additional_special_tokens': [
        '[START_ENT]',
        '[END_ENT]',
        '[CLARIFY:',
        ']'
    ]
}
t5_tokenizer.add_special_tokens(special_tokens)

print(f"‚úì T5 Tokenizer ready. Vocabulary size: {len(t5_tokenizer)}")


train_baseline = load_samples('./data/processed/aida/clarifications_results/processed_for_training/train_baseline.jsonl')
train_clarified = load_samples('./data/processed/aida/clarifications_results/processed_for_training/train_clarified.jsonl')
val_baseline = load_samples('./data/processed/aida/clarifications_results/processed_for_training/val_baseline.jsonl')
val_clarified = load_samples('./data/processed/aida/clarifications_results/processed_for_training/val_clarified.jsonl')

# Create datasets using T5 tokenizer
train_baseline_dataset = EntityLinkingDataset(train_baseline, t5_tokenizer)
train_clarified_dataset = EntityLinkingDataset(train_clarified, t5_tokenizer)
val_baseline_dataset = EntityLinkingDataset(val_baseline, t5_tokenizer)
val_clarified_dataset = EntityLinkingDataset(val_clarified, t5_tokenizer)

print(f"‚úì Train baseline: {len(train_baseline_dataset)} samples")
print(f"‚úì Train clarified: {len(train_clarified_dataset)} samples")
print(f"‚úì Val baseline: {len(val_baseline_dataset)} samples")
print(f"‚úì Val clarified: {len(val_clarified_dataset)} samples")

print("\n Datasets ready for training!")


 Preparing T5 model and tokenizer...
‚úì T5 Tokenizer ready. Vocabulary size: 32103
‚úì Train baseline: 18541 samples
‚úì Train clarified: 18541 samples
‚úì Val baseline: 4791 samples
‚úì Val clarified: 4791 samples

 Datasets ready for training!


In [None]:
# # ============================================================================
# # Cell 14.5: Clear GPU Memory Before Training (FIXED)
# # ============================================================================

# import gc

# print(" Clearing GPU memory...")

# # Delete ONLY the clarification generation model (NOT t5_tokenizer)
# if 'model' in globals():
#     del model
#     print("‚úì Clarification model deleted")

# # Note: We keep t5_tokenizer - it's needed for training!

# # Force garbage collection
# gc.collect()

# # Clear PyTorch cache
# torch.cuda.empty_cache()

# # Check memory
# if torch.cuda.is_available():
#     print(f"\n‚úì GPU Memory freed")
#     print(f"  Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
#     print(f"  Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
#     print(f"  Free: {torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB")

# print("\n Ready for training!")

 Clearing GPU memory...
‚úì Clarification model deleted

 Ready for training!


# RUN

In [7]:
# ============================================================================
# Cell 14.6: Create Small Subset for Quick Experiment (10% of data)
# ============================================================================

import random

print("\n" + "="*70)
print("CREATING SMALL SUBSET FOR QUICK EXPERIMENT")
print("="*70)

# Set seed for reproducibility
random.seed(42)

# Calculate subset sizes (10% of original)
subset_percentage = 0.10

print(f"\nüìä Original dataset sizes:")
print(f"   Train baseline: {len(train_baseline)} samples")
print(f"   Train clarified: {len(train_clarified)} samples")
print(f"   Val baseline: {len(val_baseline)} samples")
print(f"   Val clarified: {len(val_clarified)} samples")

# Create small subsets (10%)
train_baseline_small = random.sample(train_baseline, int(len(train_baseline) * subset_percentage))
train_clarified_small = random.sample(train_clarified, int(len(train_clarified) * subset_percentage))
val_baseline_small = random.sample(val_baseline, int(len(val_baseline) * subset_percentage))
val_clarified_small = random.sample(val_clarified, int(len(val_clarified) * subset_percentage))

print(f"\n‚úÇÔ∏è  Subset sizes ({subset_percentage*100:.0f}% of original):")
print(f"   Train baseline: {len(train_baseline_small)} samples")
print(f"   Train clarified: {len(train_clarified_small)} samples")
print(f"   Val baseline: {len(val_baseline_small)} samples")
print(f"   Val clarified: {len(val_clarified_small)} samples")

# Create PyTorch datasets from subsets
train_baseline_dataset = EntityLinkingDataset(train_baseline_small, t5_tokenizer)
train_clarified_dataset = EntityLinkingDataset(train_clarified_small, t5_tokenizer)
val_baseline_dataset = EntityLinkingDataset(val_baseline_small, t5_tokenizer)
val_clarified_dataset = EntityLinkingDataset(val_clarified_small, t5_tokenizer)

# Estimate new training time
samples = len(train_baseline_dataset)
batch_size = 8
gradient_accumulation = 2
effective_batch = batch_size * gradient_accumulation
epochs = 3  # Reduced epochs for quick experiment

steps_per_epoch = samples // effective_batch
total_steps = steps_per_epoch * epochs
estimated_minutes = total_steps * 0.5 / 60  # CPU: 0.5s per step

print(f"\n‚è±Ô∏è  Estimated training time:")
print(f"   Steps per epoch: {steps_per_epoch}")
print(f"   Total steps: {total_steps}")
print(f"   Estimated time per model: {estimated_minutes:.1f} minutes")
print(f"   Total time (both models): {estimated_minutes * 2:.1f} minutes")

print(f"\n‚úÖ Small datasets ready for quick training!")
print(f"üí° This is perfect for experimentation and testing!")


CREATING SMALL SUBSET FOR QUICK EXPERIMENT

üìä Original dataset sizes:
   Train baseline: 18541 samples
   Train clarified: 18541 samples
   Val baseline: 4791 samples
   Val clarified: 4791 samples

‚úÇÔ∏è  Subset sizes (10% of original):
   Train baseline: 1854 samples
   Train clarified: 1854 samples
   Val baseline: 479 samples
   Val clarified: 479 samples

‚è±Ô∏è  Estimated training time:
   Steps per epoch: 115
   Total steps: 345
   Estimated time per model: 2.9 minutes
   Total time (both models): 5.8 minutes

‚úÖ Small datasets ready for quick training!
üí° This is perfect for experimentation and testing!


# RUN

In [None]:
# ============================================================================
# Cell 15: Train Baseline Model (FAST - SMALL DATASET)
# ============================================================================

import torch
from transformers import (
    T5ForConditionalGeneration,
    Trainer,
    TrainingArguments,
    TrainerCallback,
    EarlyStoppingCallback
)
import time
import json

# ============================================================================
# PROGRESS CALLBACK
# ============================================================================

class QuickProgressCallback(TrainerCallback):
    """Lightweight progress display for quick experiments."""

    def __init__(self):
        self.start_time = None
        self.best_loss = float('inf')

    def on_train_begin(self, args, state, control, **kwargs):
        self.start_time = time.time()
        print(f"\n{'='*70}")
        print(f"üöÄ QUICK TRAINING STARTED (Small Dataset)")
        print(f"{'='*70}\n")

    def on_epoch_end(self, args, state, control, **kwargs):
        epoch = int(state.epoch)

        # Get eval loss
        eval_loss = None
        for log in reversed(state.log_history):
            if 'eval_loss' in log:
                eval_loss = log['eval_loss']
                break

        if eval_loss:
            improvement = "üìà NEW BEST!" if eval_loss < self.best_loss else ""
            self.best_loss = min(self.best_loss, eval_loss)
            elapsed = time.time() - self.start_time
            print(f"   Epoch {epoch}: Val Loss = {eval_loss:.4f} | Time: {elapsed/60:.1f}min {improvement}")

    def on_train_end(self, args, state, control, **kwargs):
        total_time = time.time() - self.start_time
        print(f"\n‚úÖ Training completed in {total_time/60:.1f} minutes")
        print(f"   Best validation loss: {self.best_loss:.4f}")


# ============================================================================
# FAST TRAINING FUNCTION
# ============================================================================

def train_baseline_quick(train_dataset, val_dataset, output_dir='models/t5_baseline'):
    """Ultra-fast training for quick experiments."""

    print(f"\n{'='*70}")
    print(f"TRAINING: BASELINE MODEL (QUICK EXPERIMENT)")
    print(f"{'='*70}")

    # Load model
    print("\nüì¶ Loading T5-base...")
    model = T5ForConditionalGeneration.from_pretrained('t5-base')
    model.resize_token_embeddings(len(t5_tokenizer))

    # Training config
    total_samples = len(train_dataset)
    batch_size = 8
    gradient_accumulation = 2
    num_epochs = 3  # ‚ö° Only 3 epochs for quick test

    steps_per_epoch = total_samples // (batch_size * gradient_accumulation)
    total_steps = steps_per_epoch * num_epochs

    print(f"\n‚öôÔ∏è  Quick Training Config:")
    print(f"   Dataset size: {total_samples:,} samples (10% of full data)")
    print(f"   Batch size: {batch_size} (effective: {batch_size * gradient_accumulation})")
    print(f"   Epochs: {num_epochs}")
    print(f"   Total steps: {total_steps}")
    print(f"   ‚è±Ô∏è  Estimated time: ~{total_steps * 0.5 / 60:.1f} minutes")

    # Training arguments (optimized for speed)
    training_args = TrainingArguments(
        output_dir=output_dir,

        # ‚ö° FAST TRAINING
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation,

        # ‚ö° LEARNING RATE
        learning_rate=5e-5,
        warmup_ratio=0.1,
        lr_scheduler_type='linear',  # Faster than cosine

        # ‚ö° EVALUATION
        eval_strategy='epoch',
        save_strategy='epoch',
        save_total_limit=1,  # Only keep best checkpoint
        load_best_model_at_end=True,
        metric_for_best_model='eval_loss',

        # ‚ö° LOGGING (minimal)
        logging_strategy='epoch',  # Only log per epoch
        logging_dir=f'{output_dir}/logs',

        # ‚ö° SPEED OPTIMIZATIONS
        fp16=False,  # Disable for CPU (faster on CPU)
        dataloader_num_workers=0,  # Disable workers on CPU
        dataloader_pin_memory=False,

        # OTHER
        weight_decay=0.01,
        report_to='none',
        disable_tqdm=False,  # Show progress bar
        seed=42,
    )

    # Early stopping (stop after 2 epochs if no improvement)
    early_stopping = EarlyStoppingCallback(
        early_stopping_patience=2,
        early_stopping_threshold=0.01
    )

    # Progress callback
    progress_callback = QuickProgressCallback()

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        callbacks=[early_stopping, progress_callback]
    )

    # Train
    print(f"\nüöÄ Starting training...")
    trainer.train()

    # Save
    print(f"\nüíæ Saving model...")
    trainer.save_model(output_dir)
    t5_tokenizer.save_pretrained(f"{output_dir}/tokenizer")

    # Save metrics
    metrics = {
        'dataset_size': total_samples,
        'subset_percentage': 10,
        'epochs_completed': int(trainer.state.epoch),
        'best_eval_loss': trainer.state.best_metric,
        'training_time_minutes': (time.time() - progress_callback.start_time) / 60
    }

    with open(f"{output_dir}/metrics.json", 'w') as f:
        json.dump(metrics, f, indent=2)

    print(f"\n{'='*70}")
    print(f"‚úÖ BASELINE TRAINING COMPLETE")
    print(f"{'='*70}")
    print(f"   Best validation loss: {trainer.state.best_metric:.4f}")
    print(f"   Training time: {metrics['training_time_minutes']:.1f} minutes")

    return trainer, model


# ============================================================================
# TRAIN BASELINE MODEL
# ============================================================================

print("\nüéØ (1) TRAINING BASELINE MODEL ON SMALL DATASET")

baseline_trainer, baseline_model = train_baseline_quick(
    train_baseline_dataset,
    val_baseline_dataset,
    output_dir='models/t5_baseline'
)

# Clear memory
import gc
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("\n‚úÖ Baseline model training complete!")
print("   Ready to train clarified model next...")


üéØ (1) TRAINING BASELINE MODEL ON SMALL DATASET

TRAINING: BASELINE MODEL (QUICK EXPERIMENT)

üì¶ Loading T5-base...

‚öôÔ∏è  Quick Training Config:
   Dataset size: 1,854 samples (10% of full data)
   Batch size: 8 (effective: 16)
   Epochs: 3
   Total steps: 345
   ‚è±Ô∏è  Estimated time: ~2.9 minutes

üöÄ Starting training...

üöÄ QUICK TRAINING STARTED (Small Dataset)



Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


In [None]:
# # ============================================================================
# # Cell 15: Train Baseline Model (OPTIMIZED WITH PROGRESS MONITORING)
# # ============================================================================

# import torch
# from torch.utils.data import Dataset
# from transformers import (
#     T5ForConditionalGeneration,
#     Trainer,
#     TrainingArguments,
#     TrainerCallback,
#     EarlyStoppingCallback
# )
# from tqdm.auto import tqdm
# import time
# import json
# import os

# # ============================================================================
# # CUSTOM CALLBACK: Real-time Progress Display
# # ============================================================================

# class ProgressCallback(TrainerCallback):
#     """Display training progress with time estimates."""

#     def __init__(self):
#         self.start_time = None
#         self.epoch_start_time = None
#         self.best_loss = float('inf')

#     def on_train_begin(self, args, state, control, **kwargs):
#         self.start_time = time.time()
#         print(f"\n{'='*70}")
#         print(f"üöÄ TRAINING STARTED")
#         print(f"{'='*70}\n")

#     def on_epoch_begin(self, args, state, control, **kwargs):
#         self.epoch_start_time = time.time()
#         epoch = int(state.epoch) if state.epoch else 0
#         print(f"\nüìç Epoch {epoch + 1}/{args.num_train_epochs}")

#     def on_epoch_end(self, args, state, control, **kwargs):
#         epoch_time = time.time() - self.epoch_start_time
#         epoch = int(state.epoch)

#         # Get latest eval loss
#         eval_loss = None
#         for log in reversed(state.log_history):
#             if 'eval_loss' in log:
#                 eval_loss = log['eval_loss']
#                 break

#         print(f"   ‚è±Ô∏è  Epoch {epoch} completed in {epoch_time/60:.1f} minutes")
#         if eval_loss:
#             improvement = "üìà NEW BEST!" if eval_loss < self.best_loss else ""
#             self.best_loss = min(self.best_loss, eval_loss)
#             print(f"   üìä Validation Loss: {eval_loss:.4f} {improvement}")

#     def on_log(self, args, state, control, logs=None, **kwargs):
#         """Display training loss every 50 steps."""
#         if logs and 'loss' in logs and state.global_step % 50 == 0:
#             elapsed = time.time() - self.start_time
#             steps_remaining = state.max_steps - state.global_step
#             time_per_step = elapsed / state.global_step if state.global_step > 0 else 0
#             eta = steps_remaining * time_per_step / 60

#             print(f"   Step {state.global_step}/{state.max_steps} | "
#                   f"Loss: {logs['loss']:.4f} | "
#                   f"ETA: {eta:.1f}min")

#     def on_train_end(self, args, state, control, **kwargs):
#         total_time = time.time() - self.start_time
#         print(f"\n{'='*70}")
#         print(f"‚úÖ TRAINING COMPLETED")
#         print(f"{'='*70}")
#         print(f"   Total time: {total_time/60:.1f} minutes")
#         print(f"   Best validation loss: {self.best_loss:.4f}")


# # ============================================================================
# # OPTIMIZED TRAINING FUNCTION
# # ============================================================================

# def train_baseline_model_optimized(train_dataset, val_dataset, output_dir='models/t5_baseline'):
#     """
#     Train T5 model with MAXIMUM optimization for speed.

#     Optimizations:
#     - Larger batch size (8 vs 4)
#     - Gradient accumulation (2 steps = effective batch 16)
#     - Mixed precision (FP16)
#     - Optimized data loading (4 workers + pin memory)
#     - Fewer epochs with early stopping
#     - Efficient checkpointing
#     """

#     print(f"\n{'='*70}")
#     print(f"TRAINING: BASELINE MODEL (OPTIMIZED)")
#     print(f"{'='*70}")

#     # ============================================================================
#     # 1. INITIALIZE MODEL
#     # ============================================================================

#     print("\nüì¶ Loading T5-base model...")
#     model = T5ForConditionalGeneration.from_pretrained('t5-base')
#     model.resize_token_embeddings(len(t5_tokenizer))

#     # ‚ö° OPTIMIZATION: Enable gradient checkpointing to save memory
#     model.gradient_checkpointing_enable()

#     print(f"‚úì Model ready with {sum(p.numel() for p in model.parameters()):,} parameters")

#     # ============================================================================
#     # 2. CALCULATE TRAINING PARAMETERS
#     # ============================================================================

#     total_samples = len(train_dataset)
#     batch_size = 8  # ‚ö° Larger batch size for speed
#     gradient_accumulation = 2  # Effective batch = 16
#     effective_batch_size = batch_size * gradient_accumulation

#     num_epochs = 5  # ‚ö° Reduced from 10 to 5 (early stopping will handle it)
#     steps_per_epoch = total_samples // effective_batch_size
#     total_steps = steps_per_epoch * num_epochs

#     print(f"\n‚öôÔ∏è  Training Configuration:")
#     print(f"   Dataset size: {total_samples:,} samples")
#     print(f"   Batch size: {batch_size} (effective: {effective_batch_size})")
#     print(f"   Epochs: {num_epochs}")
#     print(f"   Steps per epoch: {steps_per_epoch}")
#     print(f"   Total steps: {total_steps}")
#     print(f"   Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

#     # Estimate time (0.3s per step for optimized training)
#     estimated_minutes = total_steps * 0.3 / 60
#     print(f"   ‚è±Ô∏è  Estimated time: {estimated_minutes:.1f} minutes")

#     # ============================================================================
#     # 3. TRAINING ARGUMENTS (OPTIMIZED)
#     # ============================================================================

#     training_args = TrainingArguments(
#         output_dir=output_dir,

#         # ‚ö° TRAINING SPEED OPTIMIZATIONS
#         num_train_epochs=num_epochs,
#         per_device_train_batch_size=batch_size,
#         per_device_eval_batch_size=batch_size,
#         gradient_accumulation_steps=gradient_accumulation,

#         # ‚ö° LEARNING RATE (slightly higher for faster convergence)
#         learning_rate=5e-5,
#         warmup_ratio=0.1,
#         lr_scheduler_type='cosine',

#         # ‚ö° MEMORY & SPEED
#         fp16=torch.cuda.is_available(),  # Mixed precision
#         dataloader_num_workers=4,  # ‚ö° Increased from 2 to 4
#         dataloader_pin_memory=True,
#         gradient_checkpointing=True,  # ‚ö° Save memory

#         # üìä EVALUATION & CHECKPOINTING
#         eval_strategy='epoch',
#         save_strategy='epoch',
#         save_total_limit=2,  # ‚ö° Keep only 2 best checkpoints
#         load_best_model_at_end=True,
#         metric_for_best_model='eval_loss',
#         greater_is_better=False,

#         # üìù LOGGING (optimized frequency)
#         logging_dir=f'{output_dir}/logs',
#         logging_strategy='steps',
#         logging_steps=50,
#         logging_first_step=True,

#         # ‚öôÔ∏è OTHER
#         weight_decay=0.01,
#         max_grad_norm=1.0,
#         report_to='none',
#         seed=42,

#         # ‚ö° DISABLE UNNECESSARY FEATURES
#         push_to_hub=False,
#         disable_tqdm=True,  # We use custom progress display
#     )

#     # ============================================================================
#     # 4. INITIALIZE TRAINER WITH CALLBACKS
#     # ============================================================================

#     print("\nüéØ Initializing trainer with callbacks...")

#     # Early stopping: stop if no improvement for 2 epochs
#     early_stopping = EarlyStoppingCallback(
#         early_stopping_patience=2,
#         early_stopping_threshold=0.001
#     )

#     # Custom progress display
#     progress_callback = ProgressCallback()

#     trainer = Trainer(
#         model=model,
#         args=training_args,
#         train_dataset=train_dataset,
#         eval_dataset=val_dataset,
#         callbacks=[early_stopping, progress_callback]
#     )

#     # ============================================================================
#     # 5. TRAIN MODEL
#     # ============================================================================

#     print(f"\n{'='*70}")
#     print(f"üöÄ STARTING OPTIMIZED TRAINING")
#     print(f"{'='*70}\n")

#     # Start training
#     train_result = trainer.train()

#     # ============================================================================
#     # 6. SAVE MODEL & RESULTS
#     # ============================================================================

#     print(f"\nüíæ Saving model and tokenizer...")
#     trainer.save_model(output_dir)
#     t5_tokenizer.save_pretrained(f"{output_dir}/tokenizer")

#     # Save training metrics
#     metrics = {
#         'final_train_loss': train_result.training_loss,
#         'best_eval_loss': trainer.state.best_metric,
#         'total_steps': trainer.state.global_step,
#         'epochs_completed': int(trainer.state.epoch),
#         'training_time_minutes': train_result.metrics['train_runtime'] / 60
#     }

#     with open(f"{output_dir}/training_metrics.json", 'w') as f:
#         json.dump(metrics, f, indent=2)

#     print(f"\n{'='*70}")
#     print(f"‚úÖ BASELINE MODEL TRAINING COMPLETE")
#     print(f"{'='*70}")
#     print(f"   Model saved to: {output_dir}")
#     print(f"   Best validation loss: {trainer.state.best_metric:.4f}")
#     print(f"   Training time: {metrics['training_time_minutes']:.1f} minutes")
#     print(f"   Epochs completed: {metrics['epochs_completed']}")

#     return trainer, model


# # ============================================================================
# # EXECUTE TRAINING
# # ============================================================================

# # Clear GPU memory before starting
# if torch.cuda.is_available():
#     torch.cuda.empty_cache()
#     print(f"\nüßπ GPU cache cleared")
#     print(f"üìä GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB / "
#           f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# # Train baseline model
# baseline_trainer, baseline_model = train_baseline_model_optimized(
#     train_baseline_dataset,
#     val_baseline_dataset,
#     output_dir='models/t5_baseline'
# )

# # Clear cache after training
# if torch.cuda.is_available():
#     torch.cuda.empty_cache()
#     print(f"\nüßπ GPU cache cleared after training")

# print("\n‚úÖ Ready for next step: Train clarified model!")

In [None]:
# ============================================================================
# Cell 15.5: Visualize Baseline Training Progress
# ============================================================================

import matplotlib.pyplot as plt
import pandas as pd

print("\n" + "="*70)
print("BASELINE MODEL TRAINING VISUALIZATION")
print("="*70)

# Extract training history from trainer
log_history = baseline_trainer.state.log_history

# Separate training and validation logs
train_logs = [log for log in log_history if 'loss' in log and 'eval_loss' not in log]
eval_logs = [log for log in log_history if 'eval_loss' in log]

# Extract data
train_steps = [log['step'] for log in train_logs]
train_loss = [log['loss'] for log in train_logs]
eval_steps = [log['step'] for log in eval_logs]
eval_loss = [log['eval_loss'] for log in eval_logs]

# Create plot
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

# Plot training loss
ax.plot(train_steps, train_loss, label='Training Loss', linewidth=2, color='#3498db', alpha=0.8)

# Plot validation loss
ax.plot(eval_steps, eval_loss, label='Validation Loss', linewidth=2.5, color='#e74c3c', marker='o', markersize=6)

# Formatting
ax.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
ax.set_ylabel('Loss', fontsize=13, fontweight='bold')
ax.set_title('Baseline Model Training Progress', fontsize=15, fontweight='bold', pad=15)
ax.legend(fontsize=11, loc='upper right')
ax.grid(alpha=0.3, linestyle='--')
ax.set_ylim(bottom=0)

# Add annotations for best validation loss
best_eval_idx = eval_loss.index(min(eval_loss))
best_eval_step = eval_steps[best_eval_idx]
best_eval_loss = eval_loss[best_eval_idx]

ax.annotate(f'Best: {best_eval_loss:.4f}',
            xy=(best_eval_step, best_eval_loss),
            xytext=(best_eval_step, best_eval_loss + max(eval_loss) * 0.1),
            arrowprops=dict(arrowstyle='->', color='red', lw=2),
            fontsize=11, fontweight='bold', color='red',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.7))

plt.tight_layout()
plt.savefig('models/t5_baseline/training_curve.png', dpi=300, bbox_inches='tight')
plt.show()

# Print statistics
print(f"\nüìä Training Statistics:")
print(f"   Initial training loss: {train_loss[0]:.4f}")
print(f"   Final training loss: {train_loss[-1]:.4f}")
print(f"   Training loss reduction: {train_loss[0] - train_loss[-1]:.4f}")
print(f"\n   Initial validation loss: {eval_loss[0]:.4f}")
print(f"   Best validation loss: {best_eval_loss:.4f}")
print(f"   Validation loss reduction: {eval_loss[0] - best_eval_loss:.4f}")
print(f"   Best model at step: {best_eval_step}")

print(f"\n‚úÖ Baseline training curve saved to: models/t5_baseline/training_curve.png")

In [None]:
# ============================================================================
# Cell 16: Train Clarified Model (Clarify-and-Link)
# ============================================================================

print("\n (2) CLARIFY-AND-LINK MODEL")
clarified_trainer, clarified_model = train_entity_linking_model(
    train_clarified_dataset,
    val_clarified_dataset,
    'clarified',
    'models/t5_clarified'
)

# Clear cache
torch.cuda.empty_cache()

print("\n Both models trained successfully!")

In [None]:
# ============================================================================
# Cell 16.5: Visualize Clarified Model Training Progress
# ============================================================================

import matplotlib.pyplot as plt

print("\n" + "="*70)
print("CLARIFY-AND-LINK MODEL TRAINING VISUALIZATION")
print("="*70)

# Extract training history
log_history = clarified_trainer.state.log_history

# Separate logs
train_logs = [log for log in log_history if 'loss' in log and 'eval_loss' not in log]
eval_logs = [log for log in log_history if 'eval_loss' in log]

# Extract data
train_steps = [log['step'] for log in train_logs]
train_loss = [log['loss'] for log in train_logs]
eval_steps = [log['step'] for log in eval_logs]
eval_loss = [log['eval_loss'] for log in eval_logs]

# Create plot
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

# Plot training loss
ax.plot(train_steps, train_loss, label='Training Loss', linewidth=2, color='#3498db', alpha=0.8)

# Plot validation loss
ax.plot(eval_steps, eval_loss, label='Validation Loss', linewidth=2.5, color='#e74c3c', marker='o', markersize=6)

# Formatting
ax.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
ax.set_ylabel('Loss', fontsize=13, fontweight='bold')
ax.set_title('Clarify-and-Link Model Training Progress', fontsize=15, fontweight='bold', pad=15)
ax.legend(fontsize=11, loc='upper right')
ax.grid(alpha=0.3, linestyle='--')
ax.set_ylim(bottom=0)

# Add annotations
best_eval_idx = eval_loss.index(min(eval_loss))
best_eval_step = eval_steps[best_eval_idx]
best_eval_loss = eval_loss[best_eval_idx]

ax.annotate(f'Best: {best_eval_loss:.4f}',
            xy=(best_eval_step, best_eval_loss),
            xytext=(best_eval_step, best_eval_loss + max(eval_loss) * 0.1),
            arrowprops=dict(arrowstyle='->', color='red', lw=2),
            fontsize=11, fontweight='bold', color='red',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.7))

plt.tight_layout()
plt.savefig('models/t5_clarified/training_curve.png', dpi=300, bbox_inches='tight')
plt.show()

# Print statistics
print(f"\nüìä Training Statistics:")
print(f"   Initial training loss: {train_loss[0]:.4f}")
print(f"   Final training loss: {train_loss[-1]:.4f}")
print(f"   Training loss reduction: {train_loss[0] - train_loss[-1]:.4f}")
print(f"\n   Initial validation loss: {eval_loss[0]:.4f}")
print(f"   Best validation loss: {best_eval_loss:.4f}")
print(f"   Validation loss reduction: {eval_loss[0] - best_eval_loss:.4f}")
print(f"   Best model at step: {best_eval_step}")

print(f"\n‚úÖ Clarified training curve saved to: models/t5_clarified/training_curve.png")

In [None]:
# ============================================================================
# Cell 16.7: Compare Baseline vs Clarified Training
# ============================================================================

import matplotlib.pyplot as plt

print("\n" + "="*70)
print("TRAINING COMPARISON: BASELINE VS CLARIFY-AND-LINK")
print("="*70)

# Extract baseline logs
baseline_history = baseline_trainer.state.log_history
baseline_eval = [log for log in baseline_history if 'eval_loss' in log]
baseline_eval_steps = [log['step'] for log in baseline_eval]
baseline_eval_loss = [log['eval_loss'] for log in baseline_eval]

# Extract clarified logs
clarified_history = clarified_trainer.state.log_history
clarified_eval = [log for log in clarified_history if 'eval_loss' in log]
clarified_eval_steps = [log['step'] for log in clarified_eval]
clarified_eval_loss = [log['eval_loss'] for log in clarified_eval]

# Create comparison plot
fig, ax = plt.subplots(1, 1, figsize=(14, 7))

# Plot both validation losses
ax.plot(baseline_eval_steps, baseline_eval_loss,
        label='Baseline', linewidth=2.5, color='#3498db', marker='o', markersize=7)
ax.plot(clarified_eval_steps, clarified_eval_loss,
        label='Clarify-and-Link', linewidth=2.5, color='#e74c3c', marker='s', markersize=7)

# Formatting
ax.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
ax.set_ylabel('Validation Loss', fontsize=13, fontweight='bold')
ax.set_title('Training Comparison: Baseline vs Clarify-and-Link', fontsize=15, fontweight='bold', pad=15)
ax.legend(fontsize=12, loc='upper right')
ax.grid(alpha=0.3, linestyle='--')
ax.set_ylim(bottom=0)

# Add final loss comparison
final_baseline = min(baseline_eval_loss)
final_clarified = min(clarified_eval_loss)

# Add text box with comparison
textstr = f'Best Validation Loss:\n'
textstr += f'Baseline: {final_baseline:.4f}\n'
textstr += f'Clarified: {final_clarified:.4f}\n'
textstr += f'Improvement: {final_baseline - final_clarified:.4f}'

props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', bbox=props, fontfamily='monospace')

plt.tight_layout()
plt.savefig('data/experiments/training_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nüìä Final Comparison:")
print(f"   Baseline best loss: {final_baseline:.4f}")
print(f"   Clarified best loss: {final_clarified:.4f}")
print(f"   Loss improvement: {final_baseline - final_clarified:.4f}")

if final_clarified < final_baseline:
    improvement = ((final_baseline - final_clarified) / final_baseline) * 100
    print(f"   Relative improvement: {improvement:.1f}% better ‚úÖ")
else:
    print(f"   Note: Clarified model has higher loss")

print(f"\n‚úÖ Comparison plot saved to: data/experiments/training_comparison.png")

In [None]:
# ============================================================================
# Cell 17: Evaluation (UPDATED FOR NEW FORMAT)
# ============================================================================

def evaluate_model(model, tokenizer, test_samples, model_name):
    """Evaluate entity linking model."""
    print(f"\n{'='*70}")
    print(f"EVALUATING: {model_name.upper()}")
    print(f"{'='*70}")

    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    predictions = []
    ground_truths = []
    correct = 0

    print(f"\nüîç Running inference on {len(test_samples)} samples...")

    for sample in tqdm(test_samples, desc="Evaluating"):
        input_text = sample['input_text']
        target_text = sample['target_text']

        # Generate prediction
        inputs = tokenizer(input_text, return_tensors='pt', max_length=512, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=20)  # ‚úÖ Shorter for QIDs

        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

        predictions.append(prediction)
        ground_truths.append(target_text)

        # ‚úÖ Check if correct (flexible matching)
        # Handles: "Q170566" == "Q170566" or "170566" == "Q170566"
        pred_clean = prediction.replace('Q', '').replace('.0', '')
        truth_clean = target_text.replace('Q', '').replace('.0', '')

        if pred_clean == truth_clean:
            correct += 1

    accuracy = correct / len(predictions) if len(predictions) > 0 else 0

    print(f"\nüìä Results for {model_name}:")
    print(f"   Total samples: {len(predictions)}")
    print(f"   Correct: {correct}")
    print(f"   Accuracy: {accuracy:.2%}")

    # Show sample predictions
    print(f"\nüìù Sample predictions:")
    for i in range(min(5, len(predictions))):
        pred_clean = predictions[i].replace('Q', '').replace('.0', '')
        truth_clean = ground_truths[i].replace('Q', '').replace('.0', '')
        match = "‚úÖ" if pred_clean == truth_clean else "‚ùå"

        print(f"\n   {match} Example {i+1}:")
        print(f"      Input: {test_samples[i]['input_text'][:100]}...")
        print(f"      Predicted: {predictions[i]}")
        print(f"      Expected: {ground_truths[i]}")

    return {
        'accuracy': accuracy,
        'correct': correct,
        'total_samples': len(predictions),
        'predictions': predictions,
        'ground_truths': ground_truths
    }


# Evaluate both models
baseline_results = evaluate_model(baseline_model, t5_tokenizer, test_baseline, 'Baseline')
clarified_results = evaluate_model(clarified_model, t5_tokenizer, test_clarified, 'Clarify-and-Link')

# Save results
results_comparison = {
    'baseline': {
        'accuracy': baseline_results['accuracy'],
        'correct': baseline_results['correct'],
        'total_samples': baseline_results['total_samples']
    },
    'clarified': {
        'accuracy': clarified_results['accuracy'],
        'correct': clarified_results['correct'],
        'total_samples': clarified_results['total_samples']
    },
    'improvement': {
        'accuracy_gain': clarified_results['accuracy'] - baseline_results['accuracy'],
        'accuracy_gain_percent': ((clarified_results['accuracy'] - baseline_results['accuracy']) / baseline_results['accuracy'] * 100) if baseline_results['accuracy'] > 0 else 0
    }
}

with open('data/experiments/evaluation_results.json', 'w') as f:
    json.dump(results_comparison, f, indent=2)

print("\n‚úÖ Evaluation complete! Results saved.")

In [None]:
# ============================================================================
# Cell 17.5: Debug - Check What Model Is Actually Predicting
# ============================================================================

print("\n" + "="*70)
print("DEBUG: CHECKING MODEL PREDICTIONS")
print("="*70)

# Test a few samples manually
print("\nüîç Testing Baseline Model on 5 samples:\n")

baseline_model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
baseline_model.to(device)

for i in range(min(5, len(test_baseline))):
    sample = test_baseline[i]

    print(f"üìù Sample {i+1}:")
    print(f"   Input text: {sample['input_text'][:120]}...")
    print(f"   Expected output: {sample['target_text']}")

    # Generate prediction
    inputs = t5_tokenizer(sample['input_text'], return_tensors='pt', max_length=512, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        # Try different generation settings
        outputs = t5_tokenizer.batch_decode(
            baseline_model.generate(
                **inputs,
                max_new_tokens=50,
                num_beams=1,
                do_sample=False
            ),
            skip_special_tokens=True
        )

    prediction = outputs[0].strip()
    print(f"   Model prediction: '{prediction}'")
    print(f"   Prediction length: {len(prediction)}")
    print(f"   Match: {'‚úÖ YES' if prediction.replace('Q','').replace('.0','') == sample['target_text'].replace('Q','').replace('.0','') else '‚ùå NO'}")
    print()

print("\nüîç Checking data format:")
print(f"   First input starts with: {test_baseline[0]['input_text'][:50]}")
print(f"   First target: {test_baseline[0]['target_text']}")

In [None]:
# ============================================================================
# Cell 18: Results Visualization (FIXED)
# ============================================================================

import matplotlib.pyplot as plt
import seaborn as sns

print("\n" + "="*70)
print("RESULTS VISUALIZATION")
print("="*70)

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['font.size'] = 11

# Create comparison plot
fig, axes = plt.subplots(1, 2, figsize=(15, 6))  # ‚úÖ Slightly larger

# Accuracy comparison
models = ['Baseline', 'Clarify-and-Link']
accuracies = [baseline_results['accuracy'], clarified_results['accuracy']]
colors = ['#3498db', '#e74c3c']

axes[0].bar(models, accuracies, color=colors, alpha=0.8, width=0.6)
axes[0].set_ylabel('Accuracy', fontsize=13, fontweight='bold')
axes[0].set_title('Entity Linking Accuracy', fontsize=15, fontweight='bold', pad=20)  # ‚úÖ Add padding
axes[0].set_ylim([0, min(1.0, max(accuracies) * 1.3)])  # ‚úÖ Dynamic y-limit
axes[0].grid(axis='y', alpha=0.3)

# Add value labels on bars
for i, (model, acc) in enumerate(zip(models, accuracies)):
    axes[0].text(i, acc + 0.02, f'{acc:.1%}', ha='center', fontsize=12, fontweight='bold')

# Improvement visualization
improvement = clarified_results['accuracy'] - baseline_results['accuracy']
improvement_percent = improvement * 100

axes[1].bar(['Accuracy\nImprovement'], [improvement_percent], color='#2ecc71', alpha=0.8, width=0.5)
axes[1].set_ylabel('Percentage Points', fontsize=13, fontweight='bold')
axes[1].set_title('Clarify-and-Link Improvement', fontsize=15, fontweight='bold', pad=20)  # ‚úÖ Add padding
axes[1].grid(axis='y', alpha=0.3)
axes[1].axhline(y=0, color='black', linestyle='-', linewidth=0.8)  # Add zero line

# Add value label
label_y = improvement_percent + (0.5 if improvement_percent > 0 else -0.5)
axes[1].text(0, label_y, f'+{improvement:.1%}' if improvement >= 0 else f'{improvement:.1%}',
             ha='center', fontsize=12, fontweight='bold')

# ‚úÖ FIX: Use constrained_layout instead of tight_layout
plt.subplots_adjust(left=0.08, right=0.95, top=0.88, bottom=0.12, wspace=0.25)

plt.savefig('data/experiments/results_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ Visualization saved!")

# Print summary
print(f"\nüìä Final Results:")
print(f"   Baseline Accuracy: {baseline_results['accuracy']:.2%}")
print(f"   Clarify-and-Link Accuracy: {clarified_results['accuracy']:.2%}")
print(f"   Improvement: {'+' if improvement >= 0 else ''}{improvement:.2%}")
print(f"   Improvement (percentage points): {'+' if improvement_percent >= 0 else ''}{improvement_percent:.2f}pp")

# Additional statistics
if baseline_results['accuracy'] > 0:
    relative_improvement = (improvement / baseline_results['accuracy']) * 100
    print(f"   Relative improvement: {relative_improvement:.1f}%")

print(f"\nüìà Sample Counts:")
print(f"   Baseline: {baseline_results['correct']}/{baseline_results['total_samples']} correct")
print(f"   Clarified: {clarified_results['correct']}/{clarified_results['total_samples']} correct")

In [None]:
# # ============================================================================
# # Cell 19: Download Everything
# # ============================================================================

# from google.colab import files
# import shutil

# print("üì¶ Creating final package...")

# # Create archive with all results
# shutil.make_archive('clarify_and_link_complete', 'zip', 'data/experiments')
# shutil.make_archive('trained_models', 'zip', 'models')

# print("\nüì• Downloading results...")
# files.download('clarify_and_link_complete.zip')
# files.download('trained_models.zip')

# print("\n‚úÖ Download complete!")
# print("\nPackage contents:")
# print("  clarify_and_link_complete.zip:")
# print("    - clarifications_train.json")
# print("    - clarifications_val.json")
# print("    - clarifications_test.json")
# print("    - processed_for_training/ (6 JSONL files)")
# print("    - evaluation_results.json")
# print("    - results_comparison.png")
# print("\n  trained_models.zip:")
# print("    - t5_baseline/ (baseline model)")
# print("    - t5_clarified/ (clarify-and-link model)")