In [None]:
import sys
from pathlib import Path

import torch
import wandb
from google.colab import files
from torch.utils.data import DataLoader

uploaded = files.upload()

project_root = Path.cwd()
src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))

from musicagent.config import OfflineConfig
from musicagent.data import OfflineDataset, make_offline_collate_fn

# Evaluation - now using the unified evaluate_offline function
from musicagent.eval import OfflineEvalResult, evaluate_offline
from musicagent.models import OfflineTransformer
from musicagent.utils import load_configs_from_dir

In [None]:
%cd /content/models

wandb.login()
ARTIFACT_REF = "marty1ai/musicagent/best-model:v50"

CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

api = wandb.Api()
artifact = api.artifact(ARTIFACT_REF, type="model")
artifact_dir = Path(artifact.download(root=str(CHECKPOINT_DIR)))
CHECKPOINT_PATH = artifact_dir / "best_model.pt"
print(f"Artifact downloaded to: {artifact_dir}")

/content/models


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdrewtaylor[0m ([33mmarty1ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact 'best-model:v50', 275.44MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:18.5 (14.9MB/s)


checkpoints/best_model.pt


In [None]:
# Eval config
BATCH_SIZE = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE = False
TEMPERATURE = 1.0

# Load configs from wandb artifact
d_cfg, m_cfg = load_configs_from_dir(artifact_dir, OfflineConfig)

m_cfg.device = DEVICE
device = torch.device(m_cfg.device)
print(f"Device: {device}")

# Test split
test_ds = OfflineDataset(d_cfg, split="test")
collate_fn = make_offline_collate_fn(pad_id=d_cfg.pad_id)
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

id_to_melody = {v: k for k, v in test_ds.vocab_melody.items()}
id_to_chord = {v: k for k, v in test_ds.vocab_chord.items()}

# Model
vocab_src = len(test_ds.vocab_melody)
vocab_tgt = len(test_ds.vocab_chord)

model = OfflineTransformer(m_cfg, d_cfg, vocab_src, vocab_tgt).to(device)
state = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=True)
model.load_state_dict(state)
model.eval()

print(f"\nModel loaded: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Test set: {len(test_ds)} sequences")

cuda




OfflineTransformer(
  (src_embed): Embedding(204, 512, padding_idx=0)
  (tgt_embed): Embedding(12772, 512, padding_idx=0)
  (pos_enc): PositionalEncoding()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-7): 8 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (de

In [None]:
# Run evaluation
result: OfflineEvalResult = evaluate_offline(
    model=model,
    test_loader=test_loader,
    d_cfg=d_cfg,
    id_to_melody=id_to_melody,
    id_to_chord=id_to_chord,
    device=device,
    temperature=TEMPERATURE,
    sample=SAMPLE,
)

print(f"\n{'='*50}")
print("Results")
print(f"{'='*50}")
print(f"NiC Ratio:            {result.nic_ratio * 100:.2f}% ± {result.nic_std * 100:.2f}%")
print(f"Onset Interval EMD:   {result.onset_interval_emd * 1e3:.2f} × 10⁻³")
print(f"Chord Length Entropy: {result.pred_entropy:.2f} (ref: {result.ref_entropy:.2f})")
print(f"Total sequences:      {result.num_sequences:,}")

NiC Ratio:         60.13%
Onset Interval EMD:          149.75 x 10^-3
Chord Length Entropy (pred): 1.25
Chord Length Entropy (ref):  2.30
