In [None]:
import sys
from pathlib import Path

import torch
import wandb
from torch.utils.data import DataLoader

project_root = Path.cwd()
src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))

from musicagent.config import DataConfig, ModelConfig
from musicagent.dataset import MusicAgentDataset, collate_fn
from musicagent.eval.metrics import (
    chord_length_entropy,
    chord_lengths,
    note_in_chord_ratio,
    onset_interval_emd,
    onset_intervals,
)
from musicagent.model import OfflineTransformer

In [None]:
%cd /content/models

wandb.login()
ARTIFACT_REF = "your-entity/your-project/best-model:latest"

CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

api = wandb.Api()
artifact = api.artifact(ARTIFACT_REF, type="model")
artifact_dir = Path(artifact.download(root=str(CHECKPOINT_DIR)))

pt_files = list(artifact_dir.rglob("*.pt"))
CHECKPOINT_PATH = pt_files[0]
print(CHECKPOINT_PATH)

In [None]:
# Eval config
BATCH_SIZE = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE = False
TEMPERATURE = 1.0

d_cfg = DataConfig()
m_cfg = ModelConfig()
m_cfg.device = DEVICE
device = torch.device(m_cfg.device)
print(device)

# Test split
test_ds = MusicAgentDataset(d_cfg, split="test")
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

id_to_melody = {v: k for k, v in test_ds.vocab_melody.items()}
id_to_chord = {v: k for k, v in test_ds.vocab_chord.items()}

def decode_tokens(ids, id_to_token):
    return [id_to_token.get(int(i), "<unk>") for i in ids]

# Model
vocab_src = len(test_ds.vocab_melody)
vocab_tgt = len(test_ds.vocab_chord)

model = OfflineTransformer(m_cfg, d_cfg, vocab_src, vocab_tgt).to(device)
state = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=True)
model.load_state_dict(state)
model.eval()

In [None]:
all_nic = []
all_pred_intervals, all_ref_intervals = [], []
all_pred_lengths, all_ref_lengths = [], []

with torch.no_grad():
    for src, tgt in test_loader:
        src, tgt = src.to(device), tgt.to(device)

        pred = model.generate(
            src,
            max_len=d_cfg.max_len,
            sos_id=d_cfg.sos_id,
            eos_id=d_cfg.eos_id,
            temperature=TEMPERATURE,
            sample=SAMPLE,
        )

        for i in range(src.size(0)):
            mel_ids = src[i].cpu().tolist()
            pred_ids = pred[i].cpu().tolist()
            ref_ids = tgt[i].cpu().tolist()

            mel_tokens = decode_tokens(mel_ids, id_to_melody)
            pred_tokens = decode_tokens(pred_ids, id_to_chord)
            ref_tokens = decode_tokens(ref_ids, id_to_chord)

            # NiC ratio
            nic = note_in_chord_ratio(mel_tokens, pred_tokens)
            all_nic.append(nic)

            # Onset intervals
            all_pred_intervals.extend(onset_intervals(mel_tokens, pred_tokens))
            all_ref_intervals.extend(onset_intervals(mel_tokens, ref_tokens))

            # Chord lengths
            all_pred_lengths.extend(chord_lengths(pred_tokens))
            all_ref_lengths.extend(chord_lengths(ref_tokens))

avg_nic = sum(all_nic) / len(all_nic) if all_nic else 0.0
emd = onset_interval_emd(all_pred_intervals, all_ref_intervals)
pred_entropy = chord_length_entropy(all_pred_lengths)
ref_entropy = chord_length_entropy(all_ref_lengths)

print(f"NiC Ratio:         {avg_nic * 100:.2f}%")
print(f"Onset Interval EMD:          {emd * 1e3:.2f} x 10^-3")
print(f"Chord Length Entropy (pred): {pred_entropy:.2f}")
print(f"Chord Length Entropy (ref):  {ref_entropy:.2f}")