In [None]:
!git clone https://github.com/imabeastdrew/Martydepth.git
%cd Martydepth

# Install the package in development mode
%pip install -e .


In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml

In [None]:
# Import required libraries
import torch
import wandb
import tempfile
import json
import numpy as np
from pathlib import Path
from tqdm.notebook import tqdm
import os
import sys
import yaml

# Add project root to Python path
sys.path.append('.')

print(os.getcwd())

# Import project modules
from src.data.dataset import create_dataloader
from src.models.online_transformer import OnlineTransformer
from src.models.offline_teacher import OfflineTeacherModel
from src.evaluation.metrics import (
    calculate_harmony_metrics,
    calculate_emd_metrics,
)
from src.config.tokenization_config import (
    SILENCE_TOKEN,
    MELODY_ONSET_HOLD_START,
    CHORD_TOKEN_START,
)
from src.evaluation.evaluate_offline import generate_offline
from src.evaluation.evaluate import generate_online
from src.config.tokenization_config import PAD_TOKEN

In [None]:
# Configuration
config = {
    'data_dir': 'data/interim',
    'split': 'test',
    'batch_size': 32,
    'num_workers': 4,
    'temperature': 1.0,
    'top_k': 50,
}

# Get checkpoint paths from user input with default paths from Google Drive
online_checkpoint_path = "/content/drive/MyDrive/Martydepth/online_transformer_epoch_11.pth"
online_config_path = "/content/Martydepth/src/training/configs/online_transformer_base.yaml"

offline_checkpoint_path = "/content/drive/MyDrive/Martydepth/offline_teacher_epoch_12.pth"
offline_config_path = "/content/Martydepth/src/training/configs/offline_teacher_base.yaml"


# Load tokenizer info
tokenizer_info_path = Path(config['data_dir']) / "train" / "tokenizer_info.json"
with open(tokenizer_info_path, 'r') as f:
    tokenizer_info = json.load(f)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else
                     "mps" if torch.backends.mps.is_available() else
                     "cpu")
print(f"Using device: {device}")

# Initialize wandb
wandb.init(
    project="martydepth",
    name="model_evaluation",
    config=config,
    job_type="evaluation"
)

In [None]:
# Evaluate Online Model
print("\n=== Evaluating Online Model ===")
print(f"Loading model from checkpoint: {online_checkpoint_path}")

# Load model config
with open(online_config_path, 'r') as f:
    online_config = yaml.safe_load(f)

# Initialize model
model = OnlineTransformer(
    vocab_size=tokenizer_info['total_vocab_size'],
    embed_dim=online_config['embed_dim'],
    num_heads=online_config['num_heads'],
    num_layers=online_config['num_layers'],
    dropout=online_config['dropout'],
    max_seq_length=online_config.get('max_seq_length', 512),
    pad_token_id=PAD_TOKEN # Use PAD_TOKEN here
).to(device)

# Load checkpoint
checkpoint = torch.load(online_checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Create dataloader
max_seq_length = online_config.get('max_seq_length') or online_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='online',
    shuffle=False
)

# Generate sequences
print("\nGenerating sequences...")
generated_sequences, ground_truth_sequences = generate_online(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device,
    temperature=config['temperature'],
    top_k=config['top_k']
)

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

online_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Online Model Results ===")
for metric, value in online_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"online/{metric}": value})

In [None]:
# Evaluate Offline Model
print("\n=== Evaluating Offline Model ===")
print(f"Loading model from checkpoint: {offline_checkpoint_path}")

# Load model config
with open(offline_config_path, 'r') as f:
    offline_config = yaml.safe_load(f)

# Initialize model
model = OfflineTeacherModel(
    melody_vocab_size=tokenizer_info['melody_vocab_size'],
    chord_vocab_size=tokenizer_info['chord_vocab_size'],
    embed_dim=offline_config['embed_dim'],
    num_heads=offline_config['num_heads'],
    num_layers=offline_config['num_layers'],
    max_seq_length=offline_config.get('max_seq_length', 512),
    pad_token_id=PAD_TOKEN # Use PAD_TOKEN here
).to(device)

# Load checkpoint
checkpoint = torch.load(offline_checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Create dataloader
max_seq_length = offline_config.get('max_seq_length') or offline_config.get('max_sequence_length') or 512
dataloader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split=config['split'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=max_seq_length,
    mode='offline',
    shuffle=False
)

# Generate sequences
print("\nGenerating sequences...")
generated_sequences, ground_truth_sequences = generate_offline(
    model=model,
    dataloader=dataloader,
    tokenizer_info=tokenizer_info,
    device=device
)

# Calculate metrics
print("\nCalculating metrics...")
harmony_metrics = calculate_harmony_metrics(generated_sequences, tokenizer_info)
emd_metrics = calculate_emd_metrics(generated_sequences, ground_truth_sequences, tokenizer_info)

offline_metrics = {**harmony_metrics, **emd_metrics}

print("\n=== Offline Model Results ===")
for metric, value in offline_metrics.items():
    print(f"{metric}: {value:.4f}")
    wandb.log({f"offline/{metric}": value})

In [None]:
# Compare Results
print("\n=== Model Comparison ===")
print(f"{'Metric':<30} {'Online':<10} {'Offline':<10} {'Difference':<10}")
print("-" * 60)

for metric in online_metrics.keys():
    online_value = online_metrics[metric]
    offline_value = offline_metrics[metric]
    diff = online_value - offline_value
    print(f"{metric:<30} {online_value:>10.4f} {offline_value:>10.4f} {diff:>10.4f}")
    wandb.log({
        f"comparison/{metric}_diff": diff,
        f"comparison/{metric}_ratio": online_value / offline_value if offline_value != 0 else 0
    })

# Finish wandb run
wandb.finish()