In [None]:
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .


In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml

In [None]:
# Import required libraries
import torch
import wandb
import json
import numpy as np
import tempfile
from pathlib import Path
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
sys.path.append('.')

# Import project modules
from src.data.dataset import create_dataloader
from src.models.online_transformer import OnlineTransformer
from src.models.offline_teacher import OfflineTeacherModel
from src.evaluation.metrics import (
    calculate_harmony_metrics,
    calculate_emd_metrics,
    parse_sequences,  # Added this import
)
from src.config.tokenization_config import (
    SILENCE_TOKEN,
    MELODY_ONSET_HOLD_START,
    CHORD_TOKEN_START,
    PAD_TOKEN,
)
from src.evaluation.evaluate_offline import generate_offline
from src.evaluation.evaluate import generate_online

def load_model_artifact(artifact_path: str) -> tuple[dict, dict, dict]:
    """
    Load a model artifact and its config from wandb.
    Compatible with both checkpoint-style and separate file artifacts.
    
    Args:
        artifact_path: Full path to the artifact (e.g. 'marty1ai/martydepth/model_name:version')
    
    Returns:
        tuple[dict, dict, dict]: Model state dict, config, and tokenizer_info
    """
    api = wandb.Api()
    artifact = api.artifact(artifact_path)
    
    # Download the artifact
    with tempfile.TemporaryDirectory() as tmp_dir:
        artifact_dir = artifact.download(tmp_dir)
        artifact_path_obj = Path(artifact_dir)
        
        # Check for different artifact structures
        checkpoint_files = list(artifact_path_obj.glob("*.pth"))
        tokenizer_files = list(artifact_path_obj.glob("tokenizer_info.json"))
        
        if len(checkpoint_files) == 1 and len(tokenizer_files) == 1:
            # Separate files structure (model.pth + tokenizer_info.json)
            model_path = checkpoint_files[0]
            tokenizer_path = tokenizer_files[0]
            
            # Load model state dict directly
            model_state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
            
            # Load tokenizer info
            with open(tokenizer_path, 'r') as f:
                tokenizer_info = json.load(f)
            
            # Get config from the run that created this artifact
            run = artifact.logged_by()
            config = dict(run.config)
            
        elif len(checkpoint_files) == 1 and len(tokenizer_files) == 0:
            # Checkpoint structure (single .pth file with everything)
            checkpoint_file = checkpoint_files[0]
            checkpoint = torch.load(checkpoint_file, map_location='cpu')
            
            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                # Training checkpoint structure
                model_state_dict = checkpoint['model_state_dict']
                config = checkpoint.get('config', {})
                
                # Try to get tokenizer info from run config or load from data dir
                run = artifact.logged_by()
                run_config = dict(run.config)
                config.update(run_config)  # Merge configs
                
                # Load tokenizer info from data directory as fallback
                tokenizer_info_path = Path("data/interim/train/tokenizer_info.json")
                if tokenizer_info_path.exists():
                    with open(tokenizer_info_path, 'r') as f:
                        tokenizer_info = json.load(f)
                else:
                    raise FileNotFoundError("Could not find tokenizer_info.json in checkpoint or data directory")
            else:
                # Direct state dict
                model_state_dict = checkpoint
                run = artifact.logged_by()
                config = dict(run.config)
                
                # Load tokenizer info from data directory
                tokenizer_info_path = Path("data/interim/train/tokenizer_info.json")
                with open(tokenizer_info_path, 'r') as f:
                    tokenizer_info = json.load(f)
        else:
            raise ValueError(f"Unexpected artifact structure in {artifact_dir}. Expected either model.pth+tokenizer_info.json or single checkpoint.pth")
    
    return model_state_dict, config, tokenizer_info

In [None]:
# Configuration
config = {
    # Data parameters
    'data_dir': 'data/interim',
    'split': 'test',
    'batch_size': 32,
    'num_workers': 4,
    
    # Sampling parameters
    'temperature': 1.0,     # Increased from 1.3 for more exploration
    'top_k': 50,           # Reduced from 30 to focus on more likely but still diverse options
    'wait_beats': 1,       # double
    
    # Model artifact paths
    'online_model_artifact': 'marty1ai/martydepth/online_transformer_model_vercmpgd:v12',
    'offline_model_artifact': 'marty1ai/martydepth/offline_teacher_model_2hd3b6gi:v8'
}

# Add some helpful frame/beat conversions
frames_per_beat = 4  # Standard in our dataset

print(f"\nSampling Parameters:")
print(f"Temperature: {config['temperature']}")
print(f"Top-k: {config['top_k']}")
print(f"Wait beats: {config['wait_beats']}")

# Import necessary constants
from src.config.tokenization_config import (
    MELODY_VOCAB_SIZE,
    CHORD_TOKEN_START,
    SILENCE_TOKEN,
    PAD_TOKEN
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else
                     "mps" if torch.backends.mps.is_available() else
                     "cpu")
print(f"\nUsing device: {device}")

# Initialize wandb
wandb.init(
    project="martydepth",
    name="model_evaluation_configurable_durations",  # Updated name to reflect changes
    config=config,
    job_type="evaluation"
)

# Load model artifacts
print("\nLoading model artifacts...")

# Online model
online_state_dict, online_config, online_tokenizer_info = load_model_artifact(config['online_model_artifact'])
print(f"Loaded online model from {config['online_model_artifact']}")

# Offline model
offline_state_dict, offline_config, offline_tokenizer_info = load_model_artifact(config['offline_model_artifact'])
print(f"Loaded offline model from {config['offline_model_artifact']}")

# Use tokenizer info from the online model (they should be identical)
tokenizer_info = online_tokenizer_info
print("Loaded tokenizer info from model artifacts")

# Verify tokenizer consistency
if online_tokenizer_info != offline_tokenizer_info:
    print("WARNING: Online and offline models have different tokenizer info!")
    print("This may cause evaluation issues.")
else:
    print("✓ Tokenizer info consistent between models")

In [None]:
# Evaluate Online Model
print("\n=== Evaluating Online Model with Diverse Sampling ===")

# Initialize model
total_vocab_size = tokenizer_info['total_vocab_size']
# Check for max_seq_length in config with common parameter names
max_seq_length = 512

model = OnlineTransformer(
    vocab_size=total_vocab_size,
    embed_dim=online_config['embed_dim'],
    num_heads=online_config['num_heads'],
    num_layers=online_config['num_layers'],
    dropout=online_config['dropout'],
    max_seq_length=max_seq_length,
    pad_token_id=PAD_TOKEN
).to(device)

print(f"Initialized online model with max_seq_length: {max_seq_length}")

# Load state dict
model.load_state_dict(online_state_dict)
model.eval()

# Create dataloader
max_seq_length = online_config.get('max_seq_length') or online_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='online',
    shuffle=False
)

# Generate sequences with new sampling parameters
print("\nGenerating sequences with diverse sampling...")
generated_sequences, ground_truth_sequences = generate_online(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k'],
    wait_beats=config['wait_beats']
)

# Debug: Check what format the generated sequences have
print(f"\nDebug - Generated Sequences Format:")
print(f"Number of sequences: {len(generated_sequences)}")
if generated_sequences:
    print(f"First sequence length: {len(generated_sequences[0])}")
    print(f"First sequence sample: {generated_sequences[0][:10]}")
    print(f"Token range: [{min(generated_sequences[0])}, {max(generated_sequences[0])}]")
    
print(f"\nDebug - Ground Truth Format:")
if ground_truth_sequences:
    print(f"First GT sequence length: {len(ground_truth_sequences[0])}")
    print(f"First GT sequence sample: {ground_truth_sequences[0][:10]}")

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

online_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Online Model Results with Diverse Sampling ===")
for metric, value in online_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"online_diverse/{metric}": value})
    
# Log the artifact version and sampling parameters
wandb.log({
    "online_model_artifact": config['online_model_artifact'],
    "sampling/temperature": config['temperature'],
    "sampling/top_k": config['top_k'],
    "sampling/wait_beats": config['wait_beats']
})

In [None]:
# Evaluate Offline Model
print("\n=== Evaluating Offline Model with Diverse Sampling ===")

# Get max sequence length from config with common parameter names
max_seq_length = (offline_config.get('max_seq_length') or 
                  offline_config.get('max_sequence_length') or 
                  256)  # Default to 256 for offline models

# Initialize model
model = OfflineTeacherModel(
    melody_vocab_size=tokenizer_info['melody_vocab_size'],  # Get from tokenizer info
    chord_vocab_size=tokenizer_info['chord_vocab_size'],    # Get from tokenizer info
    embed_dim=offline_config['embed_dim'],
    num_heads=offline_config['num_heads'],
    num_layers=offline_config['num_layers'],
    dropout=offline_config['dropout'],
    max_seq_length=max_seq_length,
    pad_token_id=PAD_TOKEN
).to(device)

print(f"Initialized offline model with max_seq_length: {max_seq_length}")

# Load state dict
model.load_state_dict(offline_state_dict)
model.eval()
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='offline',
    shuffle=False
)

# Generate sequences with new sampling parameters
print("\nGenerating sequences with diverse sampling...")
generated_sequences, ground_truth_sequences = generate_offline(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k']
)

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

offline_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Offline Model Results with Diverse Sampling ===")
for metric, value in offline_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"offline_diverse/{metric}": value})
    
# Log the artifact version and sampling parameters
wandb.log({
    "offline_model_artifact": config['offline_model_artifact'],
    "sampling/temperature": config['temperature'],
    "sampling/top_k": config['top_k']
})

In [None]:
# Recalculate metrics from previous sequences
print("\n=== Recalculating Metrics from Previous Run ===")

# Recalculate metrics using the sequences we already generated
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

print("\n=== Online Model Results ===")
for metric, value in {**harmony_metrics, **emd_metrics}.items():
    print(f"{metric}: {value:.4f}")

# Print raw histograms for debugging
print("\n=== Debug: Raw Histograms ===")

# Get all sequences first since parse_sequences is a generator
sequences = list(parse_sequences(generated_sequences, tokenizer_info))

# Get intervals and print debug info for first few sequences
intervals = []
debug_count = 0
for data in sequences:
    if not data['notes'] or not data['chords']: continue
    
    # Print debug info for first sequence with both notes and chords
    if debug_count < 1:
        print("\nExample sequence structure:")
        print(f"Number of notes: {len(data['notes'])}")
        print(f"Number of chords: {len(data['chords'])}")
        print("\nFirst few notes:")
        for note in data['notes'][:3]:
            print(f"Note: {note}")
        print("\nFirst few chords:")
        for chord in data['chords'][:3]:
            print(f"Chord: {chord}")
        debug_count += 1
    
    note_onsets = np.array([n['start'] for n in data['notes']])
    chord_onsets = np.array([c['start'] for c in data['chords']])
    
    # Debug first few interval calculations
    if debug_count == 1:
        print("\nFirst few interval calculations:")
        for i, n_onset in enumerate(note_onsets[:3]):
            if len(chord_onsets) > 0:
                # Find closest chord onset using absolute distance
                distances = np.abs(chord_onsets - n_onset)
                min_distance = np.min(distances)
                closest_idx = np.argmin(distances)
                closest_onset = chord_onsets[closest_idx]
                print(f"Note onset {n_onset}, closest chord onset: {closest_onset}, absolute distance: {min_distance} frames")
    
    # Calculate all intervals
    for n_onset in note_onsets:
        if len(chord_onsets) > 0:
            # Find closest chord onset using absolute distance
            distances = np.abs(chord_onsets - n_onset)
            min_distance = np.min(distances)
            intervals.append(min_distance)

if not intervals:
    print("No interval data found!")
else:
    # Print raw interval statistics
    print(f"\nInterval statistics:")
    print(f"Min interval: {min(intervals)}")
    print(f"Max interval: {max(intervals)}")
    print(f"Mean interval: {np.mean(intervals):.2f}")
    print(f"Median interval: {np.median(intervals):.2f}")
    
    # Create bins as specified in paper [0, 1, 2, ..., 16, 17, ∞]
    onset_bins = list(range(18)) + [np.inf]
    # Use weights=None to get raw counts first
    hist, bin_edges = np.histogram(intervals, bins=onset_bins, weights=None)
    # Then normalize manually to get probabilities
    hist = hist / np.sum(hist) if np.sum(hist) > 0 else hist

    print("\nOnset Interval Distribution:")
    print(f"Total intervals: {len(intervals)}")
    print(f"Bin edges: {bin_edges}")
    print("\nIntervals show absolute distance between note and nearest chord onset")
    print("0 means note and chord are simultaneous")
    print("Higher values mean more frames between note and nearest chord")
    print("")
    
    for i, count in enumerate(hist):
        if i < len(hist)-1:
            print(f"Bin {i}: {count:.4f} ({int(count * len(intervals))} intervals)")
        else:
            print(f"Bin >17: {count:.4f} ({int(count * len(intervals))} intervals)")

# Finish wandb run
wandb.finish()

In [None]:
# Calculate chord length distribution
print("\n=== Debug: Chord Length Distribution ===")
chord_lengths = []

# Get first sequence to examine chord structure
sequences = list(parse_sequences(generated_sequences, tokenizer_info))
if sequences and sequences[0]['chords']:
    print("\nExample chord structure:")
    print(sequences[0]['chords'][0])

# Process all sequences
for data in sequences:
    if not data['chords']: continue
    for chord in data['chords']:
        # Calculate length as end - start
        chord_lengths.append(chord['end'] - chord['start'])

if not chord_lengths:
    print("No chord data found!")
else:
    print(f"\nTotal chords: {len(chord_lengths)}")
    print(f"Length range: [{min(chord_lengths)}, {max(chord_lengths)}]")

    # Print histogram counts
    length_bins = list(range(34)) + [np.inf]  # [0,1,2,...,32,33,∞]
    # Use weights=None to get raw counts first
    hist, _ = np.histogram(chord_lengths, bins=length_bins, weights=None)
    # Then normalize manually to get probabilities
    hist = hist / np.sum(hist) if np.sum(hist) > 0 else hist

    print("\nChord Length Distribution:")
    for i, count in enumerate(hist):
        if i < len(hist)-1:
            print(f"Bin {i}: {count:.4f} ({int(count * len(chord_lengths))} chords)")
        else:
            print(f"Bin >33: {count:.4f} ({int(count * len(chord_lengths))} chords)")

    # Calculate entropy from the normalized histogram
    entropy = -np.sum(hist[hist > 0] * np.log(hist[hist > 0]))  # Calculate entropy in nats
    print(f"\nChord Length Entropy: {entropy:.4f} nats")
