In [1]:
import sys
from pathlib import Path

from torch.utils.data import DataLoader

%cd ..
project_root = Path.cwd()
src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))

from musicagent.config import DataConfig
from musicagent.dataset import MusicAgentDataset, collate_fn
from musicagent.eval.metrics import (
    chord_length_entropy,
    chord_lengths,
    note_in_chord_ratio,
    onset_interval_emd,
    onset_intervals,
)

d_cfg = DataConfig()
if not d_cfg.data_processed.is_absolute():
    d_cfg.data_processed = project_root / d_cfg.data_processed

test_ds = MusicAgentDataset(d_cfg, split="test")

test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)

id_to_melody = {v: k for k, v in test_ds.vocab_melody.items()}
id_to_chord = {v: k for k, v in test_ds.vocab_chord.items()}

def decode_tokens(ids, id_to_token):
    return [id_to_token.get(int(i), "<unk>") for i in ids]

/Users/drewtaylor/data


In [2]:
# Reference metrics

all_nic_ref = []
all_ref_intervals = []
all_ref_lengths = []

for src, tgt in test_loader:
    for i in range(src.size(0)):
        mel_ids = src[i].cpu().tolist()
        ref_ids = tgt[i].cpu().tolist()

        mel_tokens = decode_tokens(mel_ids, id_to_melody)
        ref_tokens = decode_tokens(ref_ids, id_to_chord)

        # NiC
        all_nic_ref.append(note_in_chord_ratio(mel_tokens, ref_tokens))

        all_ref_intervals.extend(onset_intervals(mel_tokens, ref_tokens))
        all_ref_lengths.extend(chord_lengths(ref_tokens))

avg_nic_ref = sum(all_nic_ref) / len(all_nic_ref) if all_nic_ref else 0.0
ref_emd_self = onset_interval_emd(all_ref_intervals, all_ref_intervals)
ref_entropy = chord_length_entropy(all_ref_lengths)

print(f"Reference NiC:            {avg_nic_ref * 100:.2f}%")
print("Reference Onset Interval EMD:", ref_emd_self)
print("Reference Chord Length Entropy:", ref_entropy)



Reference NiC:            65.31%
Reference Onset Interval EMD: 0.0
Reference Chord Length Entropy: 2.296499290802694
