# 04 - Training Engine

Training the ModernBERT-RGAT joint model on SemEval restaurant datasets:

1. Setup and imports
2. Loss function unit tests
3. Single-batch overfit sanity check
4. Train on SemEval 2014
5. Train on SemEval 2015
6. Train on SemEval 2016
7. Training curves visualization
8. Results summary

---

## 1. Setup & Imports

In [None]:
import subprocess, sys, os

# ---------------------------------------------------------------
# SELECT THE MIG DEVICE WITH FREE MEMORY
# The H100 has 3 MIG partitions. Device 0 & 1 are nearly full.
# Device 2 has ~8 GB free -- more than enough for ModernBERT-RGAT.
# This MUST be set before importing torch.
# ---------------------------------------------------------------
os.environ['CUDA_VISIBLE_DEVICES'] = 'MIG-be0d4dc8-2244-5ed0-89ec-b674eacb6a9b'
print(f'CUDA_VISIBLE_DEVICES = {os.environ["CUDA_VISIBLE_DEVICES"]}')

def install_if_missing(package, pip_name=None):
    try:
        __import__(package)
    except ImportError:
        print(f'Installing {pip_name or package}...')
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pip_name or package])

install_if_missing('sklearn', 'scikit-learn')
install_if_missing('transformers', 'transformers')
install_if_missing('spacy', 'spacy')

import spacy
try:
    spacy.load('en_core_web_sm')
except OSError:
    subprocess.check_call([sys.executable, '-m', 'spacy', 'download', 'en_core_web_sm'])

print('Dependencies ready.')

In [None]:
PROJECT_ROOT = os.path.expanduser('~/SOTA-ModernBERT-RGAT-Joint-Aspect-Sentiment-Extraction-for-Food-Tech-Reviews')
assert os.path.isdir(PROJECT_ROOT), f'Project root not found: {PROJECT_ROOT}'
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)
print(f'Project root: {PROJECT_ROOT}')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.model import ModernBERT_RGAT
from src.dataset import ABSAPreprocessor, ABSADataset, create_dataloader
from src.data_pipeline import load_config, build_splits, compute_class_weights
from src.losses import JointTaskLoss, FocalLoss, AlphaScheduler
from src.trainer import Trainer, compute_strict_f1, extract_spans_from_bio

print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
assert torch.cuda.is_available(), 'CUDA must be available!'

device = torch.device('cuda')
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Device count: {torch.cuda.device_count()}')

try:
    free_bytes, total_bytes = torch.cuda.mem_get_info(0)
    print(f'GPU memory: {total_bytes/1024**3:.1f} GB total, {free_bytes/1024**3:.1f} GB free')
except Exception as e:
    print(f'mem_get_info error: {e}')

print(f'Using device: {device}')
print('Imports successful.')

In [None]:
config_path = os.path.join(PROJECT_ROOT, 'configs', 'config.yaml')
config = load_config(config_path)

print('Config loaded.')
print(f'  Epochs:       {config["training"]["epochs"]}')
print(f'  Batch size:   {config["training"]["batch_size"]}')
print(f'  LR:           {config["training"]["learning_rate"]}')
print(f'  Alpha:        {config["training"]["alpha_start"]} -> {config["training"]["alpha_end"]}')
print(f'  Focal gamma:  {config["training"]["focal_gamma"]}')
print(f'  FP16:         {config["training"]["fp16"]}')

## 2. Loss Function Unit Tests

In [None]:
# Verifying the loss function works correctly before training
# These tests run on CPU (no GPU needed for small tensors)

print('Test 1: JointTaskLoss forward pass')
dummy_ate_logits = torch.randn(2, 96, 3, requires_grad=True)
dummy_ate_labels = torch.randint(0, 3, (2, 96))
dummy_asc_logits = torch.randn(2, 4, requires_grad=True)
dummy_asc_labels = torch.randint(0, 4, (2,))

loss_fn = JointTaskLoss(alpha_start=0.7, alpha_end=0.3, gamma=2.0, total_steps=100)
total, ate, asc = loss_fn(dummy_ate_logits, dummy_ate_labels, dummy_asc_logits, dummy_asc_labels)
print(f'  Total: {total.item():.4f}, ATE: {ate.item():.4f}, ASC: {asc.item():.4f}')

# Verify gradients flow
total.backward()
print(f'  ATE logits grad shape: {dummy_ate_logits.grad.shape}')
print(f'  ASC logits grad shape: {dummy_asc_logits.grad.shape}')
print('  Gradients flow: OK')

print()
print('Test 2: Alpha scheduling')
scheduler = AlphaScheduler(alpha_start=0.7, alpha_end=0.3, total_steps=10)
alphas = []
for i in range(12):
    alphas.append(scheduler.get_alpha())
    scheduler.step()
print(f'  Alpha over 12 steps: {[f"{a:.2f}" for a in alphas]}')
print(f'  Starts at 0.70, ends at 0.30: {alphas[0]:.2f} -> {alphas[-1]:.2f}')

print()
print('Test 3: BIO span extraction')
test_bio = [0, 0, 1, 2, 0, 1, 0, 0]   # two aspects: (2,4) and (5,6)
spans = extract_spans_from_bio(test_bio)
print(f'  BIO: {test_bio}')
print(f'  Extracted spans: {spans}')
assert spans == [(2, 4), (5, 6)], f'Expected [(2,4), (5,6)], got {spans}'
print('  Span extraction: OK')

print()
print('Test 4: Strict F1 computation')
pred = [[0, 0, 1, 2, 0, 1, 0],    # spans: (2,4), (5,6)
        [0, 1, 2, 2, 0, 0, 0]]     # spans: (1,4)
gold = [[0, 0, 1, 2, 0, 0, 1],    # spans: (2,4), (6,7)
        [0, 1, 2, 2, 0, 0, 0]]     # spans: (1,4)
f1_result = compute_strict_f1(pred, gold)
print(f'  Pred spans: (2,4),(5,6) and (1,4)')
print(f'  Gold spans: (2,4),(6,7) and (1,4)')
print(f'  Exact matches: 2 (out of 3 pred, 3 gold)')
print(f'  P={f1_result["precision"]:.4f}, R={f1_result["recall"]:.4f}, F1={f1_result["f1"]:.4f}')

print()
print('All unit tests passed.')

## 3. Single-Batch Overfit Test

Before full training, I verify the training loop is wired correctly by
overfitting on a single batch for 50 iterations. If the loss drops to
near zero, the forward/backward/optimizer pipeline works.

In [None]:
# Building a small test dataloader for the overfit test
preprocessor = ABSAPreprocessor(
    model_name=config['model']['backbone'],
    max_len=config['model']['max_len'],
)

overfit_df = pd.DataFrame([
    {'sentence_id': '0001', 'sentence': 'The spicy ramen was incredibly delicious.',
     'aspect': 'ramen', 'polarity': 'positive', 'span_start': 10, 'span_end': 15},
    {'sentence_id': '0002', 'sentence': 'Bad service and rude staff.',
     'aspect': 'service', 'polarity': 'negative', 'span_start': 4, 'span_end': 11},
])

label_map = config['labels']['polarity']
overfit_loader = create_dataloader(overfit_df, preprocessor, label_map, batch_size=2, shuffle=False)
overfit_batch = next(iter(overfit_loader))

print(f'Overfit batch keys: {list(overfit_batch.keys())}')
print(f'Batch size: {overfit_batch["input_ids"].shape[0]}')

In [None]:
# Overfit on single batch on GPU
torch.cuda.empty_cache()

overfit_model = ModernBERT_RGAT(
    model_name=config['model']['backbone'],
    hidden_dim=config['model']['hidden_dim'],
    num_sentiment_classes=config['model']['num_sentiment_classes'],
    num_bio_tags=config['model']['num_bio_tags'],
    num_relations=config['model']['rgat']['num_relations'],
    rgat_dropout=config['model']['rgat']['dropout'],
).to(device)

overfit_loss_fn = JointTaskLoss(alpha_start=0.5, alpha_end=0.5, gamma=0.0, total_steps=50)
overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=5e-5)

# Move batch to device
batch_on_device = {k: v.to(device) for k, v in overfit_batch.items()}

print(f'Starting overfit test on {device} (50 iterations on 1 batch)...')
overfit_model.train()
losses = []

for step in range(50):
    outputs = overfit_model(
        input_ids=batch_on_device['input_ids'],
        attention_mask=batch_on_device['attention_mask'],
        adj_matrix=batch_on_device['adj_matrix'],
        aspect_mask=batch_on_device['aspect_mask'],
    )
    total, ate, asc = overfit_loss_fn(
        outputs['ate_logits'], batch_on_device['bio_labels'],
        outputs['sentiment_logits'], batch_on_device['sentiment_label'],
    )
    
    overfit_optimizer.zero_grad()
    total.backward()
    overfit_optimizer.step()
    losses.append(total.item())
    
    if (step + 1) % 10 == 0:
        print(f'  Step {step+1:3d}: loss={total.item():.4f} (ate={ate.item():.4f}, asc={asc.item():.4f})')

print(f'\nLoss dropped from {losses[0]:.4f} to {losses[-1]:.4f}')
if losses[-1] < losses[0] * 0.5:
    print('Overfit test PASSED -- loss decreased significantly.')
else:
    print('WARNING: Loss did not decrease enough. Check the training pipeline.')

# Clean up to free GPU memory
del overfit_model, overfit_optimizer, overfit_loss_fn, batch_on_device
torch.cuda.empty_cache()

## 4. Train on SemEval 2014

Full training run with early stopping and checkpointing.

In [None]:
def train_on_dataset(config, year, preprocessor, device):
    """
    Full training pipeline for one dataset year.
    Returns the trainer object (for accessing history and metrics).
    """
    label_map = config['labels']['polarity']
    
    # Load and split data
    print(f'\nLoading SemEval {year} data...')
    train_df, val_df, test_df = build_splits(config, year, verbose=True)
    print(f'  Train: {len(train_df)} rows, Val: {len(val_df)} rows, Test: {len(test_df)} rows')
    
    # Create dataloaders
    train_loader = create_dataloader(
        train_df, preprocessor, label_map,
        batch_size=config['training']['batch_size'],
        shuffle=True,
    )
    val_loader = create_dataloader(
        val_df, preprocessor, label_map,
        batch_size=config['training']['eval_batch_size'],
        shuffle=False,
    )
    print(f'  Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')
    
    # Compute class weights from training data
    class_weights = compute_class_weights(train_df, label_map)
    weight_tensor = torch.tensor(
        [class_weights[i] for i in range(len(label_map))],
        dtype=torch.float,
    )
    print(f'  Class weights: {class_weights}')
    
    # Clean GPU before loading model
    torch.cuda.empty_cache()
    
    # Build fresh model
    model = ModernBERT_RGAT(
        model_name=config['model']['backbone'],
        hidden_dim=config['model']['hidden_dim'],
        num_sentiment_classes=config['model']['num_sentiment_classes'],
        num_bio_tags=config['model']['num_bio_tags'],
        num_relations=config['model']['rgat']['num_relations'],
        rgat_dropout=config['model']['rgat']['dropout'],
    )
    
    # Create trainer and run
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        sentiment_weights=weight_tensor,
        device=device,
        dataset_year=year,
    )
    
    results = trainer.train()
    
    # Clean up after training to free GPU for next dataset
    del trainer.model, trainer.optimizer, trainer.scaler
    torch.cuda.empty_cache()
    
    return trainer, results

In [None]:
trainer_2014, results_2014 = train_on_dataset(config, '2014', preprocessor, device)

## 5. Train on SemEval 2015

In [None]:
trainer_2015, results_2015 = train_on_dataset(config, '2015', preprocessor, device)

## 6. Train on SemEval 2016

In [None]:
trainer_2016, results_2016 = train_on_dataset(config, '2016', preprocessor, device)

## 7. Training Curves

In [None]:
def plot_training_curves(results_dict):
    """Plot training curves for all datasets."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    for col, (year, results) in enumerate(results_dict.items()):
        history = results['history']
        epochs = range(1, len(history['train_loss']) + 1)
        
        # Top row: Loss curves
        ax = axes[0, col]
        ax.plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
        ax.plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
        ax.set_title(f'SemEval {year} - Loss', fontsize=13, fontweight='bold')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Bottom row: F1 curves
        ax = axes[1, col]
        ax.plot(epochs, history['val_ate_f1'], 'g-', label='ATE F1 (Strict)', linewidth=2)
        ax.plot(epochs, history['val_asc_f1'], 'm-', label='ASC F1 (Macro)', linewidth=2)
        ax.plot(epochs, history['val_asc_accuracy'], 'c--', label='ASC Accuracy', linewidth=1.5)
        ax.set_title(f'SemEval {year} - Validation Metrics', fontsize=13, fontweight='bold')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Score')
        ax.legend()
        ax.set_ylim(0, 1)
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('ModernBERT-RGAT Training Curves', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    # Save the plot
    os.makedirs('outputs/plots', exist_ok=True)
    plt.savefig('outputs/plots/training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
    print('Training curves saved: outputs/plots/training_curves.png')

all_results = {}
if results_2014: all_results['2014'] = results_2014
if results_2015: all_results['2015'] = results_2015
if results_2016: all_results['2016'] = results_2016

if all_results:
    plot_training_curves(all_results)

## 8. Results Summary

In [None]:
# Compile final results table from all three training runs
summary_rows = []

for year, results in all_results.items():
    best = results.get('best_val_metrics', {})
    if best:
        summary_rows.append({
            'Dataset': f'SemEval {year}',
            'ATE P': f"{best.get('ate_precision', 0):.4f}",
            'ATE R': f"{best.get('ate_recall', 0):.4f}",
            'ATE F1 (Strict)': f"{best.get('ate_f1', 0):.4f}",
            'ASC Acc': f"{best.get('asc_accuracy', 0):.4f}",
            'ASC F1 (Macro)': f"{best.get('asc_f1', 0):.4f}",
            'Time (min)': f"{results.get('total_time_seconds', 0)/60:.1f}",
        })

if summary_rows:
    summary_df = pd.DataFrame(summary_rows)
    display(summary_df)
    
    # Save results
    os.makedirs('outputs/results', exist_ok=True)
    summary_df.to_csv('outputs/results/training_summary.csv', index=False)
    print('\nResults saved: outputs/results/training_summary.csv')
else:
    print('No training results available yet.')

In [None]:
# RGAT relation weights after training (interpretability check)
print('RGAT Relation Importance Weights (post-training):')
print('='*50)

for year, results in all_results.items():
    checkpoint_path = os.path.join(config['training']['checkpoint_dir'], f'best_model_{year}.pt')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        # Extract relation importance from state dict
        rel_names = ['nsubj', 'amod', 'obj', 'advmod', 'neg', 'compound', 'conj']
        raw_weights = checkpoint['model_state_dict']['rgat.relation_importance']
        weights = torch.sigmoid(raw_weights).tolist()
        
        print(f'\n  SemEval {year}:')
        sorted_rels = sorted(zip(rel_names, weights), key=lambda x: -x[1])
        for name, w in sorted_rels:
            bar = '#' * int(w * 30)
            print(f'    {name:10s}: {w:.4f}  {bar}')

print(f'\n{"="*50}')

---

## Phase 4 Summary

| Component | Status |
|-----------|--------|
| JointTaskLoss (ATE CE + ASC Focal) | Done |
| Dynamic alpha scheduling (0.7 -> 0.3) | Done |
| Trainer with FP16 + gradient clipping | Done |
| LR warmup + linear decay | Done |
| Strict F1 (exact span match) | Done |
| Early stopping (patience=5) | Done |
| Checkpoint save/load | Done |
| Training curves + results | Done |

**Next step:** Phase 5 - Evaluation & Benchmarking