In [None]:
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .

In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml transformers


In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.optim import Adam
from transformers import Adafactor
from pathlib import Path
import wandb
import yaml
import json
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
import sys
sys.path.append('.')

print(os.getcwd())

# Import project modules
from src.data.dataset import create_dataloader
from src.models.offline_teacher_t5 import T5OfflineTeacherModel
from src.config.tokenization_config import PAD_TOKEN
from src.training.utils.schedulers import get_warmup_schedule


In [None]:
# Load configuration
print(os.getcwd())
config_path = 'src/training/configs/offline_teacher_base.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize wandb
run_name = f"offline_L{config['num_layers']}_H{config['num_heads']}_D{config['embed_dim']}_seq{config['max_sequence_length']}_bs{config['batch_size']}_lr{config['learning_rate']}"

wandb.init(
    project=config['wandb_project'],
    name=run_name,
    config=config,
    job_type="offline_training"
)

In [None]:
# Create dataloaders
train_loader, tokenizer_info = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="train",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_sequence_length'],
    mode='offline',
    shuffle=True
)

val_loader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="valid",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_sequence_length'],
    mode='offline',
    shuffle=False
)


In [None]:
# Display model information
print("=== T5 Offline Teacher Model Info ===")
model_info = model.get_model_info()
for key, value in model_info.items():
    print(f"  {key}: {value}")

print(f"\n=== Vocabulary Configuration ===")
print(f"  Melody vocab size: {melody_vocab_size}")
print(f"  Chord vocab size: {chord_vocab_size}")
print(f"  PAD token ID: {pad_token_id}")
print(f"  Tokenizer info - melody: {tokenizer_info['melody_vocab_size']}")
print(f"  Tokenizer info - chord: {tokenizer_info['chord_vocab_size']}")

# Verify model can handle a forward pass
print(f"\n=== Testing Model Forward Pass ===")
sample_batch = next(iter(train_loader))
with torch.no_grad():
    melody_tokens = sample_batch['melody_tokens'][:2].to(device)  # Use small batch
    chord_input = sample_batch['chord_input'][:2].to(device)
    
    print(f"  Input shapes - Melody: {melody_tokens.shape}, Chord: {chord_input.shape}")
    print(f"  Token ranges - Melody: {melody_tokens.min()}-{melody_tokens.max()}, Chord: {chord_input.min()}-{chord_input.max()}")
    
    # Test forward pass
    logits = model(melody_tokens, chord_input)
    print(f"  Output shape: {logits.shape}")
    print(f"  ✅ Forward pass successful!")


In [None]:
# Initialize model, optimizer, criterion, and scheduler
# Use unified vocabulary approach - T5 handles all tokens in one embedding space
melody_vocab_size = tokenizer_info['melody_vocab_size'] 
chord_vocab_size = tokenizer_info['chord_vocab_size'] 
total_vocab_size = tokenizer_info.get('total_vocab_size', 4779)  # Unified vocabulary
pad_token_id = tokenizer_info.get('pad_token_id', PAD_TOKEN)

model = T5OfflineTeacherModel(
    melody_vocab_size=melody_vocab_size,
    chord_vocab_size=chord_vocab_size,
    embed_dim=config['embed_dim'],
    num_heads=config['num_heads'],
    num_layers=config['num_layers'],
    dropout=config['dropout'],
    max_seq_length=config['max_sequence_length'],
    pad_token_id=pad_token_id,
    total_vocab_size=total_vocab_size  # Use unified vocabulary
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Use AdamW optimizer for better stability and simpler hyperparameter tuning
optimizer = AdamW(
    model.parameters(), 
    lr=config['learning_rate'],
    weight_decay=config.get('weight_decay', 0.01)
)
criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)

# Initialize warmup scheduler
scheduler = get_warmup_schedule(optimizer, num_warmup_steps=config['warmup_steps'])

# Enable gradient and parameter logging
wandb.watch(model, log="all", log_freq=config['log_every_n_steps'])


In [None]:
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, global_step, epoch):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, batch in enumerate(pbar):
        # Move batch to device
        melody_tokens = batch['melody_tokens'].to(device)
        chord_input = batch['chord_input'].to(device)
        chord_target = batch['chord_target'].to(device)
        
        # Create padding masks (True means position should be masked)
        melody_padding_mask = (melody_tokens == model.pad_token_id)  # [batch_size, src_len]
        chord_padding_mask = (chord_input == model.pad_token_id)     # [batch_size, tgt_len]
        
        # Forward pass with proper masking (causal mask is handled internally)
        optimizer.zero_grad()
        logits = model(
            melody_tokens=melody_tokens,
            chord_tokens=chord_input,
            melody_mask=melody_padding_mask,  # src_key_padding_mask
            chord_mask=chord_padding_mask     # tgt_key_padding_mask
        )
        
        # Calculate loss
        loss = criterion(logits.view(-1, logits.size(-1)), chord_target.view(-1))
        
        # Check for NaN loss
        if torch.isnan(loss):
            print(f"\nNaN loss detected! Skipping batch.")
            if epoch < 3:  # Early epochs
                print("NaN in early epoch - may need to reduce learning rate or increase warmup steps")
            continue
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip_val'])
        optimizer.step()
        scheduler.step()  # Update learning rate each batch
        
        # Get current learning rate
        lr = optimizer.param_groups[0]['lr']
        
        # Log batch metrics
        if batch_idx % config['log_every_n_steps'] == 0:
            wandb.log({
                'train/batch_loss': loss.item(),
                'train/learning_rate': lr,
                'train/batch': batch_idx,
                'train/epoch': epoch,
                'train/grad_norm': torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')).item()
            }, step=global_step)
        
        global_step += 1
        
        # Update progress bar
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item(), 'lr': f"{lr:.2e}"})
        
    return total_loss / len(train_loader), global_step

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    nan_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validating'):
            # Move batch to device
            melody_tokens = batch['melody_tokens'].to(device)
            chord_input = batch['chord_input'].to(device)
            chord_target = batch['chord_target'].to(device)
            
            # Create padding masks (True means position should be masked)
            melody_padding_mask = (melody_tokens == model.pad_token_id)  # [batch_size, src_len]
            chord_padding_mask = (chord_input == model.pad_token_id)     # [batch_size, tgt_len]
            
            # Forward pass with proper masking (causal mask is handled internally)
            logits = model(
                melody_tokens=melody_tokens,
                chord_tokens=chord_input,
                melody_mask=melody_padding_mask,  # src_key_padding_mask
                chord_mask=chord_padding_mask     # tgt_key_padding_mask
            )
            
            # Calculate loss
            loss = criterion(logits.view(-1, logits.size(-1)), chord_target.view(-1))
            
            # Check for NaN loss
            if torch.isnan(loss):
                nan_batches += 1
                continue
                
            total_loss += loss.item()
    
    # Avoid division by zero if all batches were NaN
    num_valid_batches = len(val_loader) - nan_batches
    return total_loss / num_valid_batches if num_valid_batches > 0 else float('nan')


In [None]:
# Training loop
best_val_loss = float('inf')
global_step = 0

try:
    for epoch in range(config['max_epochs']):
        print(f"\nEpoch {epoch + 1}/{config['max_epochs']}")
        
        # Clear GPU memory before each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"GPU memory at start of epoch: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        
        # Training Step
        train_loss, global_step = train_epoch(model, train_loader, criterion, optimizer, scheduler, device, global_step, epoch)
        
        # Validation Step
        val_loss = validate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}")
        wandb.log({
            'train/epoch_loss': train_loss,
            'valid/epoch_loss': val_loss,
            'epoch': epoch + 1,
            'train/epoch': epoch + 1
        }, step=global_step)
        
        # Save checkpoint if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
            # Save checkpoint locally
            checkpoint_path = Path("checkpoints") / f"offline_teacher_epoch_{epoch+1}.pth"
            checkpoint_path.parent.mkdir(exist_ok=True)
            
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config,
            }
            
            torch.save(checkpoint, checkpoint_path)
            
            # Log checkpoint as wandb artifact
            artifact = wandb.Artifact(
                name=f"offline_teacher_model_{wandb.run.id}",
                type="model",
                description=f"Offline Teacher Model checkpoint from epoch {epoch+1}"
            )
            artifact.add_file(str(checkpoint_path))
            wandb.log_artifact(artifact)
            
            # Also save tokenizer info as artifact
            tokenizer_artifact = wandb.Artifact(
                name=f"tokenizer_info_{wandb.run.id}",
                type="tokenizer",
                description="Tokenizer information used for training"
            )
            tokenizer_path = Path("tokenizer_info.json")
            with open(tokenizer_path, 'w') as f:
                json.dump(tokenizer_info, f)
            tokenizer_artifact.add_file(str(tokenizer_path))
            wandb.log_artifact(tokenizer_artifact)
            
            print(f"\nSaved checkpoint with validation loss: {val_loss:.4f}")
            
except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except torch.cuda.OutOfMemoryError:
    print("\nOut of GPU memory! Try reducing batch size or sequence length")
finally:
    wandb.finish()
    print("\nTraining completed")


In [None]:
def load_best_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['val_loss']

# Example usage:
# best_checkpoint_path = Path(wandb.run.dir) / 'best_checkpoint.pth'
# best_val_loss = load_best_checkpoint(model, best_checkpoint_path)
# print(f"Loaded checkpoint with validation loss: {best_val_loss:.4f}")
