In [1]:
# Install dependencies
!pip install torchaudio qdrant-client --quiet

# Download metadata and audio
!wget -nc https://os.unil.cloud.switch.ch/fma/fma_small.zip -O fma_small.zip
!wget -nc https://os.unil.cloud.switch.ch/fma/fma_metadata.zip -O fma_metadata.zip

# Unzip (audio: fma_small, metadata: CSVs)
!unzip -q -n fma_small.zip -d ./fma_small
!unzip -q -n fma_metadata.zip -d ./fma_metadata

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/337.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.3/337.3 kB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[?25h--2025-08-25 07:37:43--  https://os.unil.cloud.switch.ch/fma/fma_small.zip
Resolving os.unil.cloud.switch.ch (os.unil.cloud.switch.ch)... 86.119.28.16, 2001:620:5ca1:201::214
Connecting to os.unil.cloud.switch.ch (os.unil.cloud.switch.ch)|86.119.28.16|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7679594875 (7.2G) [application/zip]
Saving to: ‘fma_small.zip’


2025-08-25 07:43:35 (20.9 MB/s) - ‘fma_small.zip’ saved [7679594875/7679594875]

--2025-08-25 07:43:35--  https://os.unil.cloud.switch.ch/fma/fma_metadata.zip
Resolving os.unil.cloud.switch.ch (os.unil.cloud.switch.ch)... 86.119.28.16, 2001:620:5ca1:201::214
Connecting to os.unil.cloud.switch.ch (os.unil.cloud.switch.ch)|86.119.28.16|:443... connected.
HTTP request s

In [3]:
import pandas as pd

tracks = pd.read_csv("fma_metadata/fma_metadata/tracks.csv", index_col=0, header=[0,1])
genres = pd.read_csv("fma_metadata/fma_metadata/genres.csv", index_col=0)

# Filter only small set
subset = tracks['set', 'subset'] == 'small'
small_tracks = tracks[subset]

print("Total FMA-Small tracks:", len(small_tracks))
small_tracks.head()

Total FMA-Small tracks: 8000


Unnamed: 0_level_0,album,album,album,album,album,album,album,album,album,album,...,track,track,track,track,track,track,track,track,track,track
Unnamed: 0_level_1,comments,date_created,date_released,engineer,favorites,id,information,listens,producer,tags,...,information,interest,language_code,license,listens,lyricist,number,publisher,tags,title
track_id,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2,0,2008-11-26 01:44:45,2009-01-05 00:00:00,,4,1,<p></p>,6073,,[],...,,4656,en,Attribution-NonCommercial-ShareAlike 3.0 Inter...,1293,,3,,[],Food
5,0,2008-11-26 01:44:45,2009-01-05 00:00:00,,4,1,<p></p>,6073,,[],...,,1933,en,Attribution-NonCommercial-ShareAlike 3.0 Inter...,1151,,6,,[],This World
10,0,2008-11-26 01:45:08,2008-02-06 00:00:00,,4,6,,47632,,[],...,,54881,en,Attribution-NonCommercial-NoDerivatives (aka M...,50135,,1,,[],Freeway
140,1,2008-11-26 01:49:59,2007-05-22 00:00:00,,1,61,<p>Alec K. Redfearn &amp; The Eyesores: Ellen ...,1300,"Alec K. Refearn, Rob Pemberton",[],...,,1593,en,Attribution-Noncommercial-No Derivative Works ...,1299,,2,,[],Queen Of The Wires
141,0,2008-11-26 01:49:57,2009-01-16 00:00:00,,1,60,"<p>A full ensamble of strings, drums, electron...",1304,,[],...,,839,en,Attribution-Noncommercial-No Derivative Works ...,725,,4,,[],Ohio


In [27]:
import os
import torch
import torchaudio
from tqdm import tqdm

DATASET_PATH = "fma_small/fma_small"
CACHE_PATH = "/content/fma_cache"
os.makedirs(CACHE_PATH, exist_ok=True)

# Parameters
clip_duration = 3       # seconds
sample_rate = 16000
n_mels = 64
n_samples = clip_duration * sample_rate

# Mel transform
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=1024,
    hop_length=512,
    n_mels=n_mels
)

# Gather all song paths
song_files = [p for p in glob.glob(os.path.join(DATASET_PATH, "*/*.mp3"))]
song_dirs = {f"song_{i:06d}": p for i, p in enumerate(song_files)}

# Precompute
for song_id, path in tqdm(song_dirs.items(), desc="Caching spectrograms"):
    cache_file = os.path.join(CACHE_PATH, f"{song_id}.pt")
    if os.path.exists(cache_file):
        continue  # skip if already cached

    try:
        waveform, sr = torchaudio.load(path)
    except Exception as e:
        print(f"Skipping {song_id}: {e}")
        continue  # skip problematic file
    if waveform.size(1) < 1000:
      print(f"Skipping too short file: {song_id}")
      continue
    if sr != sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
    if waveform.size(0) > 1:
        waveform = waveform.mean(dim=0, keepdim=True)  # mono

    # Take first clip_duration seconds (or pad if short)
    if waveform.size(1) < n_samples:
        pad = n_samples - waveform.size(1)
        waveform = torch.nn.functional.pad(waveform, (0, pad))
    else:
        waveform = waveform[:, :n_samples]

    mel = mel_transform(waveform)
    log_mel = torch.log1p(mel)
    log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-6)
    log_mel = log_mel.contiguous()

    torch.save(log_mel, cache_file)


Caching spectrograms:  62%|██████▏   | 4986/8000 [00:00<00:00, 43504.53it/s]

Skipping song_004984: Failed to open the input "fma_small/fma_small/099/099134.mp3" (Invalid argument).


Caching spectrograms:  73%|███████▎  | 5815/8000 [01:14<01:38, 22.18it/s]

Skipping song_005817: Failed to open the input "fma_small/fma_small/108/108925.mp3" (Invalid argument).


Caching spectrograms:  93%|█████████▎| 7464/8000 [03:41<00:41, 12.80it/s]

Skipping song_007461: Failed to open the input "fma_small/fma_small/133/133297.mp3" (Invalid argument).


Caching spectrograms: 100%|██████████| 8000/8000 [04:29<00:00, 29.72it/s]


In [28]:
import torch
from torch.utils.data import Dataset
import torchaudio
import random
import os

class SongTripletDatasetCached(Dataset):
    def __init__(self, cache_dir):
        self.cache_files = [os.path.join(cache_dir, f) for f in os.listdir(cache_dir)]
        self.song_ids = [os.path.basename(f).split(".pt")[0] for f in self.cache_files]

    def __len__(self):
        return 50000  # arbitrary for triplet sampling

    def __getitem__(self, idx):
        # --- Anchor ---
        anchor_id = random.choice(self.song_ids)
        anchor = torch.load(os.path.join(CACHE_PATH, f"{anchor_id}.pt"))

        # --- Positive (same song, different clip) ---
        # For simplicity, reuse same clip; could extend to multiple clips
        positive = torch.load(os.path.join(CACHE_PATH, f"{anchor_id}.pt"))

        # --- Negative (different song) ---
        neg_id = random.choice([s for s in self.song_ids if s != anchor_id])
        negative = torch.load(os.path.join(CACHE_PATH, f"{neg_id}.pt"))

        return anchor, positive, negative


In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
import random
import os

# ----------------------------
# Embedding Network
# ----------------------------
class AudioEmbeddingNet(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(128, embedding_dim)

    def forward(self, x):
        # x: (B, 1, n_mels, time)
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.normalize(x, p=2, dim=1)  # L2-normalized embeddings

# Triplet wrapper
class TripletNetwork(nn.Module):
    def __init__(self, embedding_net):
        super().__init__()
        self.embedding_net = embedding_net

    def forward(self, anchor, positive, negative):
        return (self.embedding_net(anchor),
                self.embedding_net(positive),
                self.embedding_net(negative))


In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import os

# 🔹 1. Train/Val split
def create_loaders(dataset, batch_size=32, val_split=0.2):
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    return train_loader, val_loader


# 🔹 2. Training loop with early stopping + scheduler
def train_model(model, train_loader, val_loader, n_epochs=50, patience=5, save_path="best_model.pt"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    criterion = nn.TripletMarginLoss(margin=1.0, p=2)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)

    best_val_loss = np.inf
    patience_counter = 0

    for epoch in range(1, n_epochs + 1):
        # --- Training ---
        model.train()
        total_train_loss = 0
        for anchor, positive, negative in train_loader:
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

            optimizer.zero_grad()
            anchor_out, positive_out, negative_out = model(anchor, positive, negative)
            loss = criterion(anchor_out, positive_out, negative_out)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # --- Validation ---
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for anchor, positive, negative in val_loader:
                anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
                anchor_out, positive_out, negative_out = model(anchor, positive, negative)
                val_loss = criterion(anchor_out, positive_out, negative_out)
                total_val_loss += val_loss.item()

        avg_val_loss = total_val_loss / len(val_loader)

        # 🔹 Scheduler step
        scheduler.step(avg_val_loss)

        print(f"Epoch {epoch:03d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f}")

        # --- Check early stopping ---
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)  # save best model
            print(f"  ✅ New best model saved (val_loss={best_val_loss:.4f})")
        else:
            patience_counter += 1
            print(f"  ⚠️ No improvement (patience {patience_counter}/{patience})")

        if patience_counter >= patience:
            print("⏹️ Early stopping triggered")
            break

    # Load best model before returning
    model.load_state_dict(torch.load(save_path))
    print("🔄 Best model reloaded from checkpoint")
    return model


In [31]:
import os
import glob
from torch.utils.data import DataLoader, random_split

dataset = SongTripletDatasetCached(CACHE_PATH)
val_split = 0.1
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

# Model
model = TripletNetwork(AudioEmbeddingNet(embedding_dim=128))

# Train
best_model = train_model(model, train_loader, val_loader, n_epochs=20, patience=5, save_path="triplet_best.pt")


Epoch 001 | Train Loss: 0.0111 | Val Loss: 0.0035 | LR: 0.001000
  ✅ New best model saved (val_loss=0.0035)
Epoch 002 | Train Loss: 0.0031 | Val Loss: 0.0025 | LR: 0.001000
  ✅ New best model saved (val_loss=0.0025)
Epoch 003 | Train Loss: 0.0020 | Val Loss: 0.0015 | LR: 0.001000
  ✅ New best model saved (val_loss=0.0015)
Epoch 004 | Train Loss: 0.0017 | Val Loss: 0.0025 | LR: 0.001000
  ⚠️ No improvement (patience 1/5)
Epoch 005 | Train Loss: 0.0013 | Val Loss: 0.0019 | LR: 0.001000
  ⚠️ No improvement (patience 2/5)
Epoch 006 | Train Loss: 0.0010 | Val Loss: 0.0013 | LR: 0.001000
  ✅ New best model saved (val_loss=0.0013)
Epoch 007 | Train Loss: 0.0009 | Val Loss: 0.0006 | LR: 0.001000
  ✅ New best model saved (val_loss=0.0006)
Epoch 008 | Train Loss: 0.0008 | Val Loss: 0.0009 | LR: 0.001000
  ⚠️ No improvement (patience 1/5)
Epoch 009 | Train Loss: 0.0008 | Val Loss: 0.0011 | LR: 0.001000
  ⚠️ No improvement (patience 2/5)
Epoch 010 | Train Loss: 0.0006 | Val Loss: 0.0005 | LR: 0.00

In [32]:
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np

# Parameters
SAMPLE_RATE = 22050
N_MELS = 128
CLIP_DURATION = 30  # seconds
N_SAMPLES = SAMPLE_RATE * CLIP_DURATION

# Transforms
mel_transform = T.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=1024,
    hop_length=512,
    n_mels=N_MELS
)
db_transform = T.AmplitudeToDB()

def preprocess_audio(path, device="cpu"):
    # Load audio
    waveform, sr = torchaudio.load(path)

    # Convert to mono
    if waveform.size(0) > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Resample if needed
    if sr != SAMPLE_RATE:
        waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)

    # Clip or pad
    if waveform.size(1) < N_SAMPLES:
        waveform = torch.nn.functional.pad(waveform, (0, N_SAMPLES - waveform.size(1)))
    else:
        waveform = waveform[:, :N_SAMPLES]

    # Mel-spectrogram
    mel_spec = mel_transform(waveform)
    log_mel_spec = db_transform(mel_spec)

    # Normalize
    log_mel_spec = (log_mel_spec - log_mel_spec.mean()) / (log_mel_spec.std() + 1e-6)

    return log_mel_spec.unsqueeze(0).to(device).contiguous()  # (1, 1, n_mels, time)

def embed_song(embedding_net, path, device="cpu"):
    embedding_net.eval()
    with torch.no_grad():
        spec = preprocess_audio(path, device)
        emb = embedding_net(spec)  # (1, embedding_dim)
        return emb.squeeze(0).cpu().numpy().astype(np.float32)


In [60]:
from qdrant_client import QdrantClient

# Connect to Qdrant
client = QdrantClient(":memory:")

# Delete collection
client.delete_collection(collection_name="songs")

print("✅ Collection 'songs' deleted")


✅ Collection 'songs' deleted


In [61]:
import os
import glob
from tqdm import tqdm
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest

# Qdrant client (local or cloud)
#client = QdrantClient(host="localhost", port=6333)
client = QdrantClient(":memory:")  # only works in current runtime, no network needed

# Pick first 500 MP3s from FMA-Small
dataset_path = "fma_small/fma_small"
song_files = sorted(glob.glob(os.path.join(dataset_path, "*/*.mp3")))[:1000]

# Create or reset collection
client.recreate_collection(
    collection_name="songs",
    vectors_config=rest.VectorParams(size=128, distance="Cosine")
)

embedding_net = AudioEmbeddingNet(embedding_dim=128)
# Load weights from TripletNetwork
state_dict = torch.load("triplet_best.pt", map_location="cpu")
# If keys are prefixed with 'embedding_net.', strip them
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if k.startswith("embedding_net."):
        new_state_dict[k.replace("embedding_net.", "")] = v

embedding_net.load_state_dict(new_state_dict)

# Upsert songs
points = []
for idx, path in enumerate(tqdm(song_files, desc="Indexing songs")):
    try:
        emb = embed_song(embedding_net, path, device="cpu")  # use embedding_net here
        emb = emb / np.linalg.norm(emb)
        points.append(
            rest.PointStruct(
                id=idx,
                vector=emb.tolist(),
                payload={"track": os.path.basename(path)}
            )
        )
    except Exception as e:
        print(f"❌ Failed {path}: {e}")

# Bulk insert
if points:
    client.upsert(
        collection_name="songs",
        points=points
    )

print(f"✅ Inserted {len(points)} songs into Qdrant")


  client.recreate_collection(
Indexing songs: 100%|██████████| 1000/1000 [02:01<00:00,  8.26it/s]


✅ Inserted 1000 songs into Qdrant


In [72]:
query_path = song_files[512]  # pick any
query_emb = embed_song(embedding_net, query_path)
query_emb = query_emb / np.linalg.norm(query_emb)

results = client.search(
    collection_name="songs",
    query_vector=query_emb.tolist(),
    limit=3
)

print("\n🔎 Query:", os.path.basename(query_path))
for hit in results:
    print(f"Match: {hit.payload['track']} (score={hit.score:.3f})")



🔎 Query: 011764.mp3
Match: 011764.mp3 (score=1.000)
Match: 010695.mp3 (score=0.576)
Match: 001069.mp3 (score=0.529)


  results = client.search(


In [74]:
from google.colab import drive
import shutil
import os

# 1️⃣ Mount Google Drive
drive.mount('/content/drive')

# 2️⃣ Set the path to your file in Drive
drive_file_path = "/content/drive/MyDrive/wide_open.mp3"  # adjust if needed
local_path = "/content/wide_open.mp3"

# 3️⃣ Copy to Colab working directory
shutil.copy(drive_file_path, local_path)

# 4️⃣ Check
assert os.path.exists(local_path), "File not found!"
print(f"✅ wide_open.mp3 copied to {local_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ wide_open.mp3 copied to /content/wide_open.mp3


In [76]:
import torchaudio
import torch
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct

song_path = "/content/wide_open.mp3"  # path to Beyoncé's Freedom
collection_name = "songs"

embedding_freedom = embed_song(embedding_net, song_path, device="cpu")
embedding_freedom = embedding_freedom / np.linalg.norm(embedding_freedom)  # normalize

# ------------------------------
# 3️⃣ Connect to Qdrant
# ------------------------------
qdrant_client = client

# Only create the collection if it doesn't exist
if not qdrant_client.collection_exists(collection_name=collection_name):
    qdrant_client.create_collection(
        collection_name=collection_name,
        vectors_config={"size": 128, "distance": "Cosine"}
    )

# ------------------------------
# 4️⃣ Upsert Beyoncé's Freedom
# ------------------------------
qdrant_client.upsert(
    collection_name=collection_name,
    points=[
        PointStruct(
            id=100001,  # unique ID
            vector=embedding_freedom.tolist(),
            payload={"track": "CB - Wide Open"}
        )
    ]
)

print("✅ Inserted 'CB - Wide Open' into Qdrant")


✅ Inserted 'CB - Wide Open' into Qdrant


In [77]:
query_path = "/content/wide_open.mp3"  # pick any
query_emb = embed_song(embedding_net, query_path)
query_emb = query_emb / np.linalg.norm(query_emb)

results = client.search(
    collection_name="songs",
    query_vector=query_emb.tolist(),
    limit=3
)

print("\n🔎 Query:", os.path.basename(query_path))
for hit in results:
    print(f"Match: {hit.payload['track']} (score={hit.score:.3f})")


🔎 Query: wide_open.mp3
Match: CB - Wide Open (score=1.000)
Match: 004848.mp3 (score=0.642)
Match: 012518.mp3 (score=0.629)


  results = client.search(
