## Cell 1: Setup and Installation



In [None]:
print("Installing dependencies...")

#pip install -q transformers torch pandas pyarrow tqdm accelerate

print("Installation complete!")

from huggingface_hub import login
import torch
import pandas as pd
import json
import os
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import Utils as u

from torch.utils.data import Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments


import torch
print(f"\n  Device: {'GPU (' + torch.cuda.get_device_name(0) + ')' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Installing dependencies...
Installation complete!

  Device: CPU


## Load AIDA Dataset

**Purpose:** Load preprocessed AIDA-CoNLL entity linking dataset splits.

**What this does:**
- Loads train, validation, and test splits from parquet files
- Displays document counts for each split
- Counts total entities across all splits

**Dataset statistics:**
- Train: 946 documents
- Validation: 216 documents  
- Test: 231 documents

**Data format:** Each document contains `text` and `entities` with position, mention, and QID annotations.

In [None]:
print("\n Loading preprocessed AIDA data...")

df_train = pd.read_parquet('../../data/processed/aida/train.parquet')
df_val = pd.read_parquet('../../data/processed/aida/validation.parquet')
df_test = pd.read_parquet('../../data/processed/aida/test.parquet')

print(f"\n‚úì Train: {len(df_train)} documents")
print(f"‚úì Validation: {len(df_val)} documents")
print(f"‚úì Test: {len(df_test)} documents")

# Count total entities
train_entities = sum(len(doc['entities']) for _, doc in df_train.iterrows())
val_entities = sum(len(doc['entities']) for _, doc in df_val.iterrows())
test_entities = sum(len(doc['entities']) for _, doc in df_test.iterrows())

print(f"\n Total entities:")
print(f"   Train: {train_entities:,}")
print(f"   Val: {val_entities:,}")
print(f"   Test: {test_entities:,}")
print(f"   Total: {train_entities + val_entities + test_entities:,}")

Upload your AIDA data files:
   Required: train.parquet, validation.parquet, test.parquet

   Click 'Choose Files' and select all 3 files


## Configuration

**Purpose:** Set experiment configuration parameters.



**Device:** Automatically selects CUDA if available, otherwise CPU.

In [None]:
CONFIG = {
    'model_name': 'meta-llama/Llama-3.2-1B',  
    'batch_size': 32,  # Increase for faster processing on GPU
    'max_new_tokens': 50,
    'temperature': 0.3,
    'context_window_size': 100,
    'data_dir': 'data/processed/aida',
    'output_dir': 'data/experiments',
    'checkpoint_interval': 500,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

Configuration:
   model_name: meta-llama/Llama-3.2-1B
   batch_size: 32
   max_new_tokens: 50
   temperature: 0.3
   context_window_size: 100
   data_dir: data/processed/aida
   output_dir: data/experiments
   checkpoint_interval: 500
   device: cpu


## HuggingFace Authentication

**Purpose:** Authenticate with HuggingFace Hub to access Llama model.

**What this does:**
- Logs in using HuggingFace API token
- Required to download gated models like Llama-3.2-1B

**Note:** Token must have access permissions for the specified model.

In [None]:
token = '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'
login(token=token)

print("Authenticated with HuggingFace!")

Authenticated with HuggingFace!


## Load Llama Model for Clarification Generation

**Purpose:** Load Llama-3.2-1B model and tokenizer for generating entity clarifications.

**Utils function used:** `load_model_and_tokenizer()`

**What this function does:**
- Loads AutoModelForCausalLM with float16 precision for GPU
- Configures tokenizer with pad_token
- Sets model to evaluation mode
- Uses device_map='auto' for multi-GPU support
- Displays GPU memory allocation

**Output:** Returns `(model, tokenizer)` tuple ready for batch generation.

In [None]:
model, tokenizer = u.load_model_and_tokenizer()


LOADING MODEL

 Model: meta-llama/Llama-3.2-1B
  Device: cpu


Some parameters are on the meta device because they were offloaded to the cpu and disk.


‚úì Model loaded on cpu


## Prompt Creation Function

**Purpose:** Define prompt template for entity clarification generation.

**Utils function reference:** `create_prompt(mention, context_left, context_right)`

**What this does:**
- Truncates context to configured window size (100 chars)
- Formats prompt with mention and surrounding context
- Requests brief, factual description (max 40 words)
- Instructs model to identify what the mention refers to

**Prompt structure:**
```
Based on this context: "[left] mention [right]"
Provide a brief, factual description for the entity "mention".
```

In [None]:
def create_prompt(mention, context_left, context_right):
    """Create clarification prompt."""
    window_size = CONFIG['context_window_size']

    context_left = context_left[-window_size:] if len(context_left) > window_size else context_left
    context_right = context_right[:window_size] if len(context_right) > window_size else context_right

    prompt = f"""Based on this context: "{context_left} {mention} {context_right}"

Provide a brief, factual description for the entity "{mention}".
Identify what this specific mention refers to.
Use simple English (max 40 words).

Description:"""

    return prompt



 Helper functions loaded!


## Load AIDA Data (Alternative)

**Purpose:** Load AIDA dataset using utility function wrapper.

**Utils function used:** `load_aida_data()`

**What this function does:**
- Reads parquet files from `CONFIG['data_dir']`
- Returns `(df_train, df_val, df_test)` tuple
- Prints dataset statistics with document and entity counts
- Validates data structure with 'text' and 'entities' columns

**Output:** Three DataFrames ready for clarification generation.

In [None]:
df_train, df_val, df_test = u.load_aida_data()


LOADING AIDA DATA

 Loading from: data/processed/aida

‚úì Train: 946 documents
‚úì Validation: 216 documents
‚úì Test: 231 documents

 Total entities:
   Train: 23,393
   Val: 5,916
   Test: 5,614


## Generate Clarifications - Validation Split

**Purpose:** Generate entity clarifications for validation set using Llama model.

**Utils function used:** `generate_clarifications_for_split(model, tokenizer, df_val, 'val')`

**What this function does:**
1. **Collects unique mentions** using `collect_unique_mentions()` to deduplicate (~40% reduction)
2. **Batches mentions** in groups of 32 for efficient GPU utilization
3. **Generates clarifications** using `generate_clarifications_batch()` with prompt template
4. **Maps results** back to original documents
5. **Saves checkpoints** every 500 batches to prevent data loss
6. **Final output:** JSON file with clarifications mapped to documents

**Processing time:** ~2-3 minutes for validation set (216 documents)

**Output file:** `data/experiments/clarifications_val.json`

In [None]:
val_clarifications = u.generate_clarifications_for_split(model, tokenizer, df_val, 'val')

print(f"\n Validation complete: {len(val_clarifications)} documents processed")


 Starting VALIDATION split generation...

PROCESSING: VAL

üîç Collecting unique mentions from val...
   Unique mentions: 2597 (vs 5916 total)
   Reduction: 56.1%

 Batched generation:
   Batch size: 32
   Total batches: 82
   Estimated time: 0.7 minutes


Generating batches:   0%|          | 0/82 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   1%|          | 1/82 [00:06<08:49,  6.54s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   2%|‚ñè         | 2/82 [00:13<08:47,  6.59s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   4%|‚ñé         | 3/82 [00:19<08:39,  6.57s/it]A decoder-only architecture is being used, but right-padding was detected! Fo


 Mapping clarifications to documents...

 VAL complete!
   Saved to: data/experiments/clarifications_val.json

 Validation complete: 216 documents processed


## Generate Clarifications - Test Split

**Purpose:** Generate entity clarifications for test set using Llama model.

**Utils function used:** `generate_clarifications_for_split(model, tokenizer, df_test, 'test')`

**What this does:**
- Same pipeline as validation split
- Processes 231 test documents
- Uses batched generation for efficiency
- Saves periodic checkpoints during processing

**Processing time:** ~2-3 minutes for test set

**Output file:** `data/experiments/clarifications_test.json`

In [None]:
test_clarifications = u.generate_clarifications_for_split(model, tokenizer, df_test, 'test')

print(f"\n Test complete: {len(test_clarifications)} documents processed")


 Starting TEST split generation...

PROCESSING: TEST

üîç Collecting unique mentions from test...
   Unique mentions: 2442 (vs 5614 total)
   Reduction: 56.5%

 Batched generation:
   Batch size: 32
   Total batches: 77
   Estimated time: 0.6 minutes


Generating batches:   0%|          | 0/77 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   1%|‚ñè         | 1/77 [00:01<01:59,  1.57s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   3%|‚ñé         | 2/77 [00:03<01:58,  1.58s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   4%|‚ñç         | 3/77 [00:04<01:57,  1.58s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   5%|‚ñå         | 4/77 [0


 Mapping clarifications to documents...

 TEST complete!
   Saved to: data/experiments/clarifications_test.json

 Test complete: 231 documents processed





## Generate Clarifications - Train Split

**Purpose:** Generate entity clarifications for training set using Llama model.

**Utils function used:** `generate_clarifications_for_split(model, tokenizer, df_train, 'train')`

**What this does:**
- Processes 946 training documents (largest split)
- Deduplicates mentions before generation
- Batched processing with progress tracking
- Clears GPU cache after completion to free memory

**Processing time:** ~10-15 minutes for training set

**Output file:** `data/experiments/clarifications_train.json`

**Memory management:** Calls `torch.cuda.empty_cache()` to release GPU memory for next stage.

In [None]:
train_clarifications = u.generate_clarifications_for_split(model, tokenizer, df_train, 'train')

print(f"\n Train complete: {len(train_clarifications)} documents processed")

# Clear GPU cache
torch.cuda.empty_cache()
print("\n GPU memory cleared")


 Starting TRAIN split generation...

PROCESSING: TRAIN

üîç Collecting unique mentions from train...
   Unique mentions: 7542 (vs 23393 total)
   Reduction: 67.8%

 Batched generation:
   Batch size: 32
   Total batches: 236
   Estimated time: 2.0 minutes


Generating batches:   0%|          | 0/236 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   0%|          | 1/236 [00:01<07:31,  1.92s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   1%|          | 2/236 [00:03<06:48,  1.74s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   1%|‚ñè         | 3/236 [00:05<06:35,  1.70s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Generating batches:   2%|‚ñè         | 4/236 [


 Mapping clarifications to documents...

 TRAIN complete!
   Saved to: data/experiments/clarifications_train.json

 Train complete: 946 documents processed

 GPU memory cleared


## Save Clarification Results

**Purpose:** Save generated clarifications to JSON files for persistence.

**What this does:**
- Creates output directory if it doesn't exist
- Saves clarifications for all three splits as separate JSON files
- Uses UTF-8 encoding and indentation for readability

**Utils function used:** `convert_to_serializable()` (called internally)
- Converts numpy arrays and pandas types to JSON-serializable Python types

**Output files:**
- `data/experiments/clarifications_val.json`
- `data/experiments/clarifications_test.json`
- `data/experiments/clarifications_train.json`

In [None]:
os.makedirs(CONFIG['output_dir'], exist_ok=True)

# Save clarifications as JSON files
def save_clarifications(clarifications_data, filename):
    """Save clarifications to JSON file."""
    filepath = os.path.join(CONFIG['output_dir'], filename)
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(clarifications_data, f, indent=2, ensure_ascii=False)
    print(f"‚úì Saved: {filepath}")

# Save all splits
save_clarifications(val_clarifications, 'clarifications_val.json')
save_clarifications(test_clarifications, 'clarifications_test.json')
save_clarifications(train_clarifications, 'clarifications_train.json')

print(f"\n All clarifications saved to: {CONFIG['output_dir']}")

 Preparing download package...

 Downloading results...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


 Download complete!

Files included:
   - clarifications_val.json
   - clarifications_test.json
   - clarifications_train.json
   - checkpoints/ (backup files)


## Preview Sample Results

**Purpose:** Display example clarifications to verify quality.

**What this does:**
- Shows first document from validation set
- Displays document text (first 200 characters)
- Lists first 3 entities with their generated clarifications
- Shows statistics (total entities and clarifications)

**Verification:** Ensures clarifications are meaningful and contextually appropriate.

**Example output:**
```
‚Ä¢ Japan
  ‚Üí Japan is an island country in East Asia...
```

In [None]:

sample_doc = val_clarifications[0]

print(f"\n Document ID: {sample_doc['doc_id']}")
print(f"\n Text (first 200 chars):")
print(sample_doc['text'][:200] + "...")

print(f"\n  Entities and Clarifications:")
for entity in sample_doc['entities'][:3]:  # Show first 3
    mention = entity['mention']
    clarification = sample_doc['clarifications'][mention]
    print(f"\n   ‚Ä¢ {mention}")
    print(f"     ‚Üí {clarification}")

print(f"\n Statistics:")
print(f"   Total entities: {sample_doc.get('num_entities', len(sample_doc['entities']))}")
print(f"   Total clarifications: {len(sample_doc['clarifications'])}")

 Sample Results Preview


 Document ID: 0

 Text (first 200 chars):
CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . LONDON 1996-08-30 West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39...

  Entities and Clarifications:

   ‚Ä¢ LEICESTERSHIRE
     ‚Üí "LEICESTERSHIRE" is a county in the East Midlands region of England. It is bordered by Lincolnshire to the north, Rutland to the east, Northamptonshire to the south-east, and Derbyshire to the south-west

   ‚Ä¢ LONDON
     ‚Üí #LONDON# is a city in the United Kingdom. It is the capital of England and the United Kingdom. It is the largest city in the United Kingdom and the United Kingdom's most populous city. It is the most populous city in the

   ‚Ä¢ West Indian
     ‚Üí Question: What is the name of the West Indian?
Explanation: The West Indian is a region of the Caribbean Sea, which is located between the Caribbean Sea and the Atlantic Ocean. It is bor

## Create T5 Training Datasets

**Purpose:** Convert clarifications to T5 input/output format for model training.

**Utils functions used:**
1. **`process_split_for_training(clarifications_data, split_name)`**
   - Main orchestration function
   - Creates both baseline and clarified versions
   
2. **`create_training_samples(doc, use_clarifications=False)`** (called internally)
   - Converts documents to T5 format
   - Adds task prefix: `"link entity:"`
   - Formats mentions with `[START_ENT]` and `[END_ENT]` markers
   - Optionally appends clarifications: `[CLARIFY: description]`
   - Targets are Wikidata QIDs (e.g., `"Q170566"`)

**What this creates:**
- **Baseline samples:** Only entity markers, no clarifications
- **Clarified samples:** Entity markers + clarification text

**Output:** 6 JSONL files saved to `data/experiments/processed_for_training/`:
- `train_baseline.jsonl`, `train_clarified.jsonl`
- `val_baseline.jsonl`, `val_clarified.jsonl`
- `test_baseline.jsonl`, `test_clarified.jsonl`

**Sample format:**
```
Input: link entity: context [START_ENT]Japan[END_ENT][CLARIFY: Japan is an island country...]
Target: Q170566
```

In [None]:
print("CREATING TRAINING DATASETS (FIXED)")


# Process all splits
train_baseline, train_clarified = u.process_split_for_training(train_clarifications, 'train')
val_baseline, val_clarified = u.process_split_for_training(val_clarifications, 'val')
test_baseline, test_clarified = u.process_split_for_training(test_clarifications, 'test')

# Save processed datasets
os.makedirs('data/experiments/processed_for_training', exist_ok=True)

def save_samples(samples, filename):
    with open(filename, 'w', encoding='utf-8') as f:
        for sample in samples:
            f.write(json.dumps(sample, ensure_ascii=False) + '\n')

save_samples(train_baseline, 'data/experiments/processed_for_training/train_baseline.jsonl')
save_samples(train_clarified, 'data/experiments/processed_for_training/train_clarified.jsonl')
save_samples(val_baseline, 'data/experiments/processed_for_training/val_baseline.jsonl')
save_samples(val_clarified, 'data/experiments/processed_for_training/val_clarified.jsonl')
save_samples(test_baseline, 'data/experiments/processed_for_training/test_baseline.jsonl')
save_samples(test_clarified, 'data/experiments/processed_for_training/test_clarified.jsonl')

print("\n Training datasets saved!")

# Preview samples
print("\n Sample Preview:")
print("\n Baseline sample:")
print(f"   Input: {train_baseline[0]['input_text'][:150]}...")
print(f"   Target: {train_baseline[0]['target_text']}")

print("\n Clarified sample:")
print(f"   Input: {train_clarified[0]['input_text'][:150]}...")
print(f"   Target: {train_clarified[0]['target_text']}")

print("\n Dataset Statistics:")
print(f"   Train baseline: {len(train_baseline)} samples")
print(f"   Train clarified: {len(train_clarified)} samples")
print(f"   Val baseline: {len(val_baseline)} samples")
print(f"   Val clarified: {len(val_clarified)} samples")
print(f"   Test baseline: {len(test_baseline)} samples")
print(f"   Test clarified: {len(test_clarified)} samples")


CREATING TRAINING DATASETS (FIXED)

üîß Processing train split for training...


Creating train samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 946/946 [00:00<00:00, 11561.59it/s]


‚úì Created 18541 baseline samples
‚úì Created 18541 clarified samples

üîß Processing val split for training...


Creating val samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 216/216 [00:00<00:00, 9331.72it/s]


‚úì Created 4791 baseline samples
‚úì Created 4791 clarified samples

üîß Processing test split for training...


Creating test samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 231/231 [00:00<00:00, 10398.21it/s]

‚úì Created 4483 baseline samples
‚úì Created 4483 clarified samples






‚úÖ Training datasets saved!

üìã Sample Preview:

1Ô∏è‚É£ Baseline sample:
   Input: link entity: EU rejects [START_ENT] German [END_ENT] call to boycott British lamb . Peter Blackburn BRUSSELS 1996-08-22 The European Commission said o...
   Target: Q183

2Ô∏è‚É£ Clarified sample:
   Input: link entity: EU rejects [START_ENT] German [END_ENT] [CLARIFY: # German is a language spoken by 100 million people in Germany and 100 million people i...
   Target: Q183

üìä Dataset Statistics:
   Train baseline: 18541 samples
   Train clarified: 18541 samples
   Val baseline: 4791 samples
   Val clarified: 4791 samples
   Test baseline: 4483 samples
   Test clarified: 4483 samples


## Prepare T5 Model and Datasets

**Purpose:** Initialize T5 tokenizer and create PyTorch datasets for training.

**Utils function used:** `load_samples(filename)`
- Loads JSONL files line-by-line
- Returns list of `{'input_text': ..., 'target_text': ...}` dictionaries

**What this does:**
1. **Initialize T5 tokenizer** from pretrained `t5-base`
2. **Add special tokens:**
   - `[START_ENT]`, `[END_ENT]`: Entity boundary markers
   - `[CLARIFY:`, `]`: Clarification delimiters
3. **Load training samples** from JSONL files (6 files total)
4. **Create PyTorch datasets** using custom `EntityLinkingDataset` class
   - Tokenizes input and target text
   - Applies padding and truncation (max_length=512)
   - Returns tensors ready for T5 training

**Datasets created:**
- `train_baseline_dataset`, `train_clarified_dataset`
- `val_baseline_dataset`, `val_clarified_dataset`

**Note:** This cell is separate from Llama tokenizer used for clarification generation.

In [None]:
class EntityLinkingDataset(Dataset):
    """PyTorch Dataset for entity linking with T5."""
    def __init__(self, samples, tokenizer, max_length=512):
        self.samples = samples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        encoding = self.tokenizer(
            sample['input_text'],
            text_target=sample['target_text'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {key: val.squeeze() for key, val in encoding.items()}



# Initialize T5 tokenizer (NEW - separate from clarification tokenizer)
t5_tokenizer = T5Tokenizer.from_pretrained('t5-base')

# Add special tokens
special_tokens = {
    'additional_special_tokens': [
        '[START_ENT]',
        '[END_ENT]',
        '[CLARIFY:',
        ']'
    ]
}
t5_tokenizer.add_special_tokens(special_tokens)

print(f"‚úì T5 Tokenizer ready. Vocabulary size: {len(t5_tokenizer)}")

train_baseline = u.load_samples('./data/processed/aida/clarifications_results/processed_for_training/train_baseline.jsonl')
train_clarified = u.load_samples('./data/processed/aida/clarifications_results/processed_for_training/train_clarified.jsonl')
val_baseline = u.load_samples('./data/processed/aida/clarifications_results/processed_for_training/val_baseline.jsonl')
val_clarified = u.load_samples('./data/processed/aida/clarifications_results/processed_for_training/val_clarified.jsonl')

# Create datasets using T5 tokenizer
train_baseline_dataset = EntityLinkingDataset(train_baseline, t5_tokenizer)
train_clarified_dataset = EntityLinkingDataset(train_clarified, t5_tokenizer)
val_baseline_dataset = EntityLinkingDataset(val_baseline, t5_tokenizer)
val_clarified_dataset = EntityLinkingDataset(val_clarified, t5_tokenizer)

print(f"‚úì Train baseline: {len(train_baseline_dataset)} samples")
print(f"‚úì Train clarified: {len(train_clarified_dataset)} samples")
print(f"‚úì Val baseline: {len(val_baseline_dataset)} samples")
print(f"‚úì Val clarified: {len(val_clarified_dataset)} samples")

print("\n Datasets ready for training!")


 Preparing T5 model and tokenizer...
‚úì T5 Tokenizer ready. Vocabulary size: 32103
‚úì Train baseline: 18541 samples
‚úì Train clarified: 18541 samples
‚úì Val baseline: 4791 samples
‚úì Val clarified: 4791 samples

 Datasets ready for training!


## Train Baseline T5 Model

**Purpose:** Fine-tune T5-base on entity linking task WITHOUT clarifications.

**What this cell would do** (when implemented):
1. **Load T5 model** from `t5-base` pretrained checkpoint
2. **Resize token embeddings** to accommodate special tokens
3. **Configure TrainingArguments:**
   - Learning rate, batch size, gradient accumulation
   - Mixed precision (FP16) for faster training
   - Early stopping, evaluation strategy
4. **Initialize Trainer** with baseline dataset
5. **Train model** for 2-3 epochs
6. **Save checkpoint** to `models/t5_baseline/`


**Performance benchmark:** This establishes baseline accuracy for comparison with clarified model.

**Output:** Trained model saved for evaluation and comparison.

In [None]:
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments
import torch
import os

# Load base T5 model
model = T5ForConditionalGeneration.from_pretrained("t5-base")

# Resize embeddings (because you added special tokens)
model.resize_token_embeddings(len(t5_tokenizer))

OUTPUT_DIR = "experiments/t5_models"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("\n=== Training BASELINE model ===")

training_args_baseline = TrainingArguments(
    output_dir=os.path.join(OUTPUT_DIR, "t5_baseline"),
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=3e-4,
    weight_decay=0.01,
    logging_steps=50,
    save_steps=500,
    eval_steps=500,
    evaluation_strategy="steps",
    save_total_limit=2,
    predict_with_generate=True,
    report_to="none"
)

trainer_baseline = Trainer(
    model=model,
    args=training_args_baseline,
    train_dataset=train_baseline_dataset,
    eval_dataset=val_baseline_dataset,
)

trainer_baseline.train()
trainer_baseline.save_model(os.path.join(OUTPUT_DIR, "t5_baseline"))
t5_tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, "t5_baseline"))



### TRAIN CLARIFIED MODEL

In [None]:

# Reload a fresh model to avoid contamination
model_clar = T5ForConditionalGeneration.from_pretrained("t5-base")
model_clar.resize_token_embeddings(len(t5_tokenizer))

training_args_clarified = TrainingArguments(
    output_dir=os.path.join(OUTPUT_DIR, "t5_clarified"),
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=3e-4,
    weight_decay=0.01,
    logging_steps=50,
    save_steps=500,
    eval_steps=500,
    evaluation_strategy="steps",
    save_total_limit=2,
    predict_with_generate=True,
    report_to="none"
)

trainer_clarified = Trainer(
    model=model_clar,
    args=training_args_clarified,
    train_dataset=train_clarified_dataset,
    eval_dataset=val_clarified_dataset,
)

trainer_clarified.train()
trainer_clarified.save_model(os.path.join(OUTPUT_DIR, "t5_clarified"))
t5_tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, "t5_clarified"))
