In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml


In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.optim import Adam
from pathlib import Path
import wandb
import yaml
import json
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
import sys
sys.path.append('.')

print(os.getcwd())

# Import project modules
from src.data.dataset import create_dataloader
from src.models.online_transformer import OnlineTransformer
from src.config.tokenization_config import PAD_TOKEN
from src.training.utils.schedulers import get_warmup_schedule


In [None]:
# Load configuration
print(os.getcwd())
config_path = 'src/training/configs/online_transformer_base.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize wandb
run_name = f"online_L{config['num_layers']}_H{config['num_heads']}_D{config['embed_dim']}_seq{config['max_sequence_length']}_bs{config['batch_size']}_lr{config['learning_rate']}"

wandb.init(
    project=config['wandb_project'],
    name=run_name,
    config=config,
    job_type="online_training"
)


In [None]:
# Initialize model, optimizer, criterion, and scheduler
model = OnlineTransformer(
    vocab_size=tokenizer_info['total_vocab_size'],
    embed_dim=config['embed_dim'],
    num_heads=config['num_heads'],
    num_layers=config['num_layers'],
    dropout=config['dropout'],
    max_seq_length=config['max_sequence_length'] * 2,  # Double for interleaved sequences
    pad_token_id=PAD_TOKEN
).to(device)

optimizer = Adam(model.parameters(), lr=config['learning_rate'])
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)

# Initialize warmup scheduler
scheduler = get_warmup_schedule(optimizer, num_warmup_steps=config['warmup_steps'])


In [None]:
# Training loop
best_val_loss = float('inf')
global_step = 0

try:
    for epoch in range(config['max_epochs']):
        print(f"\nEpoch {epoch + 1}/{config['max_epochs']}")
        
        # Clear GPU memory before each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"GPU memory at start of epoch: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        
        # Training Step
        train_loss = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
        
        # Validation Step
        val_loss = validate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}")
        wandb.log({
            'train/epoch_loss': train_loss,
            'valid/epoch_loss': val_loss,
            'epoch': epoch + 1
        }, step=global_step)
        
        # Save checkpoint if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
            # Save checkpoint locally
            checkpoint_path = Path("checkpoints") / f"online_transformer_epoch_{epoch+1}.pth"
            checkpoint_path.parent.mkdir(exist_ok=True)
            
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config,
            }
            
            torch.save(checkpoint, checkpoint_path)
            
            # Log checkpoint as wandb artifact
            artifact = wandb.Artifact(
                name=f"online_transformer_model_{wandb.run.id}",
                type="model",
                description=f"Online Transformer Model checkpoint from epoch {epoch+1}"
            )
            artifact.add_file(str(checkpoint_path))
            wandb.log_artifact(artifact)
            
            # Also save tokenizer info as artifact
            tokenizer_artifact = wandb.Artifact(
                name=f"tokenizer_info_{wandb.run.id}",
                type="tokenizer",
                description="Tokenizer information used for training"
            )
            tokenizer_path = Path("tokenizer_info.json")
            with open(tokenizer_path, 'w') as f:
                json.dump(tokenizer_info, f)
            tokenizer_artifact.add_file(str(tokenizer_path))
            wandb.log_artifact(tokenizer_artifact)
            
            print(f"\nSaved checkpoint with validation loss: {val_loss:.4f}")
            
except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except torch.cuda.OutOfMemoryError:
    print("\nOut of GPU memory! Try reducing batch size or sequence length")
finally:
    wandb.finish()
    print("\nTraining completed")


In [None]:
def load_best_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['val_loss']

# Example usage:
# best_checkpoint_path = Path(wandb.run.dir) / 'best_checkpoint.pth'
# best_val_loss = load_best_checkpoint(model, best_checkpoint_path)
# print(f"Loaded checkpoint with validation loss: {best_val_loss:.4f}")
