In [None]:
from structure_derivation.model.model import StructureDerivationModel, StructureDerivationModelConfig
import os
import torch
import librosa

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINTS_DIR = "/keshav/musical_structure_metrics/structure_derivation/artifacts/structure_derivation_model/checkpoint/"

config = StructureDerivationModelConfig()
model = StructureDerivationModel(config)
model.to(device)

ckpt = torch.load(os.path.join(CHECKPOINTS_DIR, "checkpoint.pt"), map_location=device)
if "module" in ckpt["model"]:
    model.module.load_state_dict(ckpt["model"]["module"])  # `.module` because of DDP
else:
    model.load_state_dict(ckpt["model"])  # `.module` because of DDP

audio_paths = [
    '/mnt/data/music_reward/musiccaps/data/JDWPJ1AiDKc.wav',
    '/mnt/data/music_reward/musiccaps/data/JDWPJ1AiDKc.wav'
]

# Example input
audio, sr = librosa.load(audio_paths[0], sr=32000, mono=True)

audio_tensor = torch.tensor(audio).unsqueeze(0).to(device)  # (1, T)
model.eval()
with torch.no_grad():
    output = model(audio_tensor, infer_mode=True)

print(output['latent_output'].shape)  # Example output key

  from .autonotebook import tqdm as notebook_tqdm
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [1]:
import librosa
import os
import torch
import torch.nn.functional as F
from structure_derivation.model.model import StructureDerivationModel, StructureDerivationModelConfig

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

def load_model(ckpt_path, device):
    config = StructureDerivationModelConfig()
    model = StructureDerivationModel(config)
    model.to(device)

    ckpt = torch.load(ckpt_path, map_location=device)
    if "module" in ckpt["model"]:
        model.module.load_state_dict(ckpt["model"]["module"])  # DDP checkpoint
    else:
        model.load_state_dict(ckpt["model"])
    model.eval()
    return model

def split_audio(audio_path, segment_seconds=10, target_sr=32000):
    """Load audio and split into N non-overlapping segments of segment_seconds each."""
    audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
    segment_len = segment_seconds * target_sr
    total_len = len(audio)
    segments = []
    for start in range(0, total_len, segment_len):
        end = start + segment_len
        if end <= total_len:
            segments.append(audio[start:end])
    return segments, sr

def compute_embeddings(model, segments, device):
    """Pass each segment through the model to get latent_output embeddings."""
    embeddings = []
    for seg in segments:
        seg_tensor = torch.tensor(seg, dtype=torch.float32).unsqueeze(0).to(device)  # (1, T)
        with torch.no_grad():
            out = model(seg_tensor, infer_mode=True)
        embeddings.append(out["latent_output"])  # (1, D)
    return torch.cat(embeddings, dim=0)  # (N, D)

def compute_similarities(embeddings):
    """Cosine similarity between first segment and all others."""
    ref = embeddings[0].unsqueeze(0)  # (1, D)
    sims = F.cosine_similarity(ref, embeddings[1:], dim=1)  # (N-1,)
    return sims.cpu().numpy()

# ----------------- Usage -----------------
CHECKPOINTS_DIR = "/keshav/musical_structure_metrics/structure_derivation/artifacts/structure_derivation_model/checkpoint/"
ckpt_path = os.path.join(CHECKPOINTS_DIR, "checkpoint.pt")

model = load_model(ckpt_path, device)

audio_path = '/mnt/data/marble/mtg_jamendo/mtg-jamendo-dataset/data/raw_30s_audio/99/6699.mp3'
segments, sr = split_audio(audio_path, segment_seconds=10, target_sr=32000)
print(f"Split into {len(segments)} segments.")

embeddings = compute_embeddings(model, segments, device)
print("Embeddings shape:", embeddings.shape)  # (N, D)

similarities = compute_similarities(embeddings)
print("Cosine similarities with S1:", similarities)

# Average similarity
avg_sim = similarities.mean()
print("Average similarity with S1:", avg_sim)


  from .autonotebook import tqdm as notebook_tqdm
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Split into 31 segments of 5s each.
Embeddings shape: torch.Size([31, 768])
Cosine similarities with S1: [0.86841476 0.6153427  0.40927917 0.352422   0.5193686  0.36905724
 0.40563023 0.47217578 0.751181   0.50920165 0.46983656 0.2635488
 0.26710814 0.27426744 0.27518922 0.31274518 0.24528386 0.31121287
 0.42605114 0.40733093 0.36269268 0.34492853 0.24020256 0.38971105
 0.31441623 0.38691163 0.24710664 0.24856237 0.30020675 0.37190536]
Average similarity with S1: 0.39104307
