In [None]:
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .


In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml

In [None]:
# Import required libraries
import torch
import wandb
import json
import numpy as np
import tempfile
from pathlib import Path
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
sys.path.append('.')

# Import project modules
from src.data.dataset import create_dataloader
from src.models.online_transformer import OnlineTransformer
from src.models.offline_teacher_t5 import T5OfflineTeacherModel
from src.evaluation.metrics import (
    calculate_harmony_metrics,
    calculate_emd_metrics,
    parse_sequences,  # Added this import
)
from src.config.tokenization_config import (
    SILENCE_TOKEN,
    MELODY_ONSET_HOLD_START,
    CHORD_TOKEN_START,
    PAD_TOKEN,
)
from src.evaluation.evaluate_offline import generate_offline
from src.evaluation.evaluate import generate_online

def load_model_artifact(artifact_path: str) -> tuple[dict, dict, dict]:
    """
    Load a model artifact and its config from wandb.
    Compatible with both checkpoint-style and separate file artifacts.
    
    Args:
        artifact_path: Full path to the artifact (e.g. 'marty1ai/martydepth/model_name:version')
    
    Returns:
        tuple[dict, dict, dict]: Model state dict, config, and tokenizer_info
    """
    api = wandb.Api()
    artifact = api.artifact(artifact_path)
    
    # Download the artifact
    with tempfile.TemporaryDirectory() as tmp_dir:
        artifact_dir = artifact.download(tmp_dir)
        artifact_path_obj = Path(artifact_dir)
        
        # Check for different artifact structures
        checkpoint_files = list(artifact_path_obj.glob("*.pth"))
        tokenizer_files = list(artifact_path_obj.glob("tokenizer_info.json"))
        
        if len(checkpoint_files) == 1 and len(tokenizer_files) == 1:
            # Separate files structure (model.pth + tokenizer_info.json)
            model_path = checkpoint_files[0]
            tokenizer_path = tokenizer_files[0]
            
            # Load model state dict directly
            model_state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
            
            # Load tokenizer info
            with open(tokenizer_path, 'r') as f:
                tokenizer_info = json.load(f)
            
            # Get config from the run that created this artifact
            run = artifact.logged_by()
            config = dict(run.config)
            
        elif len(checkpoint_files) == 1 and len(tokenizer_files) == 0:
            # Checkpoint structure (single .pth file with everything)
            checkpoint_file = checkpoint_files[0]
            checkpoint = torch.load(checkpoint_file, map_location='cpu')
            
            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                # Training checkpoint structure
                model_state_dict = checkpoint['model_state_dict']
                config = checkpoint.get('config', {})
                
                # Try to get tokenizer info from run config or load from data dir
                run = artifact.logged_by()
                run_config = dict(run.config)
                config.update(run_config)  # Merge configs
                
                # Load tokenizer info from data directory as fallback
                tokenizer_info_path = Path("data/interim/train/tokenizer_info.json")
                if tokenizer_info_path.exists():
                    with open(tokenizer_info_path, 'r') as f:
                        tokenizer_info = json.load(f)
                else:
                    raise FileNotFoundError("Could not find tokenizer_info.json in checkpoint or data directory")
            else:
                # Direct state dict
                model_state_dict = checkpoint
                run = artifact.logged_by()
                config = dict(run.config)
                
                # Load tokenizer info from data directory
                tokenizer_info_path = Path("data/interim/train/tokenizer_info.json")
                with open(tokenizer_info_path, 'r') as f:
                    tokenizer_info = json.load(f)
        else:
            raise ValueError(f"Unexpected artifact structure in {artifact_dir}. Expected either model.pth+tokenizer_info.json or single checkpoint.pth")
    
    return model_state_dict, config, tokenizer_info

In [None]:
# Configuration
config = {
    # Data parameters
    'data_dir': 'data/interim',
    'split': 'test',
    'batch_size': 32,
    'num_workers': 4,
    
    # Sampling parameters
    'temperature': 1.0,     # Increased from 1.3 for more exploration
    'top_k': 50,           # Reduced from 30 to focus on more likely but still diverse options
    'wait_beats': 1,       # double
    
    # Model artifact paths
    'online_model_artifact': 'marty1ai/martydepth/online_transformer_model_bvwago40:v13',
    'offline_model_artifact': 'marty1ai/martydepth/offline_teacher_model_2hd3b6gi:v8'
}

# Add some helpful frame/beat conversions
frames_per_beat = 4  # Standard in our dataset

print(f"\nSampling Parameters:")
print(f"Temperature: {config['temperature']}")
print(f"Top-k: {config['top_k']}")
print(f"Wait beats: {config['wait_beats']}")

# Import necessary constants
from src.config.tokenization_config import (
    MELODY_VOCAB_SIZE,
    CHORD_TOKEN_START,
    SILENCE_TOKEN,
    PAD_TOKEN
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else
                     "mps" if torch.backends.mps.is_available() else
                     "cpu")
print(f"\nUsing device: {device}")

# Initialize wandb
wandb.init(
    project="martydepth",
    name="model_evaluation_configurable_durations",  # Updated name to reflect changes
    config=config,
    job_type="evaluation"
)

# Load model artifacts
print("\nLoading model artifacts...")

# Online model
online_state_dict, online_config, online_tokenizer_info = load_model_artifact(config['online_model_artifact'])
print(f"Loaded online model from {config['online_model_artifact']}")

# Offline model
offline_state_dict, offline_config, offline_tokenizer_info = load_model_artifact(config['offline_model_artifact'])
print(f"Loaded offline model from {config['offline_model_artifact']}")

# Use tokenizer info from the online model (they should be identical)
tokenizer_info = online_tokenizer_info
print("Loaded tokenizer info from model artifacts")

# Verify tokenizer consistency
if online_tokenizer_info != offline_tokenizer_info:
    print("WARNING: Online and offline models have different tokenizer info!")
    print("This may cause evaluation issues.")
else:
    print("✓ Tokenizer info consistent between models")

In [None]:
# Evaluate Online Model
print("\n=== Evaluating Online Model with Diverse Sampling ===")

# Initialize model
total_vocab_size = tokenizer_info['total_vocab_size']
# Check for max_seq_length in config with common parameter names
max_seq_length = 512

# Fix vocabulary size if needed (should be 4665, not 4664)
if total_vocab_size == 4664:
    print(f"WARNING: Adjusting vocab_size from {total_vocab_size} to 4665 to fix off-by-one error")
    total_vocab_size = 4665

model = OnlineTransformer(
    vocab_size=total_vocab_size,
    embed_dim=online_config['embed_dim'],
    num_heads=online_config['num_heads'],
    num_layers=online_config['num_layers'],
    dropout=online_config['dropout'],
    max_seq_length=max_seq_length,
    pad_token_id=PAD_TOKEN
).to(device)

print(f"Initialized online model with max_seq_length: {max_seq_length}")

# Load state dict
model.load_state_dict(online_state_dict)
model.eval()

# Create dataloader
max_seq_length = online_config.get('max_seq_length') or online_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='online',
    shuffle=False
)

# Generate sequences with new sampling parameters
print("\nGenerating sequences with diverse sampling...")
generated_sequences, ground_truth_sequences = generate_online(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k'],
    wait_beats=config['wait_beats']
)

# Debug: Check what format the generated sequences have
print(f"\nDebug - Generated Sequences Format:")
print(f"Number of sequences: {len(generated_sequences)}")
if generated_sequences:
    print(f"First sequence length: {len(generated_sequences[0])}")
    print(f"First sequence sample: {generated_sequences[0][:10]}")
    print(f"Token range: [{min(generated_sequences[0])}, {max(generated_sequences[0])}]")
    
print(f"\nDebug - Ground Truth Format:")
if ground_truth_sequences:
    print(f"First GT sequence length: {len(ground_truth_sequences[0])}")
    print(f"First GT sequence sample: {ground_truth_sequences[0][:10]}")

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

online_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Online Model Results with Diverse Sampling ===")
for metric, value in online_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"online_diverse/{metric}": value})
    
# Log the artifact version and sampling parameters
wandb.log({
    "online_model_artifact": config['online_model_artifact'],
    "sampling/temperature": config['temperature'],
    "sampling/top_k": config['top_k'],
    "sampling/wait_beats": config['wait_beats']
})

In [None]:
# Evaluate Offline Model
print("\n=== Evaluating Offline Model with Diverse Sampling ===")

# Get max sequence length from config with common parameter names
max_seq_length = (offline_config.get('max_seq_length') or 
                  offline_config.get('max_sequence_length') or 
                  256)  # Default to 256 for offline models

# Initialize model
model = T5OfflineTeacherModel(
    melody_vocab_size=tokenizer_info['melody_vocab_size'],  # Get from tokenizer info
    chord_vocab_size=tokenizer_info['chord_vocab_size'],    # Get from tokenizer info
    embed_dim=offline_config['embed_dim'],
    num_heads=offline_config['num_heads'],
    num_layers=offline_config['num_layers'],
    dropout=offline_config['dropout'],
    max_seq_length=max_seq_length,
    pad_token_id=PAD_TOKEN,
    total_vocab_size=tokenizer_info.get('total_vocab_size', 4779)  # Use unified vocabulary
).to(device)

print(f"Initialized offline model with max_seq_length: {max_seq_length}")

# Load state dict
model.load_state_dict(offline_state_dict)
model.eval()
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='offline',
    shuffle=False
)

# Generate sequences with new sampling parameters
print("\nGenerating sequences with diverse sampling...")
generated_sequences, ground_truth_sequences = generate_offline(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k']
)

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

offline_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Offline Model Results with Diverse Sampling ===")
for metric, value in offline_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"offline_diverse/{metric}": value})
    
# Log the artifact version and sampling parameters
wandb.log({
    "offline_model_artifact": config['offline_model_artifact'],
    "sampling/temperature": config['temperature'],
    "sampling/top_k": config['top_k']
})

In [None]:
# Debug: Examine what the offline model generated
print("=== DEBUGGING OFFLINE MODEL GENERATION ===")

print(f"\nGenerated sequences info:")
print(f"  Number of sequences: {len(generated_sequences)}")
print(f"  Number of ground truth sequences: {len(ground_truth_sequences)}")

if generated_sequences:
    print(f"\nFirst few generated sequences analysis:")
    for i, seq in enumerate(generated_sequences[:3]):
        # Convert to numpy array if it's a list
        if isinstance(seq, list):
            seq_array = np.array(seq)
        else:
            seq_array = seq
            
        print(f"\nSequence {i}:")
        print(f"  Length: {len(seq_array)}")
        print(f"  Token range: [{seq_array.min()}, {seq_array.max()}]")
        print(f"  Unique tokens: {len(np.unique(seq_array))}")
        print(f"  First 10 tokens: {seq_array[:10]}")
        
        # Check token distribution
        unique_tokens, counts = np.unique(seq_array, return_counts=True)
        print(f"  Top 5 most common tokens:")
        for token, count in zip(unique_tokens[:5], counts[:5]):
            print(f"    Token {token}: {count} times ({count/len(seq_array)*100:.1f}%)")
        
        # Check if mostly silence/PAD tokens
        pad_token = tokenizer_info['pad_token_id']
        pad_count = np.sum(seq_array == pad_token)
        print(f"  PAD token ({pad_token}) count: {pad_count} ({pad_count/len(seq_array)*100:.1f}%)")

# Check ground truth sequences too
if ground_truth_sequences:
    print(f"\nGround truth sequences analysis:")
    for i, seq in enumerate(ground_truth_sequences[:3]):
        # Convert to numpy array if it's a list
        if isinstance(seq, list):
            seq_array = np.array(seq)
        else:
            seq_array = seq
            
        print(f"\nGround Truth Sequence {i}:")
        print(f"  Length: {len(seq_array)}")
        print(f"  Token range: [{seq_array.min()}, {seq_array.max()}]")
        print(f"  Unique tokens: {len(np.unique(seq_array))}")
        print(f"  First 10 tokens: {seq_array[:10]}")
        
        # Check token distribution
        unique_tokens, counts = np.unique(seq_array, return_counts=True)
        print(f"  Top 5 most common tokens:")
        for token, count in zip(unique_tokens[:5], counts[:5]):
            print(f"    Token {token}: {count} times ({count/len(seq_array)*100:.1f}%)")

# Show tokenizer info for reference
print(f"\nTokenizer info for reference:")
print(f"  Melody vocab size: {tokenizer_info['melody_vocab_size']}")
print(f"  Chord vocab size: {tokenizer_info['chord_vocab_size']}")
print(f"  Total vocab size: {tokenizer_info['total_vocab_size']}")
print(f"  PAD token: {tokenizer_info['pad_token_id']}")
print(f"  Melody tokens: 0 - {tokenizer_info['melody_vocab_size']-1}")
print(f"  PAD token: {tokenizer_info['pad_token_id']}")
print(f"  Chord tokens: {tokenizer_info['pad_token_id']+1} - {tokenizer_info['total_vocab_size']-1}")


In [None]:
# Debug: Test metrics parsing with generated sequences
print("=== DEBUGGING METRICS PARSING ===")

# Test the parse_sequences function with our generated sequences
print(f"\nTesting parse_sequences function...")
try:
    sequences = list(parse_sequences(generated_sequences, tokenizer_info))
    print(f"Successfully parsed {len(sequences)} sequences")
    
    if sequences:
        print(f"\nFirst parsed sequence analysis:")
        seq_data = sequences[0]
        print(f"  Number of notes: {len(seq_data['notes'])}")
        print(f"  Number of chords: {len(seq_data['chords'])}")
        
        if seq_data['notes']:
            print(f"  First 3 notes:")
            for i, note in enumerate(seq_data['notes'][:3]):
                print(f"    Note {i}: {note}")
        else:
            print(f"  No notes found!")
            
        if seq_data['chords']:
            print(f"  First 3 chords:")
            for i, chord in enumerate(seq_data['chords'][:3]):
                print(f"    Chord {i}: {chord}")
        else:
            print(f"  No chords found!")
            
        # Check if we have the expected structure
        if not seq_data['notes'] and not seq_data['chords']:
            print(f"  WARNING: No notes or chords found in parsed sequence!")
            
    else:
        print(f"  No sequences were successfully parsed!")
        
except Exception as e:
    print(f"Error parsing sequences: {e}")
    import traceback
    traceback.print_exc()

# Debug the harmony calculation specifically
print(f"\n=== DEBUGGING HARMONY CALCULATION ===")
try:
    # Test harmony metrics calculation step by step
    from src.evaluation.metrics import calculate_harmony_metrics
    
    # Create a small test with just a few sequences
    test_sequences = generated_sequences[:5]
    test_harmony_metrics = calculate_harmony_metrics(test_sequences, tokenizer_info)
    
    print(f"Test harmony metrics (first 5 sequences):")
    for key, value in test_harmony_metrics.items():
        print(f"  {key}: {value}")
        
    # Show what tokens are being considered as melody vs chord
    print(f"\nToken classification:")
    print(f"  Melody token range: 0 to {tokenizer_info['melody_vocab_size']-1}")
    print(f"  PAD token: {tokenizer_info['pad_token_id']}")
    print(f"  Chord token range: {tokenizer_info['pad_token_id']+1} to {tokenizer_info['total_vocab_size']-1}")
    
    # Check what tokens are actually in our sequences
    if test_sequences:
        # Convert sequences to arrays if they're lists
        test_arrays = [np.array(seq) if isinstance(seq, list) else seq for seq in test_sequences]
        all_tokens = np.concatenate(test_arrays)
        unique_tokens = np.unique(all_tokens)
        print(f"\nActual tokens in generated sequences:")
        print(f"  Unique tokens: {unique_tokens}")
        print(f"  Token counts:")
        for token in unique_tokens[:10]:  # Show first 10
            count = np.sum(all_tokens == token)
            if token < tokenizer_info['melody_vocab_size']:
                token_type = "MELODY"
            elif token == tokenizer_info['pad_token_id']:
                token_type = "PAD"
            else:
                token_type = "CHORD"
            print(f"    Token {token} ({token_type}): {count} times")
    
except Exception as e:
    print(f"Error in harmony calculation: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Debug: Compare with ground truth and identify the root cause
print("=== DEBUGGING ROOT CAUSE ===")

# The issue is that offline model sequences are CHORD-ONLY
# but the metrics expect INTERLEAVED melody+chord sequences
print(f"\nROOT CAUSE ANALYSIS:")
print(f"1. Online model generates INTERLEAVED sequences (melody + chord)")
print(f"2. Offline model generates CHORD-ONLY sequences")
print(f"3. Metrics parsing expects INTERLEAVED sequences")

# Check if this is true by examining sequence structure
print(f"\nSequence structure comparison:")
print(f"Generated sequences (offline):")
if generated_sequences:
    seq = generated_sequences[0]
    seq_array = np.array(seq) if isinstance(seq, list) else seq
    print(f"  Length: {len(seq_array)}")
    print(f"  Token range: [{seq_array.min()}, {seq_array.max()}]")
    print(f"  All tokens >= chord_start? {np.all(seq_array >= tokenizer_info['pad_token_id'])}")

print(f"\nGround truth sequences:")
if ground_truth_sequences:
    seq = ground_truth_sequences[0]
    seq_array = np.array(seq) if isinstance(seq, list) else seq
    print(f"  Length: {len(seq_array)}")
    print(f"  Token range: [{seq_array.min()}, {seq_array.max()}]")
    print(f"  All tokens >= chord_start? {np.all(seq_array >= tokenizer_info['pad_token_id'])}")

# Show the key insight
print(f"\nKEY INSIGHT:")
print(f"The offline model evaluation is trying to calculate melody-in-chord ratio")
print(f"from CHORD-ONLY sequences, which is impossible!")
print(f"")
print(f"For offline evaluation, we need to:")
print(f"1. Get the melody tokens from the input batch")
print(f"2. Pair them with the generated chord tokens")
print(f"3. Create interleaved sequences for metrics calculation")

# Check what the dataloader provides
print(f"\nDataloader batch structure:")
print(f"Available keys: {list(batch.keys()) if 'batch' in globals() else 'No batch loaded'}")

# Manual test of creating interleaved sequences
print(f"\n=== TESTING MANUAL INTERLEAVED CREATION ===")
try:
    # Create a small test dataloader to get melody tokens
    from src.data.dataset import create_dataloader
    
    test_dataloader, _ = create_dataloader(
        data_dir=Path(config['data_dir']),
        split=config['split'],
        batch_size=2,
        num_workers=0,
        sequence_length=256,
        mode='offline',
        shuffle=False
    )
    
    # Get a batch
    test_batch = next(iter(test_dataloader))
    melody_tokens = test_batch['melody_tokens']
    
    print(f"Melody tokens shape: {melody_tokens.shape}")
    print(f"Melody token range: [{melody_tokens.min()}, {melody_tokens.max()}]")
    
    # Create interleaved sequences manually
    if generated_sequences:
        print(f"\nCreating interleaved test sequence:")
        test_melody = melody_tokens[0].numpy()  # First melody
        test_chords = generated_sequences[0]    # First generated chord sequence
        test_chords = np.array(test_chords) if isinstance(test_chords, list) else test_chords
        
        # Ensure same length
        min_len = min(len(test_melody), len(test_chords))
        test_melody = test_melody[:min_len]
        test_chords = test_chords[:min_len]
        
        # Create interleaved sequence
        interleaved = np.empty(len(test_melody) * 2, dtype=np.int64)
        interleaved[1::2] = test_melody  # Odd indices: melody
        interleaved[0::2] = test_chords  # Even indices: chords
        
        print(f"Interleaved sequence created:")
        print(f"  Length: {len(interleaved)}")
        print(f"  First 10 tokens: {interleaved[:10]}")
        
        # Test metrics on this interleaved sequence
        test_harmony = calculate_harmony_metrics([interleaved], tokenizer_info)
        print(f"Test harmony metrics on interleaved sequence:")
        for key, value in test_harmony.items():
            print(f"  {key}: {value}")
            
except Exception as e:
    print(f"Error in manual test: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# SOLUTION: Fix the offline evaluation to create proper interleaved sequences
print("=== SOLUTION: FIXING OFFLINE EVALUATION ===")

print(f"The issue is clear: offline model generates chord-only sequences,")
print(f"but metrics expect interleaved melody+chord sequences.")
print(f"")
print(f"We need to modify the offline evaluation to:")
print(f"1. Collect melody tokens from the input batches")
print(f"2. Pair them with generated chord sequences")
print(f"3. Create interleaved sequences for metrics")

# Let's create a fixed version of the evaluation manually
print(f"\n=== MANUAL FIX FOR TESTING ===")
try:
    # Re-run the offline generation but also collect melody tokens
    print("Re-running offline generation with melody collection...")
    
    # Store melody tokens from each batch
    melody_sequences = []
    
    # We need to recreate the generation loop but collect melodies
    generated_sequences_fixed = []
    ground_truth_sequences_fixed = []
    
    # Get fresh dataloader
    dataloader_fixed, _ = create_dataloader(
        data_dir=Path(config['data_dir']),
        split=config['split'],
        batch_size=config['batch_size'],
        num_workers=config['num_workers'],
        sequence_length=max_seq_length,
        mode='offline',
        shuffle=False
    )
    
    # We'll just use the existing sequences but pair them with melodies
    batch_count = 0
    for batch in dataloader_fixed:
        if batch_count >= len(generated_sequences) // config['batch_size']:
            break
            
        melody_tokens = batch['melody_tokens']
        
        # Add melody sequences for this batch
        for i in range(melody_tokens.shape[0]):
            melody_sequences.append(melody_tokens[i].numpy())
        
        batch_count += 1
    
    print(f"Collected {len(melody_sequences)} melody sequences")
    print(f"Generated {len(generated_sequences)} chord sequences")
    
    # Create interleaved sequences for metrics
    interleaved_sequences = []
    
    num_sequences = min(len(melody_sequences), len(generated_sequences))
    print(f"Creating {num_sequences} interleaved sequences...")
    
    for i in range(num_sequences):
        melody_seq = melody_sequences[i]
        chord_seq = generated_sequences[i]
        chord_seq = np.array(chord_seq) if isinstance(chord_seq, list) else chord_seq
        
        # Ensure same length
        min_len = min(len(melody_seq), len(chord_seq))
        melody_seq = melody_seq[:min_len]
        chord_seq = chord_seq[:min_len]
        
        # Create interleaved sequence: [chord_0, melody_0, chord_1, melody_1, ...]
        interleaved = np.empty(len(melody_seq) * 2, dtype=np.int64)
        interleaved[1::2] = melody_seq  # Odd indices: melody
        interleaved[0::2] = chord_seq   # Even indices: chords
        
        interleaved_sequences.append(interleaved)
    
    print(f"Created {len(interleaved_sequences)} interleaved sequences")
    
    # Now calculate metrics with proper interleaved sequences
    print(f"\nCalculating metrics with fixed interleaved sequences...")
    
    # Also create interleaved ground truth sequences
    ground_truth_interleaved = []
    for i in range(min(len(melody_sequences), len(ground_truth_sequences))):
        melody_seq = melody_sequences[i]
        gt_chord_seq = ground_truth_sequences[i]
        gt_chord_seq = np.array(gt_chord_seq) if isinstance(gt_chord_seq, list) else gt_chord_seq
        
        min_len = min(len(melody_seq), len(gt_chord_seq))
        melody_seq = melody_seq[:min_len]
        gt_chord_seq = gt_chord_seq[:min_len]
        
        # Create interleaved ground truth
        gt_interleaved = np.empty(len(melody_seq) * 2, dtype=np.int64)
        gt_interleaved[1::2] = melody_seq    # Odd indices: melody
        gt_interleaved[0::2] = gt_chord_seq  # Even indices: chords
        
        ground_truth_interleaved.append(gt_interleaved)
    
    # Calculate fixed metrics
    fixed_harmony_metrics = calculate_harmony_metrics(interleaved_sequences, tokenizer_info)
    fixed_emd_metrics = calculate_emd_metrics(interleaved_sequences, ground_truth_interleaved, tokenizer_info)
    
    print(f"\n=== FIXED OFFLINE MODEL RESULTS ===")
    for metric, value in {**fixed_harmony_metrics, **fixed_emd_metrics}.items():
        print(f"{metric}: {value:.4f}")
    
except Exception as e:
    print(f"Error in manual fix: {e}")
    import traceback
    traceback.print_exc()

print(f"\nTo permanently fix this, we need to modify the generate_offline function")
print(f"to return interleaved sequences instead of chord-only sequences.")
