In [None]:
import sys
from pathlib import Path

import wandb
from google.colab import files

uploaded = files.upload()

project_root = Path.cwd()
src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))

from musicagent.utils import (
    create_test_loader,
    download_wandb_artifact,
    get_model_registry,
    load_model_from_artifact,
)

MODEL_TYPE = "offline"
reg = get_model_registry(MODEL_TYPE)

In [None]:
%cd /content/models

wandb.login()
ARTIFACT_REF = "marty1ai/musicagent/best-model:v50"

paths = download_wandb_artifact(ARTIFACT_REF, download_dir="checkpoints")
artifact_dir = paths.artifact_dir
CHECKPOINT_PATH = paths.checkpoint_path
print(f"Artifact downloaded to: {artifact_dir}")

/content/models


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdrewtaylor[0m ([33mmarty1ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact 'best-model:v50', 275.44MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:18.5 (14.9MB/s)


checkpoints/best_model.pt


In [None]:
# Eval config
BATCH_SIZE = 128
SAMPLE = False
TEMPERATURE = 1.0

# Load model from artifact
loaded = load_model_from_artifact(artifact_dir, CHECKPOINT_PATH, MODEL_TYPE)
model, d_cfg, device = loaded.model, loaded.d_cfg, loaded.device
print(f"Device: {device}")

# Create test dataloader
loader_result = create_test_loader(d_cfg, MODEL_TYPE, batch_size=BATCH_SIZE)
test_loader = loader_result.test_loader
id_to_melody = loader_result.id_to_melody
id_to_chord = loader_result.id_to_chord

print(f"\nModel loaded: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Test set: {len(loader_result.test_dataset)} sequences")

cuda




OfflineTransformer(
  (src_embed): Embedding(204, 512, padding_idx=0)
  (tgt_embed): Embedding(12772, 512, padding_idx=0)
  (pos_enc): PositionalEncoding()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-7): 8 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (de

In [None]:
# Run evaluation
result = reg.eval_func(
    model=model,
    test_loader=test_loader,
    d_cfg=d_cfg,
    id_to_melody=id_to_melody,
    id_to_chord=id_to_chord,
    device=device,
    temperature=TEMPERATURE,
    sample=SAMPLE,
)

print(f"\n{'=' * 50}")
print("Results")
print(f"{'=' * 50}")
print(f"NiC Ratio:            {result.nic_ratio * 100:.2f}% ± {result.nic_std * 100:.2f}%")
print(f"Onset Interval EMD:   {result.onset_interval_emd * 1e3:.2f} × 10⁻³")
pred_ent = result.pred_chord_length_entropy
ref_ent = result.ref_chord_length_entropy
print(f"Chord Length Entropy: {pred_ent:.2f} (ref: {ref_ent:.2f})")
print(f"Total sequences:      {result.num_sequences:,}")

NiC Ratio:         60.13%
Onset Interval EMD:          149.75 x 10^-3
Chord Length Entropy (pred): 1.25
Chord Length Entropy (ref):  2.30
