In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml


In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.optim import Adam
from pathlib import Path
import wandb
import yaml
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
import sys
sys.path.append('.')

print(os.getcwd())

# Import project modules
from src.data.dataset import create_dataloader
from src.models.offline_teacher import OfflineTeacherModel
from src.config.tokenization_config import PAD_TOKEN
from src.training.utils.schedulers import get_warmup_schedule


In [None]:
# Load configuration
print(os.getcwd())
config_path = 'src/training/configs/offline_teacher_base.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize wandb
run_name = f"offline_L{config['num_layers']}_H{config['num_heads']}_D{config['embed_dim']}_seq{config['max_sequence_length']}_bs{config['batch_size']}_lr{config['learning_rate']}"

wandb.init(
    project=config['wandb_project'],
    name=run_name,
    config=config,
    job_type="offline_training"
)

In [None]:
# Initialize model, optimizer, criterion, and scheduler
model = OfflineTeacherModel(
    melody_vocab_size=tokenizer_info['melody_vocab_size'],
    chord_vocab_size=tokenizer_info['chord_vocab_size'],
    embed_dim=config['embed_dim'],
    num_heads=config['num_heads'],
    num_layers=config['num_layers'],
    dropout=config['dropout'],
    max_seq_length=config['max_sequence_length'],
    pad_token_id=PAD_TOKEN
).to(device)

optimizer = Adam(model.parameters(), lr=config['learning_rate'])
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)

# Initialize warmup scheduler
scheduler = get_warmup_schedule(optimizer, num_warmup_steps=config['warmup_steps'])


In [None]:
# Training loop
best_val_loss = float('inf')
global_step = 0

try:
    for epoch in range(config['max_epochs']):
        print(f"\nEpoch {epoch + 1}/{config['max_epochs']}")
        
        # Clear GPU memory before each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"GPU memory at start of epoch: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        
        # Training
        train_loss = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
        
        # Validation
        val_loss = validate(model, val_loader, criterion, device)
        
        # Log metrics
        wandb.log({
            'train/loss': train_loss,
            'valid/loss': val_loss,
            'learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch + 1,
            'gpu_memory_usage': torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
        })
        
        # Save checkpoint if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = Path(wandb.run.dir) / f"offline_teacher_epoch_{epoch+1}.pth"
            
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config,
            }, checkpoint_path)
            
            wandb.save(str(checkpoint_path))
            print(f"\nSaved checkpoint with validation loss: {val_loss:.4f}")
            
except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except torch.cuda.OutOfMemoryError:
    print("\nOut of GPU memory! Try reducing batch size or sequence length")
finally:
    wandb.finish()
    print("\nTraining completed")


In [None]:
def load_best_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['val_loss']

# Example usage:
# best_checkpoint_path = Path(wandb.run.dir) / 'best_checkpoint.pth'
# best_val_loss = load_best_checkpoint(model, best_checkpoint_path)
# print(f"Loaded checkpoint with validation loss: {best_val_loss:.4f}")
