## Model Training Pipeline: ESM2 with LoRA Fine-tuning

This notebook handles the complete model training pipeline:

1. **Environment Setup** - Check GPU availability and load configurations
2. **Data Loading** - Load tokenized datasets from disk
3. **Model Initialization** - Load ESM2 model with LoRA adapters
4. **Training Setup** - Configure training arguments and initialize trainer
5. **Model Training** - Train the model with validation monitoring
6. **Save Model** - Save the trained model checkpoint for later evaluation

After running this notebook, you should have:
- `models/checkpoints/` - Model checkpoints saved during training
- `models/final_model/` - Final trained model
- Training metrics logged to Weights & Biases

**Note:** For detailed evaluation and analysis, run `03_evaluation.ipynb` after training.

### Import Libraries

In [None]:
import yaml
import torch
import wandb
from pathlib import Path

from transformers import (
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model
from datasets import load_from_disk
from evaluate import load
import pandas as pd

# Enable cuDNN benchmark for potential performance gains on fixed-size inputs
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

## 1. Environment Setup & Configuration

In [None]:
# Check GPU availability
print("=" * 70)
print("ENVIRONMENT CHECK")
print("=" * 70)

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    cuda_version = torch.version.cuda
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"✓ GPU detected: {device_name}")
    print(f"  CUDA version: {cuda_version}")
    print(f"  Available memory: {total_memory:.2f} GB")
    print(f"  Number of GPUs: {torch.cuda.device_count()}")
else:
    print("⚠ WARNING: No GPU detected!")
    print("  Training will be extremely slow on CPU.")
    print("  Consider using a GPU-enabled environment.")

print(f"\nPyTorch version: {torch.__version__}")
print("=" * 70)

In [None]:
# Load configuration file
print("\n=== LOADING CONFIGURATION ===")

with open("../config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

print("\nModel Configuration:")
print(f"  Base model: {cfg['model']['name']}")
print(f"  Number of labels: {cfg['model']['num_labels']}")

print("\nLoRA Configuration:")
for key, value in cfg['model']['lora'].items():
    print(f"  {key}: {value}")

print("\nTraining Configuration:")
print(f"  Epochs: {cfg['training']['epochs']}")
print(f"  Batch size: {cfg['training']['batch_size']}")
print(f"  Learning rate: {cfg['training']['learning_rate']}")
print(f"  Gradient accumulation steps: {cfg['training']['grad_accum_steps']}")
print(f"  FP16: {cfg['training']['fp16']}")

print("\n✓ Configuration loaded successfully")

## 2. Load Tokenized Datasets

In [None]:
print("\n=== LOADING TOKENIZED DATASETS ===")

# Load datasets from disk
train_path = "../data/tokenized/train_dataset"
val_path = "../data/tokenized/val_dataset"

print(f"Loading train dataset from: {train_path}")
train_ds = load_from_disk(train_path).with_format("torch")

print(f"Loading validation dataset from: {val_path}")
val_ds = load_from_disk(val_path).with_format("torch")

print("\n=== DATASET SUMMARY ===")
print(f"Train samples: {len(train_ds):,}")
print(f"Validation samples: {len(val_ds):,}")
print(f"Total samples: {len(train_ds) + len(val_ds):,}")

print("\nDataset features:")
print(f"  {train_ds.features}")

print("\n✓ Datasets loaded successfully")

In [None]:
# Verify class distribution in loaded datasets
print("\n=== VERIFYING CLASS DISTRIBUTION ===")

# Get label distributions
train_labels = [int(label) for label in train_ds['labels']]
val_labels = [int(label) for label in val_ds['labels']]

train_dist = pd.Series(train_labels).value_counts().sort_index()
val_dist = pd.Series(val_labels).value_counts().sort_index()

print(f"\n{'Class':<7} {'Train':>10} {'Val':>10} {'Total':>10}")
print("-" * 40)

ec_names = [
    "Oxidoreductases",
    "Transferases",
    "Hydrolases",
    "Lyases",
    "Isomerases",
    "Ligases",
    "Translocases"
]

for i in range(7):
    train_c = train_dist.get(i, 0)
    val_c = val_dist.get(i, 0)
    total_c = train_c + val_c
    print(f"{i} ({ec_names[i][:3]}): {train_c:>10,} {val_c:>10,} {total_c:>10,}")

print("-" * 40)
print(f"{'Total':<7} {len(train_ds):>10,} {len(val_ds):>10,} {len(train_ds)+len(val_ds):>10,}")

print("\n✓ Class distribution verified")

## 3. Model Initialization with LoRA

In [None]:
print("\n=== LOADING BASE MODEL ===")

model_name = cfg['model']['name']
num_labels = cfg['model']['num_labels']

print(f"Model: {model_name}")
print(f"Task: Sequence Classification with {num_labels} labels")

# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nUsing device: {device}")

# Load base model
print("\nLoading base model...")
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    trust_remote_code=True
)

print("✓ Base model loaded")

# Count parameters before LoRA
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters in base model: {total_params:,}")

In [None]:
print("\n=== APPLYING LoRA ADAPTERS ===")

lora_cfg = cfg['model']['lora']

# Configure LoRA
lora_config = LoraConfig(
    r=lora_cfg['r'],
    lora_alpha=lora_cfg['alpha'],
    target_modules=lora_cfg['target_modules'],
    lora_dropout=lora_cfg['dropout'],
    bias=lora_cfg['bias'],
    task_type="SEQ_CLS"
)

print("LoRA Configuration:")
print(f"  Rank (r): {lora_config.r}")
print(f"  Alpha: {lora_config.lora_alpha}")
print(f"  Target modules: {lora_config.target_modules}")
print(f"  Dropout: {lora_config.lora_dropout}")
print(f"  Bias: {lora_config.bias}")

# Apply LoRA to model
print("\nApplying LoRA adapters...")
model = get_peft_model(model, lora_config)

print("\n=== TRAINABLE PARAMETERS ===")
model.print_trainable_parameters()

# Move model to device
model = model.to(device)
print(f"\n✓ Model moved to {device}")

## 4. Training Setup

In [None]:
# Define evaluation metrics
print("\n=== SETTING UP EVALUATION METRICS ===")

accuracy_metric = load("accuracy")

def compute_metrics(eval_pred):
    """Compute accuracy for evaluation"""
    logits, labels = eval_pred
    predictions = logits.argmax(-1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

print("✓ Evaluation metrics configured")
print("  Primary metric: accuracy")

In [None]:
# Initialize Weights & Biases
print("\n=== INITIALIZING WEIGHTS & BIASES ===")

if cfg['wandb']['enabled']:
    wandb.init(
        project=cfg['wandb']['project'],
        name=cfg['wandb']['run_name'],
        config=cfg
    )
    print(f"✓ W&B initialized")
    print(f"  Project: {cfg['wandb']['project']}")
    print(f"  Run name: {cfg['wandb']['run_name']}")
else:
    print("⚠ W&B logging disabled in config")

In [None]:
# Configure training arguments
print("\n=== CONFIGURING TRAINING ARGUMENTS ===")

output_dir = (Path('..') / 'models' / 'checkpoints').resolve()
output_dir.mkdir(parents=True, exist_ok=True)
output_dir = str(output_dir)

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=cfg['training']['epochs'],
    per_device_train_batch_size=cfg['training']['batch_size'],
    per_device_eval_batch_size=cfg['training']['batch_size'],
    gradient_accumulation_steps=cfg['training']['grad_accum_steps'],
    learning_rate=float(cfg['training']['learning_rate']),
    fp16=cfg['training']['fp16'] and torch.cuda.is_available(),
    
    # Evaluation and saving
    eval_strategy=cfg['training']['eval_strategy'],
    save_strategy=cfg['training']['save_strategy'],
    save_total_limit=cfg['training']['save_total_limit'],
    load_best_model_at_end=cfg['training']['load_best_model_at_end'],
    metric_for_best_model=cfg['training']['metric_for_best_model'],
    greater_is_better=True,
    
    # Logging
    logging_dir=f"{output_dir}/logs",
    logging_steps=10,
    report_to=["wandb"] if cfg['wandb']['enabled'] else [],
    run_name=cfg['wandb']['run_name'],
    
    # Data loading
    dataloader_num_workers=cfg['training']['num_workers'],
    dataloader_pin_memory=cfg['training']['pin_memory'],
    dataloader_persistent_workers=cfg['training']['persistent_workers'],
)

print("Training arguments configured:")
print(f"  Output directory: {output_dir}")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  FP16: {training_args.fp16}")
print(f"  Evaluation strategy: {training_args.eval_strategy}")
print(f"  Save strategy: {training_args.save_strategy}")

# Calculate effective batch size
effective_batch_size = (
    training_args.per_device_train_batch_size * 
    training_args.gradient_accumulation_steps
)
print(f"\nEffective batch size: {effective_batch_size}")

# Estimate training steps
steps_per_epoch = len(train_ds) // effective_batch_size
total_steps = steps_per_epoch * training_args.num_train_epochs
print(f"Estimated steps per epoch: {steps_per_epoch:,}")
print(f"Estimated total training steps: {total_steps:,}")

print("\n✓ Training arguments configured")

In [None]:
# Initialize trainer
print("\n=== INITIALIZING TRAINER ===")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

print("✓ Trainer initialized")
print(f"  Training samples: {len(trainer.train_dataset):,}")
print(f"  Evaluation samples: {len(trainer.eval_dataset):,}")

## 5. Model Training

In [None]:
from pathlib import Path

# Build checkpoint directory relative to notebook (not hard-coded '../')
checkpoint_dir = (Path('..') / cfg['data']['paths']['checkpoints']).resolve()

def _checkpoint_number(path: Path):
    try:
        return int(path.name.split('-')[-1])
    except Exception:
        return -1

def get_latest_checkpoint(output_dir):
    """Return the latest checkpoint Path or None"""
    if not isinstance(output_dir, Path):
        output_dir = Path(output_dir)
    if not output_dir.exists():
        return None

    checkpoints = [p for p in output_dir.iterdir() if p.is_dir() and p.name.startswith('checkpoint-')]
    if not checkpoints:
        return None

    latest = max(checkpoints, key=_checkpoint_number)
    return latest

# Start training
print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70)
print("\nThis may take a while depending on your hardware...\n")

checkpoint_path = get_latest_checkpoint(checkpoint_dir)

if checkpoint_path is not None:
    adapter_model = checkpoint_path / 'adapter_model.safetensors'
    trainer_state = checkpoint_path / 'trainer_state.json'

    if adapter_model.exists() and trainer_state.exists():
        print(f"Resuming from checkpoint: {checkpoint_path}\n")
        scaler_path = checkpoint_path / 'scaler.pt'
        if scaler_path.exists():
            scaler_path.unlink()
        train_result = trainer.train(resume_from_checkpoint=str(checkpoint_path))
    else:
        print("Invalid checkpoint. Starting fresh training...\n")
        train_result = trainer.train()
else:
    print("No checkpoint found. Starting fresh training...\n")
    train_result = trainer.train()

# Display results
print("\n" + "=" * 70)
print("TRAINING COMPLETED")
print("=" * 70)

if train_result:
    print("\nFinal metrics:")
    for key, value in train_result.metrics.items():
        print(f"  {key}: {value}")

# Save final model
final_model_path = (Path('..') / cfg['data']['paths']['checkpoints'] / 'final_model').resolve()
print(f"\nSaving final model to: {final_model_path}")
trainer.save_model(str(final_model_path))
print("✓ Model saved successfully!\n")

## 6. Save Final Model

In [None]:
# Save the final model
print("\n=== SAVING FINAL MODEL ===")

final_model_dir = (Path('..') / 'models' / 'final_model').resolve()
final_model_dir.mkdir(parents=True, exist_ok=True)

print(f"Saving model to: {final_model_dir}")
trainer.save_model(str(final_model_dir))

print("\nModel saved successfully")
print(f"\nSaved files:")
for file in final_model_dir.iterdir():
    print(f"  - {file.name}")

In [None]:
# Quick validation evaluation
print("\n=== QUICK VALIDATION CHECK ===")

eval_results = trainer.evaluate()

print("\nFinal Validation Metrics:")
for key, value in eval_results.items():
    if key.startswith('eval_'):
        metric_name = key.replace('eval_', '').replace('_', ' ').title()
        print(f"  {metric_name}: {value:.4f}")

print("\nValidation check complete")

In [None]:
# Finish W&B run
if cfg['wandb']['enabled']:
    print("\n=== FINALIZING WEIGHTS & BIASES ===")
    wandb.finish()
    print("✓ W&B run finished")