In [None]:
# Setup: Clone repository and install package
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .


In [None]:
# Install additional dependencies for balanced training
%pip install torch wandb tqdm pyyaml transformers scikit-learn


In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.optim import AdamW
from pathlib import Path
import wandb
import yaml
import json
import numpy as np
from tqdm.notebook import tqdm
import os
import sys
import math
from typing import Dict, Optional, Tuple

# Add project root to Python path
sys.path.append('.')

print("📍 Current directory:", os.getcwd())
print("🐍 Python path:", sys.path[:3])  # Show first 3 entries

# Import project modules
from src.data.dataset import create_dataloader
from src.models.online_transformer import OnlineTransformer
from src.config.tokenization_config import PAD_TOKEN
from src.training.losses import (
    create_loss_function, analyze_prediction_diversity, MixedLossTrainer
)
from src.training.sampling import AdvancedSampler
from src.data.analyze_token_distribution import analyze_chord_distribution
from src.evaluation.training_diagnostics import quick_model_diagnostic

print("✅ All modules imported successfully")


In [None]:
# 🎛️ Configuration Selection
# Choose one of the three training configurations:

# Option 1: Balanced Training (Recommended for first run)
CONFIG_NAME = "balanced_training_base"  

# Option 2: Diversity Regularized (Strong diversity penalties)
# CONFIG_NAME = "diversity_regularized"

# Option 3: Focal Loss Heavy (For severe class imbalance)
# CONFIG_NAME = "focal_loss_heavy"

# Load configuration
config_path = f'src/training/configs/{CONFIG_NAME}.yaml'
print(f"📋 Loading config: {config_path}")

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'🚀 Using device: {device}')

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name()}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Display key configuration settings
print(f"\n🎵 Training Configuration: {CONFIG_NAME}")
print(f"   Loss type: {config['loss']['type']}")
print(f"   Learning rate: {config['training']['learning_rate']}")
print(f"   Batch size: {config['training']['batch_size']}")
print(f"   Max epochs: {config['training']['max_epochs']}")
print(f"   Sequence length: {config['model']['sequence_length']}")


In [None]:
# 🔍 Data Distribution Analysis
print("🔍 Analyzing chord distribution in training data...")

try:
    data_dir = Path(config['data']['data_dir'])
    results = analyze_chord_distribution(data_dir, split="train", max_files=2000)
    
    token_counts = results['token_counts']
    stats = results['distribution_stats']
    
    print(f"📊 Distribution Analysis Results:")
    print(f"   Total unique chord tokens: {len(token_counts)}")
    print(f"   Most common token: {stats['most_common_token']} ({stats['most_common_count']:,} occurrences)")
    print(f"   Dominance ratio: {stats['dominance_ratio']:.3f}")
    print(f"   Gini coefficient: {stats['gini_coefficient']:.3f}")
    print(f"   Vocabulary coverage: {stats['vocab_coverage']:.3f}")
    
    # Issue warnings based on imbalance severity
    if stats['dominance_ratio'] > 0.5:
        print("   ⚠️  HIGH DOMINANCE detected - strong balancing recommended")
        print("   💡 Consider using 'focal_loss_heavy' or 'diversity_regularized' config")
    elif stats['dominance_ratio'] > 0.3:
        print("   ⚠️  MODERATE DOMINANCE detected - balanced training recommended")
        print("   💡 'balanced_training_base' config should work well")
    else:
        print("   ✅ Reasonable distribution detected")
    
    if stats['gini_coefficient'] > 0.8:
        print(f"   ⚠️  Very high inequality (Gini: {stats['gini_coefficient']:.3f})")
    
except Exception as e:
    print(f"   ⚠️  Could not analyze distribution: {e}")
    print("   Using standard training approach")
    token_counts = None
    stats = None


In [None]:
# Initialize Weights & Biases
run_name = f"{CONFIG_NAME}_L{config['model']['n_layers']}_H{config['model']['n_heads']}_D{config['model']['d_model']}_seq{config['model']['sequence_length']}_bs{config['training']['batch_size']}_lr{config['training']['learning_rate']}"

wandb.init(
    project=config['logging']['wandb']['project'],
    name=config['logging']['wandb']['name'] + "_" + run_name,
    config={
        **config,
        'data_distribution': stats if stats else 'unknown',
        'num_unique_chords': len(token_counts) if token_counts else 'unknown'
    },
    job_type="balanced_training"
)

print(f"📊 Weights & Biases initialized: {wandb.run.name}")


In [None]:
# 📊 Create Data Loaders
print("📊 Creating data loaders...")

data_config = config['data']

train_loader, train_info = create_dataloader(
    data_dir=Path(data_config['data_dir']),
    split="train",
    batch_size=config['training']['batch_size'],
    num_workers=data_config['num_workers'],
    sequence_length=data_config['sequence_length'],
    mode=data_config['mode'],
    shuffle=True,
    pin_memory=data_config.get('pin_memory', True)
)

val_loader, val_info = create_dataloader(
    data_dir=Path(data_config['data_dir']),
    split="valid",
    batch_size=config['training']['batch_size'],
    num_workers=data_config['num_workers'],
    sequence_length=data_config['sequence_length'],
    mode=data_config['mode'],
    shuffle=False,
    pin_memory=data_config.get('pin_memory', True)
)

print(f"   Train batches: {len(train_loader)}")
print(f"   Valid batches: {len(val_loader)}")
print(f"   Vocabulary size: {train_info.get('vocab_size', 'unknown')}")

# Update config with actual vocab size if available
if 'vocab_size' in train_info:
    config['model']['vocab_size'] = train_info['vocab_size']
    print(f"   Updated vocab_size to: {config['model']['vocab_size']}")


In [None]:
# 🏗️ Create Model and Loss Function
print("🏗️ Creating model and loss function...")

# Initialize model
model_config = config['model']
model = OnlineTransformer(
    vocab_size=model_config['vocab_size'],
    embed_dim=model_config['d_model'],
    num_heads=model_config['n_heads'],
    num_layers=model_config['n_layers'],
    max_seq_length=model_config['sequence_length'],
    dropout=model_config['dropout'],
    pad_token_id=PAD_TOKEN
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   Model created with {total_params:,} total parameters ({trainable_params:,} trainable)")

# Create balanced loss function
loss_config = config['loss']
criterion = create_loss_function(
    loss_config=loss_config,
    vocab_size=model_config['vocab_size'],
    token_counts=token_counts,
    ignore_index=loss_config.get('ignore_index', PAD_TOKEN),
    device=device
)

print(f"   Loss function created: {loss_config['type']}")
if hasattr(criterion, 'weights'):
    print(f"   Using class weights for {criterion.weights.sum().item():.1f} total weight")

# Create optimizer and scheduler
train_config = config['training']
optimizer = AdamW(
    model.parameters(),
    lr=train_config['learning_rate'],
    weight_decay=train_config['weight_decay'],
    betas=(0.9, 0.95),
    eps=1e-8
)

# Simple scheduler for now - can enhance later
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=train_config['max_epochs'] * len(train_loader),
    eta_min=train_config['learning_rate'] * 0.1
)

print(f"   Optimizer: AdamW (lr={train_config['learning_rate']}, wd={train_config['weight_decay']})")
print(f"   Scheduler: CosineAnnealingLR")

# Enable gradient and parameter logging
wandb.watch(model, log="all", log_freq=config['logging']['log_every'])


In [None]:
# 🎯 Training Functions with Diversity Monitoring
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device, global_step, epoch):
    """Training epoch with diversity monitoring."""
    model.train()
    total_loss = 0
    diversity_history = []
    
    pbar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}')
    for batch_idx, batch in enumerate(pbar):
        # Move batch to device
        input_tokens = batch['input_tokens'].to(device)
        target_tokens = batch['target_tokens'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(input_tokens)
        
        # Calculate loss
        if isinstance(criterion, MixedLossTrainer):
            loss, loss_metrics = criterion.compute_loss(logits, target_tokens, global_step)
        else:
            loss = criterion(logits.view(-1, logits.size(-1)), target_tokens.view(-1))
            loss_metrics = {'loss': loss.item()}
        
        # Check for NaN loss
        if torch.isnan(loss):
            print(f"\n⚠️ NaN loss detected! Skipping batch.")
            continue
        
        # Backward pass
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['training']['grad_clip_norm'])
        optimizer.step()
        scheduler.step()
        
        # Get current learning rate
        lr = optimizer.param_groups[0]['lr']
        
        # Analyze diversity every few steps
        if batch_idx % 50 == 0:
            diversity_metrics = analyze_prediction_diversity(
                logits, target_tokens, model_config['vocab_size']
            )
            diversity_history.append(diversity_metrics)
            
            # Check for diversity alerts
            if diversity_metrics['dominance_ratio'] > 0.7:
                print(f"\n⚠️ HIGH DOMINANCE: {diversity_metrics['dominance_ratio']:.3f}")
            if diversity_metrics['vocabulary_coverage'] < 0.05:
                print(f"\n⚠️ LOW COVERAGE: {diversity_metrics['vocabulary_coverage']:.3f}")
        
        # Log metrics
        if batch_idx % config['logging']['log_every'] == 0:
            log_dict = {
                'train/batch_loss': loss.item(),
                'train/learning_rate': lr,
                'train/grad_norm': grad_norm.item(),
                'train/batch': batch_idx,
                'train/epoch': epoch
            }
            
            # Add loss-specific metrics
            for key, value in loss_metrics.items():
                log_dict[f'train/{key}'] = value
            
            # Add diversity metrics if available
            if diversity_history:
                latest_diversity = diversity_history[-1]
                for key, value in latest_diversity.items():
                    log_dict[f'train/diversity_{key}'] = value
            
            wandb.log(log_dict, step=global_step)
        
        global_step += 1
        total_loss += loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'lr': f"{lr:.2e}",
            'dom': f"{diversity_history[-1]['dominance_ratio']:.3f}" if diversity_history else "---"
        })
    
    return total_loss / len(train_loader), global_step, diversity_history

def validate(model, val_loader, criterion, device, vocab_size):
    """Validation with comprehensive diversity analysis."""
    model.eval()
    total_loss = 0
    all_diversity_metrics = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validating'):
            input_tokens = batch['input_tokens'].to(device)
            target_tokens = batch['target_tokens'].to(device)
            
            # Forward pass
            logits = model(input_tokens)
            
            # Calculate loss
            if isinstance(criterion, MixedLossTrainer):
                loss, _ = criterion.compute_loss(logits, target_tokens, 0)
            else:
                loss = criterion(logits.view(-1, logits.size(-1)), target_tokens.view(-1))
            
            if not torch.isnan(loss):
                total_loss += loss.item()
                
                # Analyze diversity for this batch
                diversity_metrics = analyze_prediction_diversity(logits, target_tokens, vocab_size)
                all_diversity_metrics.append(diversity_metrics)
    
    # Average validation metrics
    avg_loss = total_loss / len(val_loader)
    
    # Aggregate diversity metrics
    if all_diversity_metrics:
        avg_diversity = {}
        for key in all_diversity_metrics[0].keys():
            avg_diversity[key] = np.mean([m[key] for m in all_diversity_metrics])
    else:
        avg_diversity = {}
    
    return avg_loss, avg_diversity

print("🎯 Training functions defined with diversity monitoring")


In [None]:
# 🚀 Main Training Loop
print("🚀 Starting balanced training...")

best_val_loss = float('inf')
global_step = 0
patience_counter = 0
max_patience = 5

# Create checkpoints directory
checkpoints_dir = Path("checkpoints")
checkpoints_dir.mkdir(exist_ok=True)

try:
    for epoch in range(config['training']['max_epochs']):
        print(f"\n{'='*60}")
        print(f"🎵 Epoch {epoch + 1}/{config['training']['max_epochs']}")
        print(f"{'='*60}")
        
        # Clear GPU memory before each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"🔧 GPU memory at start: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        
        # Training Step
        train_loss, global_step, train_diversity = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, global_step, epoch
        )
        
        # Validation Step
        val_loss, val_diversity = validate(model, val_loader, criterion, device, model_config['vocab_size'])
        
        # Log epoch results
        print(f"\n📊 Epoch {epoch+1} Results:")
        print(f"   Train Loss: {train_loss:.4f}")
        print(f"   Valid Loss: {val_loss:.4f}")
        
        if val_diversity:
            print(f"   Validation Diversity:")
            print(f"     Dominance Ratio: {val_diversity['dominance_ratio']:.3f}")
            print(f"     Vocab Coverage: {val_diversity['vocabulary_coverage']:.3f}")
            print(f"     Prediction Entropy: {val_diversity['prediction_entropy']:.3f}")
            
            # Diversity alerts
            if val_diversity['dominance_ratio'] > 0.6:
                print(f"   ⚠️  HIGH DOMINANCE WARNING!")
            if val_diversity['vocabulary_coverage'] < 0.1:
                print(f"   ⚠️  LOW VOCABULARY COVERAGE WARNING!")
        
        # Log to wandb
        epoch_log = {
            'train/epoch_loss': train_loss,
            'valid/epoch_loss': val_loss,
            'epoch': epoch + 1,
            'train/epoch': epoch + 1
        }
        
        # Add validation diversity metrics
        for key, value in val_diversity.items():
            epoch_log[f'valid/diversity_{key}'] = value
        
        wandb.log(epoch_log, step=global_step)
        
        # Save checkpoint if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            
            # Save best checkpoint
            checkpoint_path = checkpoints_dir / f"balanced_best_epoch_{epoch+1}.pth"
            
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'val_diversity': val_diversity,
                'config': config,
                'vocab_size': model_config['vocab_size'],
                'global_step': global_step
            }
            
            torch.save(checkpoint, checkpoint_path)
            
            # Log checkpoint as wandb artifact
            artifact = wandb.Artifact(
                name=f"balanced_model_{wandb.run.id}",
                type="model",
                description=f"Balanced model checkpoint from epoch {epoch+1}"
            )
            artifact.add_file(str(checkpoint_path))
            wandb.log_artifact(artifact)
            
            print(f"💾 Saved best checkpoint: {checkpoint_path}")
            print(f"   New best validation loss: {val_loss:.4f}")
            
        else:
            patience_counter += 1
            print(f"⏰ Patience counter: {patience_counter}/{max_patience}")
            
            if patience_counter >= max_patience:
                print(f"🛑 Early stopping triggered after {patience_counter} epochs without improvement")
                break
        
        # Save regular checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            regular_checkpoint_path = checkpoints_dir / f"balanced_epoch_{epoch+1}.pth"
            torch.save(checkpoint, regular_checkpoint_path)
            print(f"💾 Saved regular checkpoint: {regular_checkpoint_path}")
            
except KeyboardInterrupt:
    print("\n⏹️ Training interrupted by user")
except torch.cuda.OutOfMemoryError:
    print("\n❌ Out of GPU memory! Try reducing batch size or sequence length")
    print(f"   Current batch size: {config['training']['batch_size']}")
    print(f"   Current sequence length: {config['model']['sequence_length']}")
except Exception as e:
    print(f"\n❌ Training failed with error: {e}")
    import traceback
    traceback.print_exc()
finally:
    # Save final checkpoint
    final_checkpoint_path = checkpoints_dir / "balanced_final.pth"
    if 'checkpoint' in locals():
        torch.save(checkpoint, final_checkpoint_path)
        print(f"💾 Saved final checkpoint: {final_checkpoint_path}")
    
    wandb.finish()
    print("\n🎉 Training completed!")


In [None]:
# 🧪 Quick Model Evaluation
print("🧪 Running quick model diagnostic...")

try:
    # Load the best checkpoint
    checkpoint_files = list(checkpoints_dir.glob("balanced_best_*.pth"))
    if checkpoint_files:
        latest_checkpoint = max(checkpoint_files, key=lambda x: x.stat().st_mtime)
        print(f"📂 Loading checkpoint: {latest_checkpoint}")
        
        checkpoint = torch.load(latest_checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Run diagnostic
        results = quick_model_diagnostic(
            model=model,
            data_loader=val_loader,
            device=device,
            num_batches=10,
            vocab_size=model_config['vocab_size']
        )
        
        print(f"\n📊 Model Diagnostic Results:")
        print(f"   Average Loss: {results['avg_loss']:.4f}")
        print(f"   Prediction Accuracy: {results['accuracy']:.3f}")
        print(f"   Perplexity: {results['perplexity']:.2f}")
        print(f"   Unique Predictions: {results['unique_predictions']}")
        print(f"   Vocabulary Coverage: {results['vocab_coverage']:.3f}")
        print(f"   Dominance Ratio: {results['dominance_ratio']:.3f}")
        
        # Issue alerts based on results
        if results['dominance_ratio'] > 0.8:
            print("   ⚠️  CRITICAL: Very high dominance - model may be stuck")
        elif results['dominance_ratio'] > 0.6:
            print("   ⚠️  WARNING: High dominance detected")
        else:
            print("   ✅ Reasonable prediction diversity")
            
        if results['vocab_coverage'] < 0.05:
            print("   ⚠️  CRITICAL: Very low vocabulary coverage")
        elif results['vocab_coverage'] < 0.1:
            print("   ⚠️  WARNING: Low vocabulary coverage")
        else:
            print("   ✅ Good vocabulary coverage")
        
        # Log diagnostic results to wandb
        wandb.log({
            'diagnostic/avg_loss': results['avg_loss'],
            'diagnostic/accuracy': results['accuracy'],
            'diagnostic/perplexity': results['perplexity'],
            'diagnostic/unique_predictions': results['unique_predictions'],
            'diagnostic/vocab_coverage': results['vocab_coverage'],
            'diagnostic/dominance_ratio': results['dominance_ratio']
        })
        
    else:
        print("⚠️ No checkpoint found for diagnostic")
        
except Exception as e:
    print(f"❌ Diagnostic failed: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# 🎲 Test Advanced Sampling
print("🎲 Testing advanced sampling strategies...")

# Create advanced sampler
sampling_config = config.get('sampling', {})
strategies = {
    'nucleus': {
        'p': sampling_config.get('top_p', 0.9),
        'temperature': sampling_config.get('temperature', 1.2)
    },
    'top_k': {
        'k': 50,
        'temperature': sampling_config.get('temperature', 1.0)
    }
}

sampler = AdvancedSampler(strategies=strategies)

# Test with a sample from validation set
try:
    model.eval()
    with torch.no_grad():
        # Get a sample batch
        sample_batch = next(iter(val_loader))
        input_tokens = sample_batch['input_tokens'][:1].to(device)  # Take first sample
        
        print(f"🎵 Input sequence length: {input_tokens.shape[1]}")
        print(f"   First few tokens: {input_tokens[0, :10].cpu().tolist()}")
        
        # Get model predictions
        logits = model(input_tokens)
        last_logits = logits[0, -1, :]  # Last position predictions
        
        print(f"\n🎯 Testing different sampling strategies:")
        
        # Test multiple sampling strategies
        strategies_to_test = ['nucleus', 'top_k']
        
        for strategy in strategies_to_test:
            samples = []
            for _ in range(10):  # Generate 10 samples
                sample_token = sampler.sample(
                    last_logits.unsqueeze(0),
                    strategy=strategy,
                    input_ids=input_tokens,
                    repetition_penalty=1.2,
                    frequency_penalty=0.1
                )
                samples.append(sample_token.item())
            
            unique_samples = len(set(samples))
            print(f"   {strategy.capitalize()}: {unique_samples}/10 unique samples")
            print(f"     Samples: {samples[:5]}... (showing first 5)")
            
            if unique_samples < 3:
                print(f"     ⚠️ Low diversity with {strategy}")
            else:
                print(f"     ✅ Good diversity with {strategy}")
        
        # Test greedy vs sampling
        greedy_token = torch.argmax(last_logits).item()
        nucleus_token = sampler.sample(last_logits.unsqueeze(0), 'nucleus').item()
        
        print(f"\n🔄 Comparison:")
        print(f"   Greedy prediction: {greedy_token}")
        print(f"   Nucleus sample: {nucleus_token}")
        print(f"   Different from greedy: {'✅' if nucleus_token != greedy_token else '❌'}")
        
except Exception as e:
    print(f"❌ Sampling test failed: {e}")
    import traceback
    traceback.print_exc()
