# Phase 6: Multi-Task BERT Training
## Fine-tune bert-base-uncased for Emotion + Crisis + Informativeness Classification

This notebook trains a **multi-task BERT model** with three classification heads:
- **Emotion** (13 classes): fear, anger, sadness, anxiety, confusion, surprise, disgust, caring, joy, excitement, gratitude, disappointment, neutral
- **Crisis** (2 classes): crisis vs. non-crisis
- **Informativeness** (3 classes): related_informative, related_not_informative, not_related

### Architecture
```
Input text -> BERT Tokenizer (max_length=128)
           -> BERT base encoder (shared, 768-dim)
           -> [CLS] token embedding
           |-> Dropout(0.3) -> Linear(768, 13) -> Emotion logits
           |-> Dropout(0.3) -> Linear(768,  2) -> Crisis logits
           |-> Dropout(0.3) -> Linear(768,  3) -> Informativeness logits
```

### Training Config
- Learning rate: 2e-5, Batch size: 16, Epochs: 3, Max length: 128 tokens
- Optimizer: AdamW (weight_decay=0.01), Scheduler: Linear warmup (10%)
- Loss: CrossEntropyLoss per task (weighted for emotion), summed equally

### Prerequisites
- Run notebook 06 first to create the labelled datasets
- Starting with the 10K sample for development, switch to full 52.8K for final training

## 1. Environment Setup & Installs

In [None]:
# Google Colab: Uncomment the following lines to install required packages
# !pip install transformers accelerate -q
# !pip install scikit-learn tqdm matplotlib seaborn -q

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import time
import warnings
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 100)

print("\u2705 Libraries loaded successfully!")

## 2. Device Setup & Configuration

In [None]:
# Google Colab: Uncomment to mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Device setup
if torch.cuda.is_available():
    device = torch.device('cuda')
    gpu_name = torch.cuda.get_device_name(0)
    # Corrected: Changed 'total_mem' to 'total_memory'
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"✅ GPU available: {gpu_name} ({gpu_mem:.1f} GB)")
else:
    device = torch.device('cpu')
    print("☢️  No GPU found - using CPU (training will be slow)")

print(f"Device: {device}")

In [None]:
# ============================================================
# CONFIGURATION - Edit this cell to change settings
# ============================================================

CONFIG = {
    # --- DATA ---
    # Switch between 10K sample and full dataset by changing this path:
    # 'DATA_PATH': 'master_training_data/master_training_sample_v5_labelled (3).csv',
    # 'DATA_PATH': 'master_training_data/master_training_data_v5_labelled.csv',  # Full 52.8K
    
    # For Colab with Google Drive:
    'DATA_PATH': '/content/drive/MyDrive/Tempo/master_training_sample_v5_labelled (3).csv',
    
    # --- MODEL ---
    'MODEL_NAME': 'bert-base-uncased',
    'MAX_LENGTH': 128,
    
    # --- TRAINING ---
    'BATCH_SIZE': 16,
    'EPOCHS': 3,
    'LEARNING_RATE': 2e-5,
    'WEIGHT_DECAY': 0.01,
    'WARMUP_RATIO': 0.1,
    'DROPOUT': 0.3,
    'GRADIENT_ACCUMULATION_STEPS': 1,  # Increase to 2 or 4 if GPU runs out of memory
    
    # --- TASKS ---
    'NUM_EMOTION_CLASSES': 13,
    'NUM_CRISIS_CLASSES': 2,
    'NUM_INFO_CLASSES': 3,
    
    # --- SAVING ---
    'SAVE_DIR': '/content/drive/MyDrive/Tempo/saved_models/multitask_bert',
    'RANDOM_SEED': 42,
}

# Set seeds for reproducibility
torch.manual_seed(CONFIG['RANDOM_SEED'])
np.random.seed(CONFIG['RANDOM_SEED'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG['RANDOM_SEED'])

print("\u2705 Configuration set:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Load Data

In [None]:
df = pd.read_csv(CONFIG['DATA_PATH'])
print(f"\u2705 Loaded {len(df):,} rows from {CONFIG['DATA_PATH']}")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nDtypes:\n{df.dtypes}")
print(f"\nNull counts:\n{df.isnull().sum()}")
print(f"\nFirst 3 rows:")
df.head(3)

## 4. Emotion Label Merge (18 -> 13)

The dataset contains **18 unique emotion names** from two different labeling schemes (GoEmotions original vs Gemini LLM). The `emotion_label` column is **inconsistent** (same number maps to different emotions depending on the source), so we use `emotion_name` as the source of truth and merge 5 emotions:

| Merge From | Merge To | Rationale |
|-----------|----------|----------|
| admiration | gratitude | Both express positive regard/appreciation |
| amusement | joy | Both express positive/happy feelings |
| annoyance | anger | Annoyance is a milder form of anger |
| curiosity | confusion | Both involve uncertainty/seeking understanding |
| desire | excitement | Both involve positive anticipation |

In [None]:
# Pre-merge distribution
print("PRE-MERGE emotion distribution (18 unique):")
print(df['emotion_name'].value_counts())
print(f"\nUnique emotion names: {df['emotion_name'].nunique()}")

In [None]:
# Define the merge mapping
EMOTION_MERGE = {
    'admiration': 'gratitude',
    'amusement': 'joy',
    'annoyance': 'anger',
    'curiosity': 'confusion',
    'desire': 'excitement',
}

# Apply merge
df['emotion_name'] = df['emotion_name'].replace(EMOTION_MERGE)

# Verify exactly 13 unique emotions
assert df['emotion_name'].nunique() == 13, f"Expected 13, got {df['emotion_name'].nunique()}"

# Define the canonical 13 emotions (ordered for label indices 0-12)
EMOTION_NAMES_13 = [
    'fear', 'anger', 'sadness', 'anxiety', 'confusion',
    'surprise', 'disgust', 'caring', 'joy', 'excitement',
    'gratitude', 'disappointment', 'neutral'
]

# Create 0-indexed mapping for PyTorch CrossEntropyLoss
EMOTION_TO_IDX = {name: idx for idx, name in enumerate(EMOTION_NAMES_13)}
IDX_TO_EMOTION = {idx: name for name, idx in EMOTION_TO_IDX.items()}

# Apply new labels
df['emotion_idx'] = df['emotion_name'].map(EMOTION_TO_IDX)

# Validate no NaNs
assert df['emotion_idx'].isna().sum() == 0, "Unmapped emotion names found!"

# Drop the old unreliable emotion_label column
df.drop(columns=['emotion_label'], inplace=True)

print("\u2705 POST-MERGE emotion distribution (13 classes, 0-indexed):")
for name, idx in EMOTION_TO_IDX.items():
    count = (df['emotion_idx'] == idx).sum()
    print(f"  {idx:2d}: {name:<16} ({count:,} samples)")

## 5. Encode Crisis & Informativeness Labels

In [None]:
# Crisis: ensure integer type (may be float in full dataset)
df['crisis_label'] = df['crisis_label'].astype(int)

# Informativeness: string -> 0-indexed integer
INFO_NAMES = ['related_informative', 'related_not_informative', 'not_related']
INFO_TO_IDX = {name: idx for idx, name in enumerate(INFO_NAMES)}
IDX_TO_INFO = {idx: name for name, idx in INFO_TO_IDX.items()}

df['info_idx'] = df['informativeness'].map(INFO_TO_IDX)
assert df['info_idx'].isna().sum() == 0, "Unmapped informativeness values found!"

print("\u2705 Crisis label distribution:")
print(df['crisis_label'].value_counts())

print(f"\n\u2705 Informativeness distribution (0-indexed):")
for name, idx in INFO_TO_IDX.items():
    count = (df['info_idx'] == idx).sum()
    print(f"  {idx}: {name:<28} ({count:,} samples)")

## 6. Compute Class Weights (Imbalanced Emotions)

Emotion classes are highly imbalanced (e.g., anxiety: ~132 vs neutral: ~12,649 in the full dataset). We use inverse-frequency class weights so rare emotions contribute equally to the loss.

In [None]:
# Compute balanced class weights for emotion
emotion_labels_array = df['emotion_idx'].values
emotion_class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(CONFIG['NUM_EMOTION_CLASSES']),
    y=emotion_labels_array
)
emotion_class_weights = torch.tensor(emotion_class_weights, dtype=torch.float32).to(device)

print("\u2705 Emotion class weights (balanced):")
for idx in range(CONFIG['NUM_EMOTION_CLASSES']):
    name = IDX_TO_EMOTION[idx]
    weight = emotion_class_weights[idx].item()
    count = (df['emotion_idx'] == idx).sum()
    print(f"  {idx:2d} {name:<16}: weight={weight:.4f}  (n={count:,})")

## 7. Train/Test Split (80/20, Stratified)

In [None]:
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    random_state=CONFIG['RANDOM_SEED'],
    stratify=df['emotion_idx']
)

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

print(f"\u2705 Train: {len(train_df):,} rows")
print(f"\u2705 Val:   {len(val_df):,} rows")

# Verify all classes present in both splits
train_emotions = set(train_df['emotion_idx'].unique())
val_emotions = set(val_df['emotion_idx'].unique())
assert train_emotions == set(range(13)), f"Missing train emotions: {set(range(13)) - train_emotions}"
assert val_emotions == set(range(13)), f"Missing val emotions: {set(range(13)) - val_emotions}"
print(f"\n\u2705 All 13 emotion classes present in both splits")

print(f"\nTrain crisis: {dict(train_df['crisis_label'].value_counts())}")
print(f"Val crisis:   {dict(val_df['crisis_label'].value_counts())}")

## 8. Dataset & DataLoader

In [None]:
class MultiTaskDataset(Dataset):
    """PyTorch Dataset for multi-task BERT training."""
    
    def __init__(self, dataframe, tokenizer, max_length):
        self.texts = dataframe['text'].tolist()
        self.emotion_labels = dataframe['emotion_idx'].tolist()
        self.crisis_labels = dataframe['crisis_label'].tolist()
        self.info_labels = dataframe['info_idx'].tolist()
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'emotion_label': torch.tensor(self.emotion_labels[idx], dtype=torch.long),
            'crisis_label': torch.tensor(self.crisis_labels[idx], dtype=torch.long),
            'info_label': torch.tensor(self.info_labels[idx], dtype=torch.long),
        }

print("\u2705 MultiTaskDataset class defined")

In [None]:
# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained(CONFIG['MODEL_NAME'])

# Create datasets
train_dataset = MultiTaskDataset(train_df, tokenizer, CONFIG['MAX_LENGTH'])
val_dataset = MultiTaskDataset(val_df, tokenizer, CONFIG['MAX_LENGTH'])

# Create dataloaders (num_workers=0 for Windows/Colab compatibility)
train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=0)

print(f"\u2705 Train: {len(train_dataset):,} samples, {len(train_loader):,} batches")
print(f"\u2705 Val:   {len(val_dataset):,} samples, {len(val_loader):,} batches")

# Sanity check: inspect one batch
batch = next(iter(train_loader))
print(f"\nSample batch shapes:")
for key, val in batch.items():
    print(f"  {key}: {val.shape}")

## 9. Model Architecture

Shared BERT encoder with three independent classification heads:

```
[CLS] embedding (768-dim)
   |
   |---> Dropout(0.3) -> Linear(768, 13) -> Emotion logits
   |---> Dropout(0.3) -> Linear(768,  2) -> Crisis logits
   |---> Dropout(0.3) -> Linear(768,  3) -> Informativeness logits
```

In [None]:
class MultiTaskBERT(nn.Module):
    def __init__(self, model_name, num_emotions, num_crisis, num_info, dropout):
        super(MultiTaskBERT, self).__init__()
        
        # Shared BERT encoder
        self.bert = BertModel.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size  # 768
        
        # Task-specific classification heads
        self.emotion_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_emotions)
        )
        
        self.crisis_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_crisis)
        )
        
        self.info_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_info)
        )
    
    def forward(self, input_ids, attention_mask):
        # Shared encoder
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # [CLS] token embedding
        cls_output = outputs.last_hidden_state[:, 0, :]  # (batch_size, 768)
        
        # Task-specific logits
        return {
            'emotion': self.emotion_head(cls_output),   # (batch_size, 13)
            'crisis': self.crisis_head(cls_output),     # (batch_size, 2)
            'info': self.info_head(cls_output),         # (batch_size, 3)
        }


# Instantiate model
model = MultiTaskBERT(
    model_name=CONFIG['MODEL_NAME'],
    num_emotions=CONFIG['NUM_EMOTION_CLASSES'],
    num_crisis=CONFIG['NUM_CRISIS_CLASSES'],
    num_info=CONFIG['NUM_INFO_CLASSES'],
    dropout=CONFIG['DROPOUT']
)
model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\u2705 Model loaded on {device}")
print(f"   Total parameters:     {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")

## 10. Loss Functions, Optimizer & Scheduler

In [None]:
# Loss functions
emotion_criterion = nn.CrossEntropyLoss(weight=emotion_class_weights)
crisis_criterion = nn.CrossEntropyLoss()
info_criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['LEARNING_RATE'],
    weight_decay=CONFIG['WEIGHT_DECAY']
)

# Scheduler with linear warmup
total_steps = (len(train_loader) // CONFIG['GRADIENT_ACCUMULATION_STEPS']) * CONFIG['EPOCHS']
warmup_steps = int(total_steps * CONFIG['WARMUP_RATIO'])

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f"\u2705 Training setup:")
print(f"   Total training steps: {total_steps:,}")
print(f"   Warmup steps: {warmup_steps:,}")
print(f"   Effective batch size: {CONFIG['BATCH_SIZE'] * CONFIG['GRADIENT_ACCUMULATION_STEPS']}")

## 11. Training Loop

In [None]:
def train_one_epoch(model, loader, optimizer, scheduler,
                    emotion_criterion, crisis_criterion, info_criterion,
                    device, accumulation_steps=1):
    """Train for one epoch. Returns dict of metrics."""
    model.train()
    
    total_loss = 0
    task_losses = {'emotion': 0, 'crisis': 0, 'info': 0}
    task_correct = {'emotion': 0, 'crisis': 0, 'info': 0}
    total_samples = 0
    
    optimizer.zero_grad()
    
    pbar = tqdm(enumerate(loader), total=len(loader), desc="Training")
    for step, batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        emotion_labels = batch['emotion_label'].to(device)
        crisis_labels = batch['crisis_label'].to(device)
        info_labels = batch['info_label'].to(device)
        
        # Forward pass
        logits = model(input_ids, attention_mask)
        
        # Compute losses for each task
        e_loss = emotion_criterion(logits['emotion'], emotion_labels)
        c_loss = crisis_criterion(logits['crisis'], crisis_labels)
        i_loss = info_criterion(logits['info'], info_labels)
        
        # Total loss (equal weighting)
        loss = e_loss + c_loss + i_loss
        
        # Scale loss for gradient accumulation
        scaled_loss = loss / accumulation_steps
        scaled_loss.backward()
        
        if (step + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # Track metrics
        batch_size = input_ids.size(0)
        total_loss += loss.item() * batch_size
        task_losses['emotion'] += e_loss.item() * batch_size
        task_losses['crisis'] += c_loss.item() * batch_size
        task_losses['info'] += i_loss.item() * batch_size
        
        task_correct['emotion'] += (logits['emotion'].argmax(dim=1) == emotion_labels).sum().item()
        task_correct['crisis'] += (logits['crisis'].argmax(dim=1) == crisis_labels).sum().item()
        task_correct['info'] += (logits['info'].argmax(dim=1) == info_labels).sum().item()
        total_samples += batch_size
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{total_loss/total_samples:.4f}",
            'e_acc': f"{task_correct['emotion']/total_samples:.3f}",
            'c_acc': f"{task_correct['crisis']/total_samples:.3f}",
        })
    
    # Handle remaining gradients if steps not divisible by accumulation_steps
    if (step + 1) % accumulation_steps != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    
    return {
        'loss': total_loss / total_samples,
        'emotion_loss': task_losses['emotion'] / total_samples,
        'crisis_loss': task_losses['crisis'] / total_samples,
        'info_loss': task_losses['info'] / total_samples,
        'emotion_acc': task_correct['emotion'] / total_samples,
        'crisis_acc': task_correct['crisis'] / total_samples,
        'info_acc': task_correct['info'] / total_samples,
    }


def evaluate(model, loader, emotion_criterion, crisis_criterion, info_criterion, device):
    """Evaluate on validation set. Returns dict of metrics + predictions."""
    model.eval()
    
    total_loss = 0
    task_losses = {'emotion': 0, 'crisis': 0, 'info': 0}
    task_correct = {'emotion': 0, 'crisis': 0, 'info': 0}
    total_samples = 0
    
    all_preds = {'emotion': [], 'crisis': [], 'info': []}
    all_labels = {'emotion': [], 'crisis': [], 'info': []}
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            emotion_labels = batch['emotion_label'].to(device)
            crisis_labels = batch['crisis_label'].to(device)
            info_labels = batch['info_label'].to(device)
            
            logits = model(input_ids, attention_mask)
            
            e_loss = emotion_criterion(logits['emotion'], emotion_labels)
            c_loss = crisis_criterion(logits['crisis'], crisis_labels)
            i_loss = info_criterion(logits['info'], info_labels)
            loss = e_loss + c_loss + i_loss
            
            batch_size = input_ids.size(0)
            total_loss += loss.item() * batch_size
            task_losses['emotion'] += e_loss.item() * batch_size
            task_losses['crisis'] += c_loss.item() * batch_size
            task_losses['info'] += i_loss.item() * batch_size
            
            e_preds = logits['emotion'].argmax(dim=1)
            c_preds = logits['crisis'].argmax(dim=1)
            i_preds = logits['info'].argmax(dim=1)
            
            task_correct['emotion'] += (e_preds == emotion_labels).sum().item()
            task_correct['crisis'] += (c_preds == crisis_labels).sum().item()
            task_correct['info'] += (i_preds == info_labels).sum().item()
            total_samples += batch_size
            
            all_preds['emotion'].extend(e_preds.cpu().numpy())
            all_preds['crisis'].extend(c_preds.cpu().numpy())
            all_preds['info'].extend(i_preds.cpu().numpy())
            all_labels['emotion'].extend(emotion_labels.cpu().numpy())
            all_labels['crisis'].extend(crisis_labels.cpu().numpy())
            all_labels['info'].extend(info_labels.cpu().numpy())
    
    return {
        'loss': total_loss / total_samples,
        'emotion_loss': task_losses['emotion'] / total_samples,
        'crisis_loss': task_losses['crisis'] / total_samples,
        'info_loss': task_losses['info'] / total_samples,
        'emotion_acc': task_correct['emotion'] / total_samples,
        'crisis_acc': task_correct['crisis'] / total_samples,
        'info_acc': task_correct['info'] / total_samples,
        'all_preds': all_preds,
        'all_labels': all_labels,
    }

print("\u2705 Training and evaluation functions defined")

In [None]:
# Create save directory
os.makedirs(CONFIG['SAVE_DIR'], exist_ok=True)

# Training history for plotting
history = {
    'train_loss': [], 'val_loss': [],
    'train_emotion_acc': [], 'val_emotion_acc': [],
    'train_crisis_acc': [], 'val_crisis_acc': [],
    'train_info_acc': [], 'val_info_acc': [],
}

best_val_loss = float('inf')
best_epoch = -1

print("=" * 80)
print("STARTING TRAINING")
print("=" * 80)
start_time = time.time()

for epoch in range(CONFIG['EPOCHS']):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch+1}/{CONFIG['EPOCHS']}")
    print(f"{'='*80}")
    
    # Train
    train_metrics = train_one_epoch(
        model, train_loader, optimizer, scheduler,
        emotion_criterion, crisis_criterion, info_criterion,
        device, CONFIG['GRADIENT_ACCUMULATION_STEPS']
    )
    
    # Evaluate
    val_metrics = evaluate(
        model, val_loader,
        emotion_criterion, crisis_criterion, info_criterion,
        device
    )
    
    # Store history
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    history['train_emotion_acc'].append(train_metrics['emotion_acc'])
    history['val_emotion_acc'].append(val_metrics['emotion_acc'])
    history['train_crisis_acc'].append(train_metrics['crisis_acc'])
    history['val_crisis_acc'].append(val_metrics['crisis_acc'])
    history['train_info_acc'].append(train_metrics['info_acc'])
    history['val_info_acc'].append(val_metrics['info_acc'])
    
    # Print epoch summary
    print(f"\n--- Epoch {epoch+1} Summary ---")
    print(f"{'Metric':<25} {'Train':>10} {'Val':>10}")
    print(f"{'-'*45}")
    print(f"{'Total Loss':<25} {train_metrics['loss']:>10.4f} {val_metrics['loss']:>10.4f}")
    print(f"{'Emotion Loss':<25} {train_metrics['emotion_loss']:>10.4f} {val_metrics['emotion_loss']:>10.4f}")
    print(f"{'Crisis Loss':<25} {train_metrics['crisis_loss']:>10.4f} {val_metrics['crisis_loss']:>10.4f}")
    print(f"{'Informativeness Loss':<25} {train_metrics['info_loss']:>10.4f} {val_metrics['info_loss']:>10.4f}")
    print(f"{'Emotion Accuracy':<25} {train_metrics['emotion_acc']:>10.4f} {val_metrics['emotion_acc']:>10.4f}")
    print(f"{'Crisis Accuracy':<25} {train_metrics['crisis_acc']:>10.4f} {val_metrics['crisis_acc']:>10.4f}")
    print(f"{'Info Accuracy':<25} {train_metrics['info_acc']:>10.4f} {val_metrics['info_acc']:>10.4f}")
    
    # Save best model based on validation loss
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        best_epoch = epoch + 1
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_metrics['loss'],
            'config': CONFIG,
            'emotion_to_idx': EMOTION_TO_IDX,
            'info_to_idx': INFO_TO_IDX,
        }, os.path.join(CONFIG['SAVE_DIR'], 'best_model.pt'))
        
        print(f"\n  ** New best model saved (val_loss={best_val_loss:.4f}) **")
    else:
        print(f"\n  Val loss did not improve (best={best_val_loss:.4f} at epoch {best_epoch})")

elapsed = time.time() - start_time
print(f"\n{'='*80}")
print(f"\u2705 Training complete in {elapsed/60:.1f} minutes")
print(f"   Best val loss: {best_val_loss:.4f} (epoch {best_epoch})")
print(f"{'='*80}")

## 12. Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
epochs_range = range(1, CONFIG['EPOCHS'] + 1)

# Total loss
axes[0, 0].plot(epochs_range, history['train_loss'], 'b-o', label='Train')
axes[0, 0].plot(epochs_range, history['val_loss'], 'r-o', label='Val')
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Emotion accuracy
axes[0, 1].plot(epochs_range, history['train_emotion_acc'], 'b-o', label='Train')
axes[0, 1].plot(epochs_range, history['val_emotion_acc'], 'r-o', label='Val')
axes[0, 1].set_title('Emotion Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Crisis accuracy
axes[1, 0].plot(epochs_range, history['train_crisis_acc'], 'b-o', label='Train')
axes[1, 0].plot(epochs_range, history['val_crisis_acc'], 'r-o', label='Val')
axes[1, 0].set_title('Crisis Accuracy')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Informativeness accuracy
axes[1, 1].plot(epochs_range, history['train_info_acc'], 'b-o', label='Train')
axes[1, 1].plot(epochs_range, history['val_info_acc'], 'r-o', label='Val')
axes[1, 1].set_title('Informativeness Accuracy')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.suptitle('Multi-Task BERT Training Curves', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(CONFIG['SAVE_DIR'], 'training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()
print("\u2705 Training curves saved")

## 13. Load Best Model & Final Evaluation

In [None]:
# Load best model checkpoint
checkpoint = torch.load(
    os.path.join(CONFIG['SAVE_DIR'], 'best_model.pt'),
    map_location=device,
    weights_only=False
)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\u2705 Loaded best model from epoch {checkpoint['epoch']+1} (val_loss={checkpoint['val_loss']:.4f})")

# Final evaluation
final_metrics = evaluate(
    model, val_loader,
    emotion_criterion, crisis_criterion, info_criterion,
    device
)

print(f"\nFinal Validation Metrics:")
print(f"  Total Loss:       {final_metrics['loss']:.4f}")
print(f"  Emotion Accuracy: {final_metrics['emotion_acc']:.4f}")
print(f"  Crisis Accuracy:  {final_metrics['crisis_acc']:.4f}")
print(f"  Info Accuracy:    {final_metrics['info_acc']:.4f}")

## 14. Classification Reports

In [None]:
print("=" * 80)
print("EMOTION CLASSIFICATION REPORT (13 classes)")
print("=" * 80)
print(classification_report(
    final_metrics['all_labels']['emotion'],
    final_metrics['all_preds']['emotion'],
    target_names=EMOTION_NAMES_13,
    digits=4,
    zero_division=0
))

In [None]:
print("=" * 80)
print("CRISIS CLASSIFICATION REPORT (binary)")
print("=" * 80)
print(classification_report(
    final_metrics['all_labels']['crisis'],
    final_metrics['all_preds']['crisis'],
    target_names=['non_crisis', 'crisis'],
    digits=4,
    zero_division=0
))

In [None]:
print("=" * 80)
print("INFORMATIVENESS CLASSIFICATION REPORT (3 classes)")
print("=" * 80)
print(classification_report(
    final_metrics['all_labels']['info'],
    final_metrics['all_preds']['info'],
    target_names=INFO_NAMES,
    digits=4,
    zero_division=0
))

## 15. Confusion Matrices

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(22, 6))

# Emotion confusion matrix
cm_emotion = confusion_matrix(
    final_metrics['all_labels']['emotion'],
    final_metrics['all_preds']['emotion']
)
sns.heatmap(cm_emotion, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=EMOTION_NAMES_13, yticklabels=EMOTION_NAMES_13)
axes[0].set_title('Emotion Confusion Matrix')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')
axes[0].tick_params(axis='x', rotation=45)
axes[0].tick_params(axis='y', rotation=0)

# Crisis confusion matrix
cm_crisis = confusion_matrix(
    final_metrics['all_labels']['crisis'],
    final_metrics['all_preds']['crisis']
)
sns.heatmap(cm_crisis, annot=True, fmt='d', cmap='Oranges', ax=axes[1],
            xticklabels=['non_crisis', 'crisis'], yticklabels=['non_crisis', 'crisis'])
axes[1].set_title('Crisis Confusion Matrix')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')

# Informativeness confusion matrix
cm_info = confusion_matrix(
    final_metrics['all_labels']['info'],
    final_metrics['all_preds']['info']
)
sns.heatmap(cm_info, annot=True, fmt='d', cmap='Greens', ax=axes[2],
            xticklabels=['informative', 'not_inform.', 'not_related'],
            yticklabels=['informative', 'not_inform.', 'not_related'])
axes[2].set_title('Informativeness Confusion Matrix')
axes[2].set_xlabel('Predicted')
axes[2].set_ylabel('True')

plt.suptitle('Multi-Task BERT - Confusion Matrices', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(CONFIG['SAVE_DIR'], 'confusion_matrices.png'), dpi=150, bbox_inches='tight')
plt.show()
print("\u2705 Confusion matrices saved")

## 16. Save Final Artifacts

In [None]:
# Save tokenizer
tokenizer.save_pretrained(CONFIG['SAVE_DIR'])

# Save label mappings as JSON
label_mappings = {
    'emotion_to_idx': EMOTION_TO_IDX,
    'idx_to_emotion': {str(k): v for k, v in IDX_TO_EMOTION.items()},
    'info_to_idx': INFO_TO_IDX,
    'idx_to_info': {str(k): v for k, v in IDX_TO_INFO.items()},
    'emotion_merge_map': EMOTION_MERGE,
    'emotion_names_13': EMOTION_NAMES_13,
    'info_names': INFO_NAMES,
}

with open(os.path.join(CONFIG['SAVE_DIR'], 'label_mappings.json'), 'w') as f:
    json.dump(label_mappings, f, indent=2)

# Save training history
with open(os.path.join(CONFIG['SAVE_DIR'], 'training_history.json'), 'w') as f:
    json.dump(history, f, indent=2)

print(f"\u2705 All artifacts saved to: {CONFIG['SAVE_DIR']}/")
print(f"   - best_model.pt (model checkpoint)")
print(f"   - vocab.txt, tokenizer_config.json (tokenizer)")
print(f"   - label_mappings.json")
print(f"   - training_history.json")
print(f"   - training_curves.png")
print(f"   - confusion_matrices.png")

## 17. Quick Inference Test

In [None]:
def predict(text, model, tokenizer, device):
    """Run inference on a single text. Returns predicted labels and confidences."""
    model.eval()
    encoding = tokenizer(
        text,
        max_length=CONFIG['MAX_LENGTH'],
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        logits = model(input_ids, attention_mask)
    
    emotion_probs = torch.softmax(logits['emotion'], dim=1)
    crisis_probs = torch.softmax(logits['crisis'], dim=1)
    info_probs = torch.softmax(logits['info'], dim=1)
    
    emotion_idx = emotion_probs.argmax(dim=1).item()
    crisis_idx = crisis_probs.argmax(dim=1).item()
    info_idx = info_probs.argmax(dim=1).item()
    
    return {
        'emotion': IDX_TO_EMOTION[emotion_idx],
        'emotion_confidence': emotion_probs[0, emotion_idx].item(),
        'crisis': 'crisis' if crisis_idx == 1 else 'non_crisis',
        'crisis_confidence': crisis_probs[0, crisis_idx].item(),
        'informativeness': IDX_TO_INFO[info_idx],
        'info_confidence': info_probs[0, info_idx].item(),
    }


# Test on sample texts
test_texts = [
    "Massive earthquake hits the coast, buildings collapsed, people trapped under rubble",
    "GOOOAL! France wins the World Cup! What an incredible match!",
    "Please donate blood at the Red Cross center to help hurricane victims",
    "Just finished watching the new Game of Thrones episode, it was okay",
    "I am so scared, the wildfire is getting closer to our neighborhood",
]

print("=" * 80)
print("INFERENCE TEST")
print("=" * 80)
for text in test_texts:
    result = predict(text, model, tokenizer, device)
    display_text = f'"{text[:80]}..."' if len(text) > 80 else f'"{text}"'
    print(f"\nText: {display_text}")
    print(f"  Emotion: {result['emotion']} ({result['emotion_confidence']:.3f})")
    print(f"  Crisis:  {result['crisis']} ({result['crisis_confidence']:.3f})")
    print(f"  Info:    {result['informativeness']} ({result['info_confidence']:.3f})")

## 18. Final Summary

In [None]:
print("=" * 80)
print("MULTI-TASK BERT TRAINING COMPLETE")
print("=" * 80)

print(f"\nDataset: {CONFIG['DATA_PATH']}")
print(f"  Train: {len(train_df):,} | Val: {len(val_df):,}")

print(f"\nBest Model (epoch {best_epoch}):")
print(f"  Val Loss:       {best_val_loss:.4f}")
print(f"  Emotion Acc:    {final_metrics['emotion_acc']:.4f}")
print(f"  Crisis Acc:     {final_metrics['crisis_acc']:.4f}")
print(f"  Info Acc:       {final_metrics['info_acc']:.4f}")

print(f"\nArtifacts saved to: {CONFIG['SAVE_DIR']}/")

print(f"\nNext Steps:")
print(f"  1. If using 10K sample: re-run with full 52.8K dataset (change DATA_PATH in CONFIG)")
print(f"  2. Apply trained model to original full datasets (67K crisis + 2.3M non-crisis)")
print(f"  3. Extract emotion features per tweet for RL agent")
print(f"  4. Create episodes and hourly aggregations for RL training")

print(f"\n{'='*80}")