## Baseline models

This notebook defines simple model-free baselines for the ReaLchords dataset and evaluates them
using the same offline metrics as the trained model.

Baselines:
- Global majority-chord predictor.
- Last-chord persistence predictor.



In [1]:
import sys
from collections import Counter
from pathlib import Path

import numpy as np
from torch.utils.data import DataLoader

# Locate project root (directory containing pyproject.toml or src/musicagent)
project_root = Path.cwd()
while project_root != project_root.parent:
    if (project_root / "pyproject.toml").exists() or (project_root / "src" / "musicagent").exists():
        break
    project_root = project_root.parent

print("Project root:", project_root)

src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))

from musicagent.config import DataConfig
from musicagent.dataset import MusicAgentDataset, collate_fn
from musicagent.eval.metrics import (
    chord_length_entropy,
    chord_lengths,
    note_in_chord_ratio,
    onset_interval_emd,
    onset_intervals,
)

# Make data_processed absolute so it works regardless of notebook CWD
d_cfg = DataConfig()
if not d_cfg.data_processed.is_absolute():
    d_cfg.data_processed = project_root / d_cfg.data_processed

print("Processed data dir:", d_cfg.data_processed)

test_ds = MusicAgentDataset(d_cfg, split="test")
print("Test examples:", len(test_ds))

test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)

id_to_melody = {v: k for k, v in test_ds.vocab_melody.items()}
id_to_chord = {v: k for k, v in test_ds.vocab_chord.items()}

def decode_tokens(ids, id_to_token):
    return [id_to_token.get(int(i), "<unk>") for i in ids]



Project root: /Users/drewtaylor/data
Processed data dir: /Users/drewtaylor/data/realchords_data
Test examples: 2760


In [2]:
# Compute reference metrics (ground-truth chords) for comparison

all_nic_ref = []
all_ref_intervals = []
all_ref_lengths = []

for src, tgt in test_loader:
    for i in range(src.size(0)):
        mel_ids = src[i].cpu().tolist()
        ref_ids = tgt[i].cpu().tolist()

        mel_tokens = decode_tokens(mel_ids, id_to_melody)
        ref_tokens = decode_tokens(ref_ids, id_to_chord)

        all_nic_ref.append(note_in_chord_ratio(mel_tokens, ref_tokens))
        all_ref_intervals.extend(onset_intervals(mel_tokens, ref_tokens))
        all_ref_lengths.extend(chord_lengths(ref_tokens))

avg_nic_ref = sum(all_nic_ref) / len(all_nic_ref) if all_nic_ref else 0.0
ref_entropy = chord_length_entropy(all_ref_lengths)

print(f"Reference NiC Ratio:         {avg_nic_ref * 100:.2f}%")
print("Reference Chord Length Entropy:", ref_entropy)



Reference NiC Ratio:         65.31%
Reference Chord Length Entropy: 2.296499290802694


In [3]:
# Global majority-chord baseline (excluding special tokens)

special_ids = {d_cfg.pad_id, d_cfg.sos_id, d_cfg.eos_id, d_cfg.rest_id}

all_chords = []
for idx in range(len(test_ds)):
    tgt_arr = np.array(test_ds.tgt_data[idx])
    # Only count "real" chord IDs when choosing the majority chord.
    all_chords.extend(int(tid) for tid in tgt_arr.tolist() if int(tid) not in special_ids)

chord_counts = Counter(all_chords)
most_common_id, most_common_count = chord_counts.most_common(1)[0]
most_common_token = id_to_chord.get(int(most_common_id), "<unk>")

print("Most common non-special chord token:", most_common_token, "(count =", most_common_count, ")")


def predict_majority_like_ref(ref_ids, majority_id, eos_id):
    """Predict the majority chord wherever ref is not EOS; preserve EOS positions.

    This baseline deliberately predicts a single, common *real* chord token.
    """
    result = []
    for tid in ref_ids:
        tid = int(tid)
        if tid == int(eos_id):
            result.append(tid)
        else:
            result.append(int(majority_id))
    return result



Most common non-special chord token: C:4-3/0_hold (count = 25665 )


In [4]:
# Evaluate majority-chord baseline

all_nic_major = []
all_pred_intervals_major = []
all_pred_lengths_major = []

for src, tgt in test_loader:
    for i in range(src.size(0)):
        mel_ids = src[i].cpu().tolist()
        ref_ids = tgt[i].cpu().tolist()

        pred_ids = predict_majority_like_ref(ref_ids, most_common_id, d_cfg.eos_id)

        mel_tokens = decode_tokens(mel_ids, id_to_melody)
        pred_tokens = decode_tokens(pred_ids, id_to_chord)

        all_nic_major.append(note_in_chord_ratio(mel_tokens, pred_tokens))
        all_pred_intervals_major.extend(onset_intervals(mel_tokens, pred_tokens))
        all_pred_lengths_major.extend(chord_lengths(pred_tokens))

avg_nic_major = sum(all_nic_major) / len(all_nic_major) if all_nic_major else 0.0
emd_major = onset_interval_emd(all_pred_intervals_major, all_ref_intervals)
entropy_major = chord_length_entropy(all_pred_lengths_major)

print(f"Majority baseline NiC:         {avg_nic_major * 100:.2f}%")
print("Majority baseline Onset EMD:   ", emd_major)
print("Majority baseline Chord Ent.:  ", entropy_major)



Majority baseline NiC:         27.46%
Majority baseline Onset EMD:    16.318362787500355
Majority baseline Chord Ent.:   0.0


In [5]:
# Last-chord persistence baseline (ignoring special tokens for persistence)


def predict_last_chord_persistence(ref_ids, eos_id, special_ids):
    """Repeat the last seen non-special, non-EOS chord ID; keep EOS/specials where they appear.

    If no chord has been seen yet, use the current token as-is.
    """
    result = []
    last_chord = None
    for tid in ref_ids:
        tid = int(tid)
        if tid == int(eos_id):
            # Preserve EOS markers exactly.
            result.append(tid)
            continue
        if tid in special_ids:
            # Keep special tokens (pad/rest/sos) but do not update last_chord.
            result.append(tid)
            continue
        # Real chord token
        if last_chord is None:
            last_chord = tid
            result.append(tid)
        else:
            result.append(last_chord)
    return result


all_nic_last = []
all_pred_intervals_last = []
all_pred_lengths_last = []

for src, tgt in test_loader:
    for i in range(src.size(0)):
        mel_ids = src[i].cpu().tolist()
        ref_ids = tgt[i].cpu().tolist()

        pred_ids = predict_last_chord_persistence(ref_ids, d_cfg.eos_id, special_ids)

        mel_tokens = decode_tokens(mel_ids, id_to_melody)
        pred_tokens = decode_tokens(pred_ids, id_to_chord)

        all_nic_last.append(note_in_chord_ratio(mel_tokens, pred_tokens))
        all_pred_intervals_last.extend(onset_intervals(mel_tokens, pred_tokens))
        all_pred_lengths_last.extend(chord_lengths(pred_tokens))

avg_nic_last = sum(all_nic_last) / len(all_nic_last) if all_nic_last else 0.0
emd_last = onset_interval_emd(all_pred_intervals_last, all_ref_intervals)
entropy_last = chord_length_entropy(all_pred_lengths_last)

print(f"Last-chord baseline NiC:         {avg_nic_last * 100:.2f}%")
print("Last-chord baseline Onset EMD:   ", emd_last)
print("Last-chord baseline Chord Ent.:  ", entropy_last)



Last-chord baseline NiC:         50.67%
Last-chord baseline Onset EMD:    1.6636751333136883
Last-chord baseline Chord Ent.:   0.0


## Next steps

- Compare these baselines directly against your trained model metrics from `offline.ipynb`.
- Add additional baselines (e.g., simple rule-based harmonic functions, key-aware majority chords).

