In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml


In [None]:
# Clone the repository if running on Colab
import os
if not os.path.exists('Martydepth'):
    !git clone https://github.com/yourusername/Martydepth.git
    %cd Martydepth
else:
    %cd Martydepth


In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.optim import Adam
from pathlib import Path
import wandb
import yaml
from tqdm.notebook import tqdm

# Add project root to Python path
import sys
sys.path.append('.')

# Import project modules
from src.data.dataset import create_dataloader
from src.models.offline_teacher import OfflineTeacherModel
from src.config.tokenization_config import PAD_TOKEN


In [None]:
# Load configuration
config_path = 'src/training/configs/offline_teacher_base.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize wandb
run_name = f"offline_L{config['num_layers']}_H{config['num_heads']}_D{config['embed_dim']}_seq{config['max_sequence_length']}_bs{config['batch_size']}_lr{config['learning_rate']}"

wandb.init(
    project=config['wandb_project'],
    name=run_name,
    config=config,
    job_type="offline_training"
)


In [None]:
# Create dataloaders
train_loader, tokenizer_info = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="train",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_sequence_length'],
    mode='offline',
    shuffle=True
)

val_loader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="valid",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_sequence_length'],
    mode='offline',
    shuffle=False
)


In [None]:
# Initialize model
model = OfflineTeacherModel(
    melody_vocab_size=tokenizer_info['melody_vocab_size'],
    chord_vocab_size=tokenizer_info['chord_vocab_size'],
    embed_dim=config['embed_dim'],
    num_heads=config['num_heads'],
    num_layers=config['num_layers'],
    dropout=config['dropout'],
    max_seq_length=config['max_sequence_length'],
    pad_token_id=tokenizer_info.get('pad_token_id', PAD_TOKEN)
).to(device)

# Loss function
criterion = nn.CrossEntropyLoss(ignore_index=model.pad_token_id)

# Optimizer
optimizer = Adam(model.parameters(), lr=config['learning_rate'])

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    verbose=True
)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters.")


In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch in pbar:
        # Move batch to device
        melody_tokens = batch['melody_tokens'].to(device)
        chord_input = batch['chord_input'].to(device)
        chord_target = batch['chord_target'].to(device)
        
        # Create masks
        melody_mask = (melody_tokens == model.pad_token_id)
        chord_mask = (chord_input == model.pad_token_id)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(melody_tokens, chord_input, melody_mask, chord_mask)
        
        # Calculate loss
        loss = criterion(logits.view(-1, logits.size(-1)), chord_target.view(-1))
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Update progress bar
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
        
    return total_loss / len(train_loader)

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validating'):
            # Move batch to device
            melody_tokens = batch['melody_tokens'].to(device)
            chord_input = batch['chord_input'].to(device)
            chord_target = batch['chord_target'].to(device)
            
            # Create masks
            melody_mask = (melody_tokens == model.pad_token_id)
            chord_mask = (chord_input == model.pad_token_id)
            
            # Forward pass
            logits = model(melody_tokens, chord_input, melody_mask, chord_mask)
            
            # Calculate loss
            loss = criterion(logits.view(-1, logits.size(-1)), chord_target.view(-1))
            total_loss += loss.item()
            
    return total_loss / len(val_loader)


In [None]:
# Training loop
best_val_loss = float('inf')

try:
    for epoch in range(config['epochs']):
        print(f"\nEpoch {epoch + 1}/{config['epochs']}")
        
        # Training
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validation
        val_loss = validate(model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Log metrics
        wandb.log({
            'train/loss': train_loss,
            'valid/loss': val_loss,
            'learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch + 1
        })
        
        # Save checkpoint if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = Path(wandb.run.dir) / f"offline_teacher_epoch_{epoch+1}.pth"
            
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config,
            }, checkpoint_path)
            
            wandb.save(str(checkpoint_path))
            print(f"\nSaved checkpoint with validation loss: {val_loss:.4f}")
            
except KeyboardInterrupt:
    print("\nTraining interrupted by user")
finally:
    wandb.finish()
    print("\nTraining completed")


In [None]:
def load_best_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['val_loss']

# Example usage:
# best_checkpoint_path = Path(wandb.run.dir) / 'best_checkpoint.pth'
# best_val_loss = load_best_checkpoint(model, best_checkpoint_path)
# print(f"Loaded checkpoint with validation loss: {best_val_loss:.4f}")
