In [None]:
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .


In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml

In [None]:
# Import required libraries
import torch
import wandb
import json
import numpy as np
import tempfile
from pathlib import Path
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
sys.path.append('.')

# Import project modules
from src.data.dataset import create_dataloader
from src.models.online_transformer import OnlineTransformer
from src.models.offline_teacher import OfflineTeacherModel
from src.evaluation.metrics import (
    calculate_harmony_metrics,
    calculate_emd_metrics,
)
from src.config.tokenization_config import (
    SILENCE_TOKEN,
    MELODY_ONSET_HOLD_START,
    CHORD_TOKEN_START,
    PAD_TOKEN,
)
from src.evaluation.evaluate_offline import generate_offline
from src.evaluation.evaluate import generate_online

def load_model_artifact(artifact_path: str) -> tuple[dict, dict]:
    """
    Load a model artifact and its config from wandb.
    
    Args:
        artifact_path: Full path to the artifact (e.g. 'marty1ai/martydepth/model_name:version')
    
    Returns:
        tuple[dict, dict]: Model state dict and config
    """
    api = wandb.Api()
    artifact = api.artifact(artifact_path)
    
    # Download the artifact
    with tempfile.TemporaryDirectory() as tmp_dir:
        artifact_dir = artifact.download(tmp_dir)
        checkpoint_file = next(Path(artifact_dir).glob("*.pth"))
        checkpoint = torch.load(checkpoint_file, map_location='cpu')
    
    return checkpoint['model_state_dict'], checkpoint['config']

In [None]:
# Configuration
config = {
    'data_dir': 'data/interim',
    'split': 'test',
    'batch_size': 32,
    'num_workers': 4,
    'temperature': 1.0,
    'top_k': 50,
    # Model artifact paths
    'online_model_artifact': 'marty1ai/martydepth/online_transformer_model_np0myk09:v7',
    'offline_model_artifact': 'marty1ai/martydepth/offline_teacher_model_rv8se0ur:v9'
}

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else
                     "mps" if torch.backends.mps.is_available() else
                     "cpu")
print(f"Using device: {device}")

# Initialize wandb
wandb.init(
    project="martydepth",
    name="model_evaluation",
    config=config,
    job_type="evaluation"
)

# Load model artifacts
print("\nLoading model artifacts...")

# Online model
online_state_dict, online_config = load_model_artifact(config['online_model_artifact'])
print(f"Loaded online model from {config['online_model_artifact']}")

# Offline model
offline_state_dict, offline_config = load_model_artifact(config['offline_model_artifact'])
print(f"Loaded offline model from {config['offline_model_artifact']}")

# Load tokenizer info from data directory
tokenizer_info_path = Path(config['data_dir']) / "train" / "tokenizer_info.json"
with open(tokenizer_info_path, 'r') as f:
    tokenizer_info = json.load(f)
print("Loaded tokenizer info from data directory")

In [None]:
# Evaluate Online Model
print("\n=== Evaluating Online Model ===")

# Initialize model
model = OnlineTransformer(
    vocab_size=tokenizer_info['total_vocab_size'],
    embed_dim=online_config['embed_dim'],
    num_heads=online_config['num_heads'],
    num_layers=online_config['num_layers'],
    dropout=online_config['dropout'],
    max_seq_length=online_config.get('max_seq_length', 512),
    pad_token_id=PAD_TOKEN
).to(device)

# Load state dict
model.load_state_dict(online_state_dict)
model.eval()

# Create dataloader
max_seq_length = online_config.get('max_seq_length') or online_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='online',
    shuffle=False
)

# Generate sequences
print("\nGenerating sequences...")
generated_sequences, ground_truth_sequences = generate_online(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k']
)

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

online_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Online Model Results ===")
for metric, value in online_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"online/{metric}": value})
    
# Log the artifact version we evaluated
wandb.log({"online_model_artifact": config['online_model_artifact']})

In [None]:
# Evaluate Offline Model
print("\n=== Evaluating Offline Model ===")

# Initialize model
model = OfflineTeacherModel(
    vocab_size=tokenizer_info['total_vocab_size'],
    embed_dim=offline_config['embed_dim'],
    num_heads=offline_config['num_heads'],
    num_layers=offline_config['num_layers'],
    dropout=offline_config['dropout'],
    max_seq_length=offline_config.get('max_seq_length', 512),
    pad_token_id=PAD_TOKEN
).to(device)

# Load state dict
model.load_state_dict(offline_state_dict)
model.eval()

# Create dataloader
max_seq_length = offline_config.get('max_seq_length') or offline_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='offline',
    shuffle=False
)

# Generate sequences
print("\nGenerating sequences...")
generated_sequences, ground_truth_sequences = generate_offline(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k']
)

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

offline_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Offline Model Results ===")
for metric, value in offline_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"offline/{metric}": value})
    
# Log the artifact version we evaluated
wandb.log({"offline_model_artifact": config['offline_model_artifact']})

In [None]:
# Compare Results
print("\n=== Model Comparison ===")
print(f"{'Metric':<30} {'Online':<10} {'Offline':<10} {'Difference':<10}")
print("-" * 60)

for metric in online_metrics.keys():
    online_value = online_metrics[metric]
    offline_value = offline_metrics[metric]
    diff = online_value - offline_value
    print(f"{metric:<30} {online_value:>10.4f} {offline_value:>10.4f} {diff:>10.4f}")
    wandb.log({
        f"comparison/{metric}_diff": diff,
        f"comparison/{metric}_ratio": online_value / offline_value if offline_value != 0 else 0
    })

# Finish wandb run
wandb.finish()