In [None]:
import sys
from pathlib import Path

import torch
import wandb
from torch.utils.data import DataLoader
from google.colab import files

uploaded = files.upload()

project_root = Path.cwd()
src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))

from musicagent.config import DataConfig, OnlineConfig
from musicagent.data import OfflineDataset, collate_fn
from musicagent.models import OnlineTransformer

# All evaluation utilities from the shared eval module
from musicagent.eval import (
    chord_length_entropy,
    chord_lengths,
    decode_tokens,
    note_in_chord_at_beat,
    note_in_chord_ratio,
    onset_interval_emd,
    onset_intervals,
)


In [None]:
%cd /content/models

wandb.login()
ARTIFACT_REF = "marty1ai/musicagent/best-model:v50"

CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

api = wandb.Api()
artifact = api.artifact(ARTIFACT_REF, type="model")
artifact_dir = Path(artifact.download(root=str(CHECKPOINT_DIR)))

pt_files = list(artifact_dir.rglob("*.pt"))
CHECKPOINT_PATH = pt_files[0]
print(CHECKPOINT_PATH)

In [None]:
# Evaluation config
BATCH_SIZE = 32  # Smaller batch since online generation is per-sample
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE = False
TEMPERATURE = 1.0

d_cfg = DataConfig()
m_cfg = OnlineConfig()
m_cfg.device = DEVICE
device = torch.device(m_cfg.device)
print(f"Device: {device}")

# Test split - use OfflineDataset since we need separate melody/chord for metrics
test_ds = OfflineDataset(d_cfg, split="test")
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

# Reverse vocab mappings for decoding
id_to_melody = {v: k for k, v in test_ds.vocab_melody.items()}
id_to_chord = {v: k for k, v in test_ds.vocab_chord.items()}

# Model
melody_vocab_size = len(test_ds.vocab_melody)
chord_vocab_size = len(test_ds.vocab_chord)

model = OnlineTransformer(m_cfg, d_cfg, melody_vocab_size, chord_vocab_size).to(device)
state = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=True)
model.load_state_dict(state)
model.eval()

print(f"\nModel loaded: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Melody vocab: {melody_vocab_size}, Chord vocab: {chord_vocab_size}")


In [None]:
all_nic = []
all_pred_intervals, all_ref_intervals = [], []
all_pred_lengths, all_ref_lengths = [], []

num_batches = len(test_loader)
print(f"Evaluating {num_batches} batches...")

with torch.no_grad():
    for batch_idx, (src, tgt) in enumerate(test_loader):
        src = src.to(device)
        batch_size = src.size(0)
        
        for i in range(batch_size):
            mel_seq = src[i]
            
            # Strip SOS/EOS/PAD - online model expects only frame tokens
            eos_positions = (mel_seq == d_cfg.eos_id).nonzero(as_tuple=True)[0]
            eos_idx = int(eos_positions[0].item()) if len(eos_positions) > 0 else mel_seq.size(0)
            melody_frames = mel_seq[1:eos_idx].unsqueeze(0)
            
            # Generate chord predictions
            pred_i = model.generate(
                melody_frames,
                sos_id=d_cfg.sos_id,
                eos_id=d_cfg.eos_id,
                temperature=TEMPERATURE,
                sample=SAMPLE,
            )[0]
            
            mel_tokens = decode_tokens(mel_seq.cpu().tolist(), id_to_melody)
            pred_tokens = decode_tokens(pred_i.cpu().tolist(), id_to_chord)
            ref_tokens = decode_tokens(tgt[i].cpu().tolist(), id_to_chord)
            
            all_nic.append(note_in_chord_ratio(mel_tokens, pred_tokens))
            all_pred_intervals.extend(onset_intervals(mel_tokens, pred_tokens))
            all_ref_intervals.extend(onset_intervals(mel_tokens, ref_tokens))
            all_pred_lengths.extend(chord_lengths(pred_tokens))
            all_ref_lengths.extend(chord_lengths(ref_tokens))
        
        if (batch_idx + 1) % 10 == 0 or batch_idx == num_batches - 1:
            print(f"  Processed {batch_idx + 1}/{num_batches} batches")

# Compute aggregate metrics
avg_nic = sum(all_nic) / len(all_nic) if all_nic else 0.0
emd = onset_interval_emd(all_pred_intervals, all_ref_intervals)
pred_entropy = chord_length_entropy(all_pred_lengths)
ref_entropy = chord_length_entropy(all_ref_lengths)

print(f"NiC Ratio:         {avg_nic * 100:.2f}%")
print(f"Onset Interval EMD:          {emd * 1e3:.2f} × 10⁻³")
print(f"Chord Length Entropy (pred): {pred_entropy:.2f}")
print(f"Chord Length Entropy (ref):  {ref_entropy:.2f}")

In [None]:
# Compute per-beat NiC for a subset of samples
MAX_BEATS = 32
NUM_SAMPLES = min(100, len(test_ds))

beat_nic_all = {b: [] for b in range(MAX_BEATS)}

print(f"Analyzing adaptation dynamics on {NUM_SAMPLES} samples...")

with torch.no_grad():
    for sample_idx in range(NUM_SAMPLES):
        src, tgt = test_ds[sample_idx]
        src = src.unsqueeze(0).to(device)
        
        mel_seq = src[0]
        eos_positions = (mel_seq == d_cfg.eos_id).nonzero(as_tuple=True)[0]
        eos_idx = int(eos_positions[0].item()) if len(eos_positions) > 0 else mel_seq.size(0)
        melody_frames = mel_seq[1:eos_idx].unsqueeze(0)
        
        pred = model.generate(
            melody_frames,
            sos_id=d_cfg.sos_id,
            eos_id=d_cfg.eos_id,
            temperature=TEMPERATURE,
            sample=SAMPLE,
        )[0]
        
        mel_tokens = decode_tokens(mel_seq.cpu().tolist(), id_to_melody)
        pred_tokens = decode_tokens(pred.cpu().tolist(), id_to_chord)
        
        # Use the shared note_in_chord_at_beat function
        beat_nic = note_in_chord_at_beat(mel_tokens, pred_tokens)
        
        for beat, nic in beat_nic.items():
            if beat < MAX_BEATS and nic is not None:
                beat_nic_all[beat].append(nic)
        
        if (sample_idx + 1) % 25 == 0:
            print(f"  Processed {sample_idx + 1}/{NUM_SAMPLES} samples")

# Compute mean and std per beat
beat_means, beat_stds, valid_beats = [], [], []
for beat in range(MAX_BEATS):
    if len(beat_nic_all[beat]) > 5:
        valid_beats.append(beat)
        beat_means.append(np.mean(beat_nic_all[beat]))
        beat_stds.append(np.std(beat_nic_all[beat]))

print(f"\nComputed dynamics for beats 0-{max(valid_beats) if valid_beats else 0}")


In [None]:
# Plot adaptation dynamics
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
if valid_beats:
    plt.plot(valid_beats, beat_means, 'o-', color='#2ecc71', linewidth=2, markersize=4)
    plt.fill_between(
        valid_beats,
        [m - s for m, s in zip(beat_means, beat_stds)],
        [m + s for m, s in zip(beat_means, beat_stds)],
        alpha=0.3, color='#2ecc71'
    )
plt.axhline(y=avg_nic, color='#e74c3c', linestyle='--', label=f'Overall NiC ({avg_nic:.2%})')
plt.xlabel('Beat', fontsize=12)
plt.ylabel('Note-in-Chord Ratio', fontsize=12)
plt.title('Adaptation Dynamics: Cold Start', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0, 1)

# Histogram of chord lengths
plt.subplot(1, 2, 2)
max_len = 32
pred_hist = [min(l, max_len) for l in all_pred_lengths]
ref_hist = [min(l, max_len) for l in all_ref_lengths]

plt.hist(ref_hist, bins=range(1, max_len + 2), alpha=0.5, label='Reference', color='#3498db')
plt.hist(pred_hist, bins=range(1, max_len + 2), alpha=0.5, label='Predicted', color='#e74c3c')
plt.xlabel('Chord Length (frames)', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.title('Chord Length Distribution', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
def show_example(idx: int, max_tokens: int = 32):
    """Display a single example generation."""
    src, tgt = test_ds[idx]
    src = src.unsqueeze(0).to(device)
    
    mel_seq = src[0]
    eos_positions = (mel_seq == d_cfg.eos_id).nonzero(as_tuple=True)[0]
    eos_idx = int(eos_positions[0].item()) if len(eos_positions) > 0 else mel_seq.size(0)
    melody_frames = mel_seq[1:eos_idx].unsqueeze(0)
    
    with torch.no_grad():
        pred = model.generate(
            melody_frames,
            sos_id=d_cfg.sos_id,
            eos_id=d_cfg.eos_id,
            temperature=TEMPERATURE,
            sample=SAMPLE,
        )[0]
    
    mel_tokens = decode_tokens(mel_seq[1:eos_idx].cpu().tolist(), id_to_melody)
    pred_tokens = decode_tokens(pred.cpu().tolist(), id_to_chord)
    ref_tokens = decode_tokens(tgt[1:].cpu().tolist(), id_to_chord)
    
    nic = note_in_chord_ratio(mel_tokens, pred_tokens)
    
    print(f"\n{'='*60}")
    print(f"Example {idx} | NiC: {nic:.2%}")
    print(f"{'='*60}")
    print(f"\nMelody (first {max_tokens} frames):")
    print(" ".join(mel_tokens[:max_tokens]))
    print(f"\nPredicted Chords:")
    print(" ".join(pred_tokens[:max_tokens]))
    print(f"\nReference Chords:")
    ref_filtered = [t for t in ref_tokens[:max_tokens] if not t.startswith('<')]
    print(" ".join(ref_filtered))

# Show 3 examples
for i in [0, 50, 100]:
    if i < len(test_ds):
        show_example(i)
