In [None]:
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .

In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml transformers

In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import Adafactor
from pathlib import Path
import wandb
import yaml
import json
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
import sys
sys.path.append('.')

print(os.getcwd())

# Import project modules
from src.data.dataset import create_dataloader
from src.models.offline_teacher import OfflineTeacherModel
from src.models.offline_teacher_t5 import T5OfflineTeacherModel
from src.config.tokenization_config import PAD_TOKEN, CHORD_TOKEN_START
from src.training.utils.schedulers import get_warmup_schedule


In [None]:
# Configuration - Set model type here
MODEL_TYPE = "custom"  # Change to "t5" to use T5OfflineTeacherModel

# Load configuration based on model type
if MODEL_TYPE == "custom":
    config_path = 'src/training/configs/offline_teacher_base.yaml'
elif MODEL_TYPE == "t5":
    config_path = 'src/training/configs/offline_teacher_t5.yaml'
else:
    raise ValueError(f"Unknown model type: {MODEL_TYPE}. Use 'custom' or 't5'")

print(f"Loading config from: {config_path}")
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Override model type if needed
config['model_type'] = MODEL_TYPE

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'🎵 Training {MODEL_TYPE.upper()} model')

# Initialize wandb with model type in name
model_type_for_name = config.get('model_type', 'custom')
run_name = (
    f"offline_{model_type_for_name}_L{config['num_layers']}_H{config['num_heads']}"
    f"_D{config['embed_dim']}_seq{config['max_sequence_length']}"
    f"_bs{config['batch_size']}_lr{config['learning_rate']}"
)

wandb.init(
    project=config['wandb_project'],
    name=run_name,
    config=config,
    job_type="offline_training"
)

In [None]:
# Create dataloaders
train_loader, tokenizer_info = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="train",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_sequence_length'],
    mode='offline',
    shuffle=True
)

val_loader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="valid",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_sequence_length'],
    mode='offline',
    shuffle=False
)


In [None]:
# Initialize model, optimizer, criterion, and scheduler
config['melody_vocab_size'] = tokenizer_info['melody_vocab_size']
config['chord_vocab_size'] = tokenizer_info['chord_vocab_size']
config['total_vocab_size'] = tokenizer_info['total_vocab_size']

# --- Model Creation (Configurable) ---
def create_model(model_type: str):
    """Create model based on configuration"""
    if model_type == "custom":
        return OfflineTeacherModel(
            melody_vocab_size=config['melody_vocab_size'],
            chord_vocab_size=config['chord_vocab_size'],
            embed_dim=config['embed_dim'],
            num_heads=config['num_heads'],
            num_layers=config['num_layers'],
            dropout=config['dropout'],
            max_seq_length=config['max_sequence_length'],
            pad_token_id=tokenizer_info.get('pad_token_id', PAD_TOKEN)
        )
    elif model_type == "t5":
        return T5OfflineTeacherModel(
            melody_vocab_size=config['melody_vocab_size'],
            chord_vocab_size=config['chord_vocab_size'],
            embed_dim=config['embed_dim'],
            num_heads=config['num_heads'],
            num_layers=config['num_layers'],
            dropout=config['dropout'],
            max_seq_length=config['max_sequence_length'],
            pad_token_id=tokenizer_info.get('pad_token_id', PAD_TOKEN),
            total_vocab_size=config['total_vocab_size']
        )
    else:
        raise ValueError(f"Unknown model_type: {model_type}. Use 'custom' or 't5'")

model_type = config.get('model_type', 'custom')  # Default to custom model
model = create_model(model_type).to(device)

print(f"🎵 Using {model_type.upper()} model architecture")
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Use AdamW optimizer for better stability and simpler hyperparameter tuning
optimizer = AdamW(
    model.parameters(), 
    lr=config['learning_rate'],
    weight_decay=config.get('weight_decay', 0.01)
)

# Note: Cross-entropy loss will be calculated in the training function with proper vocab sizing
pad_token_id = tokenizer_info.get('pad_token_id', PAD_TOKEN)

# Initialize warmup scheduler
scheduler = get_warmup_schedule(optimizer, num_warmup_steps=config['warmup_steps'])

# Enable gradient and parameter logging
wandb.watch(model, log="all", log_freq=config['log_every_n_steps'])


In [None]:
def train_epoch(model, train_loader, optimizer, scheduler, device, global_step, epoch, model_type, config):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, batch in enumerate(pbar):
        # Move batch to device
        melody_tokens = batch['melody_tokens'].to(device)
        chord_input = batch['chord_input'].to(device)
        chord_target = batch['chord_target'].to(device)
        
        # Create padding masks (True means position should be masked)
        melody_padding_mask = (melody_tokens == model.pad_token_id)  # [batch_size, src_len]
        chord_padding_mask = (chord_input == model.pad_token_id)     # [batch_size, tgt_len]
        
        # Forward pass with proper masking (causal mask is handled internally)
        optimizer.zero_grad()
        logits = model(
            melody_tokens=melody_tokens,
            chord_tokens=chord_input,
            melody_mask=melody_padding_mask,  # src_key_padding_mask
            chord_mask=chord_padding_mask     # tgt_key_padding_mask
        )
        
        # Use chord vocab size for both models to isolate architectural differences
        vocab_size_for_loss = config['chord_vocab_size']
        
        # Extract chord-only logits for T5 model (which outputs full vocab)
        if model_type == "t5":
            # Chord tokens start at CHORD_TOKEN_START (179) in the full vocabulary
            chord_logits = logits[:, :, CHORD_TOKEN_START:]  # [batch, seq, chord_vocab_size]
            logits_for_loss = chord_logits
            
            # Adjust targets from full vocab space to chord-only space
            # Original targets: [179, 180, ..., 4778] -> [0, 1, ..., 4599]
            targets_for_loss = chord_target - CHORD_TOKEN_START
            # Handle PAD tokens (they should remain as pad_token_id for ignore_index)
            pad_mask = (chord_target == model.pad_token_id)
            targets_for_loss[pad_mask] = model.pad_token_id
        else:  # custom model already outputs chord-only
            logits_for_loss = logits
            targets_for_loss = chord_target  # Already in chord space
            
        # Calculate loss
        loss = nn.functional.cross_entropy(
            logits_for_loss.reshape(-1, vocab_size_for_loss),
            targets_for_loss.reshape(-1),
            ignore_index=model.pad_token_id
        )
        
        # Check for NaN loss
        if torch.isnan(loss):
            print(f"\nNaN loss detected! Skipping batch.")
            if epoch < 3:  # Early epochs
                print("NaN in early epoch - may need to reduce learning rate or increase warmup steps")
            continue
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip_val'])
        optimizer.step()
        scheduler.step()  # Update learning rate each batch
        
        # Get current learning rate
        lr = optimizer.param_groups[0]['lr']
        
        # Log batch metrics
        if batch_idx % config['log_every_n_steps'] == 0:
            wandb.log({
                'train/batch_loss': loss.item(),
                'train/learning_rate': lr,
                'train/batch': batch_idx,
                'train/epoch': epoch,
                'train/grad_norm': torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')).item()
            }, step=global_step)
        
        global_step += 1
        
        # Update progress bar
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item(), 'lr': f"{lr:.2e}"})
        
    return total_loss / len(train_loader), global_step

def validate(model, val_loader, device, model_type, config):
    model.eval()
    total_loss = 0
    nan_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validating'):
            # Move batch to device
            melody_tokens = batch['melody_tokens'].to(device)
            chord_input = batch['chord_input'].to(device)
            chord_target = batch['chord_target'].to(device)
            
            # Create padding masks (True means position should be masked)
            melody_padding_mask = (melody_tokens == model.pad_token_id)  # [batch_size, src_len]
            chord_padding_mask = (chord_input == model.pad_token_id)     # [batch_size, tgt_len]
            
            # Forward pass with proper masking (causal mask is handled internally)
            logits = model(
                melody_tokens=melody_tokens,
                chord_tokens=chord_input,
                melody_mask=melody_padding_mask,  # src_key_padding_mask
                chord_mask=chord_padding_mask     # tgt_key_padding_mask
            )
            
            # Use chord vocab size for both models to isolate architectural differences
            vocab_size_for_loss = config['chord_vocab_size']
            
            # Extract chord-only logits for T5 model (which outputs full vocab)
            if model_type == "t5":
                # Chord tokens start at CHORD_TOKEN_START (179) in the full vocabulary
                chord_logits = logits[:, :, CHORD_TOKEN_START:]  # [batch, seq, chord_vocab_size]
                logits_for_loss = chord_logits
                
                # Adjust targets from full vocab space to chord-only space
                # Original targets: [179, 180, ..., 4778] -> [0, 1, ..., 4599]
                targets_for_loss = chord_target - CHORD_TOKEN_START
                # Handle PAD tokens (they should remain as pad_token_id for ignore_index)
                pad_mask = (chord_target == model.pad_token_id)
                targets_for_loss[pad_mask] = model.pad_token_id
            else:  # custom model already outputs chord-only
                logits_for_loss = logits
                targets_for_loss = chord_target  # Already in chord space
                
            # Calculate loss
            loss = nn.functional.cross_entropy(
                logits_for_loss.reshape(-1, vocab_size_for_loss),
                targets_for_loss.reshape(-1),
                ignore_index=model.pad_token_id
            )
            
            # Check for NaN loss
            if torch.isnan(loss):
                nan_batches += 1
                continue
                
            total_loss += loss.item()
    
    # Avoid division by zero if all batches were NaN
    num_valid_batches = len(val_loader) - nan_batches
    return total_loss / num_valid_batches if num_valid_batches > 0 else float('nan')


In [None]:
# Training loop
best_val_loss = float('inf')
global_step = 0

print(f"\n--- Offline Training Info ---")
print(f"  Model type: {model_type}")
print(f"  Max epochs: {config['max_epochs']}")
print(f"  Early stopping patience: {config.get('early_stopping_patience', 5)}")

try:
    for epoch in range(config['max_epochs']):
        print(f"\nEpoch {epoch + 1}/{config['max_epochs']}")
        
        # Clear GPU memory before each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"GPU memory at start of epoch: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        
        # Training Step
        train_loss, global_step = train_epoch(model, train_loader, optimizer, scheduler, device, global_step, epoch, model_type, config)
        
        # Validation Step
        val_loss = validate(model, val_loader, device, model_type, config)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}")
        wandb.log({
            'train/epoch_loss': train_loss,
            'valid/epoch_loss': val_loss,
            'epoch': epoch + 1,
            'train/epoch': epoch + 1
        }, step=global_step)
        
        # Save checkpoint if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
            # Save checkpoint locally
            checkpoint_path = Path("checkpoints") / f"offline_teacher_epoch_{epoch+1}.pth"
            checkpoint_path.parent.mkdir(exist_ok=True)
            
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': config,
            }
            
            torch.save(checkpoint, checkpoint_path)
            
            # Log checkpoint as wandb artifact
            artifact = wandb.Artifact(
                name=f"offline_teacher_model_{wandb.run.id}",
                type="model",
                description=f"Offline Teacher Model checkpoint from epoch {epoch+1}"
            )
            artifact.add_file(str(checkpoint_path))
            wandb.log_artifact(artifact)
            
            # Also save tokenizer info as artifact
            tokenizer_artifact = wandb.Artifact(
                name=f"tokenizer_info_{wandb.run.id}",
                type="tokenizer",
                description="Tokenizer information used for training"
            )
            tokenizer_path = Path("tokenizer_info.json")
            with open(tokenizer_path, 'w') as f:
                json.dump(tokenizer_info, f)
            tokenizer_artifact.add_file(str(tokenizer_path))
            wandb.log_artifact(tokenizer_artifact)
            
            print(f"\nSaved checkpoint with validation loss: {val_loss:.4f}")
            
except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except torch.cuda.OutOfMemoryError:
    print("\nOut of GPU memory! Try reducing batch size or sequence length")
finally:
    wandb.finish()
    print("\nTraining completed")


In [None]:
def load_best_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['val_loss']

# Example usage:
# best_checkpoint_path = Path(wandb.run.dir) / 'best_checkpoint.pth'
# best_val_loss = load_best_checkpoint(model, best_checkpoint_path)
# print(f"Loaded checkpoint with validation loss: {best_val_loss:.4f}")


In [None]:
# Model Comparison Analysis (Optional)
def analyze_model_performance():
    """
    Compare model architectures and performance metrics
    """
    print("🎵 Model Architecture Comparison")
    print("=" * 50)
    
    print(f"Selected Model Type: {model_type.upper()}")
    print(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    if model_type == "custom":
        print("✅ Custom Model Advantages:")
        print("  - Native chord-only output head")
        print("  - 100% parameter efficiency")
        print("  - Task-specific architecture")
        print("  - No vocabulary confusion")
        print(f"  - Output shape: [batch, seq, {config['chord_vocab_size']}]")
    
    elif model_type == "t5":
        print("❌ T5 Model Issues:")
        print("  - Unified vocabulary (melody + chord)")
        print("  - 3.8% wasted parameters (unused melody outputs)")
        print("  - Text-to-text design for cross-domain task")
        print("  - Requires logit extraction + target adjustment")
        print(f"  - Output shape: [batch, seq, {config['total_vocab_size']}] → extract chord portion")
        
        # Calculate parameter waste
        total_params = config['total_vocab_size'] * config['embed_dim']
        wasted_params = 179 * config['embed_dim']  # Melody + PAD tokens
        waste_percent = (wasted_params / total_params) * 100
        print(f"  - Output head waste: {wasted_params:,} / {total_params:,} ({waste_percent:.1f}%)")

# Run analysis
analyze_model_performance()
