In [None]:
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .


In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml

In [None]:
# Import required libraries
import torch
import wandb
import json
import numpy as np
import tempfile
from pathlib import Path
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
sys.path.append('.')

# Import project modules
from src.data.dataset import create_dataloader
from src.models.online_transformer import OnlineTransformer
from src.models.offline_teacher_t5 import T5OfflineTeacherModel
from src.evaluation.metrics import (
    calculate_harmony_metrics,
    calculate_emd_metrics,
    parse_sequences,  # Added this import
)
from src.config.tokenization_config import (
    SILENCE_TOKEN,
    MELODY_ONSET_HOLD_START,
    CHORD_TOKEN_START,
    PAD_TOKEN,
)
from src.evaluation.evaluate_offline import generate_offline
from src.evaluation.evaluate import generate_online

def load_model_artifact(artifact_path: str) -> tuple[dict, dict, dict]:
    """
    Load a model artifact and its config from wandb.
    Compatible with both checkpoint-style and separate file artifacts.
    
    Args:
        artifact_path: Full path to the artifact (e.g. 'marty1ai/martydepth/model_name:version')
    
    Returns:
        tuple[dict, dict, dict]: Model state dict, config, and tokenizer_info
    """
    api = wandb.Api()
    artifact = api.artifact(artifact_path)
    
    # Download the artifact
    with tempfile.TemporaryDirectory() as tmp_dir:
        artifact_dir = artifact.download(tmp_dir)
        artifact_path_obj = Path(artifact_dir)
        
        # Check for different artifact structures
        checkpoint_files = list(artifact_path_obj.glob("*.pth"))
        tokenizer_files = list(artifact_path_obj.glob("tokenizer_info.json"))
        
        if len(checkpoint_files) == 1 and len(tokenizer_files) == 1:
            # Separate files structure (model.pth + tokenizer_info.json)
            model_path = checkpoint_files[0]
            tokenizer_path = tokenizer_files[0]
            
            # Load model state dict directly
            model_state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
            
            # Load tokenizer info
            with open(tokenizer_path, 'r') as f:
                tokenizer_info = json.load(f)
            
            # Get config from the run that created this artifact
            run = artifact.logged_by()
            config = dict(run.config)
            
        elif len(checkpoint_files) == 1 and len(tokenizer_files) == 0:
            # Checkpoint structure (single .pth file with everything)
            checkpoint_file = checkpoint_files[0]
            checkpoint = torch.load(checkpoint_file, map_location='cpu')
            
            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                # Training checkpoint structure
                model_state_dict = checkpoint['model_state_dict']
                config = checkpoint.get('config', {})
                
                # Try to get tokenizer info from run config or load from data dir
                run = artifact.logged_by()
                run_config = dict(run.config)
                config.update(run_config)  # Merge configs
                
                # Load tokenizer info from data directory as fallback
                tokenizer_info_path = Path("data/interim/train/tokenizer_info.json")
                if tokenizer_info_path.exists():
                    with open(tokenizer_info_path, 'r') as f:
                        tokenizer_info = json.load(f)
                else:
                    raise FileNotFoundError("Could not find tokenizer_info.json in checkpoint or data directory")
            else:
                # Direct state dict
                model_state_dict = checkpoint
                run = artifact.logged_by()
                config = dict(run.config)
                
                # Load tokenizer info from data directory
                tokenizer_info_path = Path("data/interim/train/tokenizer_info.json")
                with open(tokenizer_info_path, 'r') as f:
                    tokenizer_info = json.load(f)
        else:
            raise ValueError(f"Unexpected artifact structure in {artifact_dir}. Expected either model.pth+tokenizer_info.json or single checkpoint.pth")
    
    return model_state_dict, config, tokenizer_info

In [None]:
# Configuration
config = {
    # Data parameters
    'data_dir': 'data/interim',
    'split': 'test',
    'batch_size': 32,
    'num_workers': 4,
    
    # Sampling parameters
    'temperature': 1.0,     # Increased from 1.3 for more exploration
    'top_k': 50,           # Reduced from 30 to focus on more likely but still diverse options
    'wait_beats': 1,       # double
    
    # Model artifact paths
    'online_model_artifact': 'marty1ai/martydepth/online_transformer_model_bvwago40:v13',
    'offline_model_artifact': 'marty1ai/martydepth/offline_teacher_model_2hd3b6gi:v8'
}

# Add some helpful frame/beat conversions
frames_per_beat = 4  # Standard in our dataset

print(f"\nSampling Parameters:")
print(f"Temperature: {config['temperature']}")
print(f"Top-k: {config['top_k']}")
print(f"Wait beats: {config['wait_beats']}")

# Import necessary constants
from src.config.tokenization_config import (
    MELODY_VOCAB_SIZE,
    CHORD_TOKEN_START,
    SILENCE_TOKEN,
    PAD_TOKEN
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else
                     "mps" if torch.backends.mps.is_available() else
                     "cpu")
print(f"\nUsing device: {device}")

# Initialize wandb
wandb.init(
    project="martydepth",
    name="model_evaluation_configurable_durations",  # Updated name to reflect changes
    config=config,
    job_type="evaluation"
)

# Load model artifacts
print("\nLoading model artifacts...")

# Online model
online_state_dict, online_config, online_tokenizer_info = load_model_artifact(config['online_model_artifact'])
print(f"Loaded online model from {config['online_model_artifact']}")

# Offline model
offline_state_dict, offline_config, offline_tokenizer_info = load_model_artifact(config['offline_model_artifact'])
print(f"Loaded offline model from {config['offline_model_artifact']}")

# Use tokenizer info from the online model (they should be identical)
tokenizer_info = online_tokenizer_info
print("Loaded tokenizer info from model artifacts")

# Verify tokenizer consistency
if online_tokenizer_info != offline_tokenizer_info:
    print("WARNING: Online and offline models have different tokenizer info!")
    print("This may cause evaluation issues.")
else:
    print("✓ Tokenizer info consistent between models")

In [None]:
# Evaluate Online Model
print("\n=== Evaluating Online Model with Diverse Sampling ===")

# Initialize model
total_vocab_size = tokenizer_info['total_vocab_size']
# Check for max_seq_length in config with common parameter names
max_seq_length = 512

# Fix vocabulary size if needed (should be 4665, not 4664)
if total_vocab_size == 4664:
    print(f"WARNING: Adjusting vocab_size from {total_vocab_size} to 4665 to fix off-by-one error")
    total_vocab_size = 4665

model = OnlineTransformer(
    vocab_size=total_vocab_size,
    embed_dim=online_config['embed_dim'],
    num_heads=online_config['num_heads'],
    num_layers=online_config['num_layers'],
    dropout=online_config['dropout'],
    max_seq_length=max_seq_length,
    pad_token_id=PAD_TOKEN
).to(device)

print(f"Initialized online model with max_seq_length: {max_seq_length}")

# Load state dict
model.load_state_dict(online_state_dict)
model.eval()

# Create dataloader
max_seq_length = online_config.get('max_seq_length') or online_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='online',
    shuffle=False
)

# Generate sequences with new sampling parameters
print("\nGenerating sequences with diverse sampling...")
generated_sequences, ground_truth_sequences = generate_online(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k'],
    wait_beats=config['wait_beats']
)

# Debug: Check what format the generated sequences have
print(f"\nDebug - Generated Sequences Format:")
print(f"Number of sequences: {len(generated_sequences)}")
if generated_sequences:
    print(f"First sequence length: {len(generated_sequences[0])}")
    print(f"First sequence sample: {generated_sequences[0][:10]}")
    print(f"Token range: [{min(generated_sequences[0])}, {max(generated_sequences[0])}]")
    
print(f"\nDebug - Ground Truth Format:")
if ground_truth_sequences:
    print(f"First GT sequence length: {len(ground_truth_sequences[0])}")
    print(f"First GT sequence sample: {ground_truth_sequences[0][:10]}")

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

online_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Online Model Results with Diverse Sampling ===")
for metric, value in online_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"online_diverse/{metric}": value})
    
# Log the artifact version and sampling parameters
wandb.log({
    "online_model_artifact": config['online_model_artifact'],
    "sampling/temperature": config['temperature'],
    "sampling/top_k": config['top_k'],
    "sampling/wait_beats": config['wait_beats']
})

In [None]:
# Evaluate Offline Model
print("\n=== Evaluating Offline Model with Diverse Sampling ===")

# Get max sequence length from config with common parameter names
max_seq_length = (offline_config.get('max_seq_length') or 
                  offline_config.get('max_sequence_length') or 
                  256)  # Default to 256 for offline models

# Initialize model
model = T5OfflineTeacherModel(
    melody_vocab_size=tokenizer_info['melody_vocab_size'],  # Get from tokenizer info
    chord_vocab_size=tokenizer_info['chord_vocab_size'],    # Get from tokenizer info
    embed_dim=offline_config['embed_dim'],
    num_heads=offline_config['num_heads'],
    num_layers=offline_config['num_layers'],
    dropout=offline_config['dropout'],
    max_seq_length=max_seq_length,
    pad_token_id=PAD_TOKEN,
    total_vocab_size=tokenizer_info.get('total_vocab_size', 4779)  # Use unified vocabulary
).to(device)

print(f"Initialized offline model with max_seq_length: {max_seq_length}")

# Load state dict
model.load_state_dict(offline_state_dict)
model.eval()
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='offline',
    shuffle=False
)

# Generate sequences with new sampling parameters
print("\nGenerating sequences with diverse sampling...")
generated_chord_sequences, ground_truth_chord_sequences, melody_sequences = generate_offline(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k']
)

# FIX: Convert to interleaved format for metrics calculation (this was missing!)
print("\nConverting to interleaved format for metrics calculation...")
from src.evaluation.metrics import create_interleaved_sequences
import numpy as np

generated_sequences = create_interleaved_sequences(
    np.array(melody_sequences), np.array(generated_chord_sequences)
)
ground_truth_sequences = create_interleaved_sequences(
    np.array(melody_sequences), np.array(ground_truth_chord_sequences)
)

print(f"Created {len(generated_sequences)} interleaved sequences for evaluation.")

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

offline_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Offline Model Results with Diverse Sampling ===")
for metric, value in offline_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"offline_diverse/{metric}": value})
    
# Log the artifact version and sampling parameters
wandb.log({
    "offline_model_artifact": config['offline_model_artifact'],
    "sampling/temperature": config['temperature'],
    "sampling/top_k": config['top_k']
})

In [None]:
# Compare Results with Test Set Baselines
print("\n=== COMPARISON WITH TEST SET BASELINES ===")

from src.evaluation.metrics import print_baseline_comparison

print("\n--- Online Model vs Baselines ---")
print_baseline_comparison(
    harmony_metrics=calculate_harmony_metrics(generated_sequences, tokenizer_info), 
    emd_metrics=calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)
)

print("\n--- Offline Model vs Baselines ---") 
print_baseline_comparison(
    harmony_metrics=offline_metrics, 
    emd_metrics=offline_metrics
)

# Log baseline comparisons to wandb
wandb.log({
    "online_vs_baseline_harmony_ratio": (online_metrics['melody_note_in_chord_ratio'] / 65.88),
    "offline_vs_baseline_harmony_ratio": (offline_metrics['melody_note_in_chord_ratio'] / 65.88),
})

print("\n=== SUMMARY ===")
print(f"Online Model Harmony: {online_metrics['melody_note_in_chord_ratio']:.2f}% (vs 65.88% baseline)")
print(f"Offline Model Harmony: {offline_metrics['melody_note_in_chord_ratio']:.2f}% (vs 65.88% baseline)")
print(f"Online Model EMD: {online_metrics.get('onset_interval_emd', 'N/A'):.2f} (vs 28.89 baseline)")
print(f"Offline Model EMD: {offline_metrics.get('onset_interval_emd', 'N/A'):.2f} (vs 28.89 baseline)")
