# 03 - Model Training with TensorBoard

**AI-Powered Code Review Assistant**  
**CS 5590 - Final Project**

---

## Objectives

This notebook implements the complete training pipeline:

1. **Initialize** CodeBERT model for multi-label classification
2. **Configure** training with justified hyperparameters
3. **Train** with TensorBoard monitoring
4. **Save** best model checkpoints
5. **Analyze** training curves

---

## CRISP-DM Phase: Modeling

This notebook corresponds to **Phase 4** of the CRISP-DM methodology.

---

## GPU Setup (Google Colab)

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = torch.device('cuda')
else:
    print("⚠ No GPU available, using CPU (training will be slow)")
    device = torch.device('cpu')

print(f"\nUsing device: {device}")

## 1. Setup

In [None]:
try:
    import google.colab
    IN_COLAB = True
    !git clone https://github.com/darshlukkad/Code-Review-Assistant.git
    %cd Code-Review-Assistant
except ImportError:
    IN_COLAB = False

In [None]:
!pip install -q transformers torch tensorboard scikit-learn tqdm

In [None]:
import sys
sys.path.append('src')

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from transformers import AdamW, get_linear_schedule_with_warmup
import pandas as pd
from tqdm import tqdm
import json
import os

# Import our modules
from models.model import CodeBERTClassifier
from data.preprocessing import CodePreprocessor
from training.config import training_config, model_config

print("✓ All libraries imported")

## 2. Load Preprocessed Data

In [None]:
# Load data splits from preprocessing notebook
from torch.utils.data import DataLoader
from data.preprocessing import CodePreprocessor

# Load splits
train_df = pd.read_csv('train_split.csv')
val_df = pd.read_csv('val_split.csv')

print(f"Train: {len(train_df):,} samples")
print(f"Val:   {len(val_df):,} samples")

# Initialize preprocessor
preprocessor = CodePreprocessor()

# Create datasets and loaders will be done using preprocessing code
print("\n✓ Data loaded successfully")

## 3. Initialize Model

### Model Architecture: CodeBERT

**Why CodeBERT?**
- Pre-trained on code from 6 languages (2.1M samples)
- Understands code structure better than BERT
- State-of-the-art on code understanding tasks

**Architecture Details:**
- 12 transformer layers
- 768 hidden dimensions
- 12 attention heads
- Multi-label classification head (5 outputs)

In [None]:
# Initialize model
model = CodeBERTClassifier(
    model_name="microsoft/codebert-base",
    num_labels=5,
    hidden_dropout_prob=0.1
)

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Architecture:")
print("="*80)
print(f"Total parameters:     {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size:          ~{total_params * 4 / 1e6:.0f} MB")
print("\n✓ Model initialized successfully")

## 4. Configure Training Hyperparameters

### Complete Hyperparameter Justifications

| Hyperparameter | Value | Justification |
|----------------|-------|---------------|
| **Learning Rate** | 2e-5 | Standard for BERT fine-tuning, provides stable convergence |
| **Optimizer** | AdamW | Decoupled weight decay prevents overfitting better than Adam |
| **Weight Decay** | 0.01 | Regularization to prevent overfitting, standard for transformers |
| **Batch Size** | 32 | Balances GPU memory (16GB) with gradient stability |
| **Epochs** | 15 | Sufficient with early stopping (patience=3) |
| **Warmup Steps** | 500 | Linear warmup prevents unstable early training |
| **Max Grad Norm** | 1.0 | Prevents exploding gradients during training |
| **Dropout** | 0.1 | Regularization without losing model capacity |
| **Loss Function** | BCEWithLogitsLoss | Multi-label requires independent probabilities per class |
| **Activation** | GELU | Smoother gradients than ReLU, standard in transformers |
| **Scheduler** | Linear | Gradual learning rate decay after warmup |

In [None]:
# Training configuration
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
BATCH_SIZE = 32
NUM_EPOCHS = 15
WARMUP_STEPS = 500
MAX_GRAD_NORM = 1.0
EARLY_STOPPING_PATIENCE = 3

# Optimizer: AdamW
optimizer = AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999),
    eps=1e-8
)

# Learning rate scheduler with warmup
num_training_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=num_training_steps
)

# Loss function (already in model, but we can override)
# BCEWithLogitsLoss combines sigmoid + BCE for numerical stability

print("Training Configuration:")
print("="*80)
print(f"Learning rate:        {LEARNING_RATE}")
print(f"Weight decay:         {WEIGHT_DECAY}")
print(f"Batch size:           {BATCH_SIZE}")
print(f"Num epochs:           {NUM_EPOCHS}")
print(f"Warmup steps:         {WARMUP_STEPS}")
print(f"Total training steps: {num_training_steps:,}")
print(f"Early stopping:       {EARLY_STOPPING_PATIENCE} epochs")
print("\n✓ Training configured")

## 5. Set Up TensorBoard

TensorBoard will log:
- Training/validation loss per step
- Learning rate schedule
- Gradient norms
- Model graph

In [None]:
# Initialize TensorBoard writer
writer = SummaryWriter('logs/tensorboard')

# Log hyperparameters
hparams = {
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'batch_size': BATCH_SIZE,
    'num_epochs': NUM_EPOCHS,
    'warmup_steps': WARMUP_STEPS,
    'max_grad_norm': MAX_GRAD_NORM
}

writer.add_hparams(hparams, {})

print("✓ TensorBoard initialized")
print("  Run: tensorboard --logdir=logs/tensorboard")

# In Colab, load TensorBoard extension
if IN_COLAB:
    %load_ext tensorboard
    %tensorboard --logdir logs/tensorboard

## 6. Training Loop

### Training Process:

For each epoch:
1. **Training phase:**
   - Forward pass through model
   - Compute loss (BCEWithLogitsLoss)
   - Backward pass (compute gradients)
   - Clip gradients (prevent explosion)
   - Update weights
   - Update learning rate

2. **Validation phase:**
   - Evaluate on validation set
   - Check for improvement
   - Save best model

3. **Early stopping:**
   - Stop if no improvement for 3 epochs
   - Prevents overfitting

In [None]:
def train_epoch(model, loader, optimizer, scheduler, device, epoch, writer, global_step):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch}")
    
    for batch_idx, batch in enumerate(progress_bar):
        # Move to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(input_ids, attention_mask, labels)
        loss = outputs['loss']
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        
        # Update weights
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Log to TensorBoard
        global_step += 1
        if global_step % 100 == 0:
            writer.add_scalar('train/loss', loss.item(), global_step)
            writer.add_scalar('train/lr', scheduler.get_last_lr()[0], global_step)
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader), global_step


@torch.no_grad()
def validate(model, loader, device):
    """Validate on validation set."""
    model.eval()
    total_loss = 0
    
    for batch in tqdm(loader, desc="Validating"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask, labels)
        loss = outputs['loss']
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

print("✓ Training functions defined")

## 7. Train the Model

In [None]:
# Training loop
best_val_loss = float('inf')
epochs_without_improvement = 0
global_step = 0

train_losses = []
val_losses = []

print("Starting training...")
print("="*80)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
    print("-" * 80)
    
    # Train
    train_loss, global_step = train_epoch(
        model, train_loader, optimizer, scheduler,
        device, epoch, writer, global_step
    )
    
    # Validate
    val_loss = validate(model, val_loader, device)
    
    # Log epoch metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    writer.add_scalar('epoch/train_loss', train_loss, epoch)
    writer.add_scalar('epoch/val_loss', val_loss, epoch)
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss:   {val_loss:.4f}")
    
    # Check for improvement
    if val_loss < best_val_loss:
        print(f"✓ New best! Saving model...")
        best_val_loss = val_loss
        epochs_without_improvement = 0
        
        # Save best model
        os.makedirs('models', exist_ok=True)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, 'models/best_model.pt')
    else:
        epochs_without_improvement += 1
        print(f"No improvement for {epochs_without_improvement} epochs")
    
    # Early stopping
    if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
        print(f"\n⚠ Early stopping triggered after {epoch} epochs")
        break

print("\n" + "="*80)
print("Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")

writer.close()

## 8. Visualize Training Progress

In [None]:
import matplotlib.pyplot as plt

# Plot training curves
plt.figure(figsize=(10, 6))
epochs_range = range(1, len(train_losses) + 1)

plt.plot(epochs_range, train_losses, 'b-', label='Training Loss', linewidth=2)
plt.plot(epochs_range, val_losses, 'r-', label='Validation Loss', linewidth=2)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved: training_curves.png")

## 9. Save Training Metadata

In [None]:
# Save training history
training_history = {
    'epochs_trained': len(train_losses),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'best_val_loss': best_val_loss,
    'hyperparameters': hparams,
    'total_steps': global_step
}

with open('training_history.json', 'w') as f:
    json.dump(training_history, f, indent=2)

print("✓ Training history saved")

## Summary

### Training Results

- **Epochs trained:** X (update after running)
- **Best validation loss:** X.XXXX
- **Total training steps:** X,XXX
- **Training time:** ~X hours

### Model Saved

Best model checkpoint saved to: `models/best_model.pt`

### Next Step: Evaluation (04-evaluation.ipynb)

Now we'll:
- Load the best model
- Evaluate on test set
- Compute all metrics (F1, AUC, precision, recall)
- Create comprehensive visualizations
- Perform ablation studies