# Libraries

In [1]:
import os
import time
import torch
import torchaudio
import subprocess
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, Subset
from sklearn.model_selection import train_test_split
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from sklearn.metrics._plot.confusion_matrix import ConfusionMatrixDisplay
from torch_audiomentations import AddBackgroundNoise, AddColoredNoise, Gain, PitchShift

# Constants

In [2]:
CACHE_DIR = "./cache"
SONGS_DIR = f"{CACHE_DIR}/songs"
SEGMENTS_DIR = f"{CACHE_DIR}/segments"  
SPECTOGRAMS_DIR = f"{CACHE_DIR}/spectograms"

ASSETS_DIR = "./assets"
BACKGROUND_NOISE_DIR = f"{ASSETS_DIR}/background_noises"

DATASET_FILE = f"{CACHE_DIR}/dataset.csv"
EMBEDDINGS_FILE = f"{CACHE_DIR}/embeddings.npz" 
TRAINED_MODEL_FILE = f"{CACHE_DIR}/trained_model.pt"

# SPOTIFY_PLAYLIST_URL="https://open.spotify.com/playlist/1Y0Qk1K1DEMXeKgvjjnN7m?si=80a2a297dded480b" # 10 songs
# SPOTIFY_PLAYLIST_URL="https://open.spotify.com/playlist/34NbomaTu7YuOYnky8nLXL?si=4bf54104cf4c480c" # 50 songs
SPOTIFY_PLAYLIST_URL="https://open.spotify.com/playlist/0sDahzOkMWOmLXfTMf2N4N?si=94c4a9796f934a34" # 100 songs
# SPOTIFY_PLAYLIST_URL="https://open.spotify.com/playlist/0MK9YSzaQn2D0fsuIHa94B?si=c2fb987f3b94496e" # 306 songs

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True  # Optimize for consistent input sizes
torch.backends.cuda.matmul.allow_tf32 = True  # Allow TF32 for faster training
torch.backends.cudnn.allow_tf32 = True

TARGET_SAMPLE_RATE = 16000
SEGMENT_DURATION = 3 # seconds

# CNN model parameters  
BATCH_SIZE = 128
NUM_WORKERS = 0
SHUFFLE = True

NUM_TESTS = 100

device = "cuda" if torch.cuda.is_available() else "cpu"

# Create needed directories
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(SONGS_DIR, exist_ok=True)
os.makedirs(SPECTOGRAMS_DIR, exist_ok=True)
os.makedirs(SEGMENTS_DIR, exist_ok=True)
os.makedirs(ASSETS_DIR, exist_ok=True)  
os.makedirs(BACKGROUND_NOISE_DIR, exist_ok=True)  

In [3]:
if (len(os.listdir(BACKGROUND_NOISE_DIR)) == 0):
  raise Exception(f"Please download background noise files to {BACKGROUND_NOISE_DIR} directory.")

# Download Songs

In [4]:
if len(os.listdir(SONGS_DIR)) == 0:
  print(f"Songs directory is empty, downloading songs from {SPOTIFY_PLAYLIST_URL}")

  # Download songs
  subprocess.run(
    ["spotdl", SPOTIFY_PLAYLIST_URL, "--bitrate", "96k"],
    check=True,
    cwd=SONGS_DIR,
  )
else:
  print(f"The songs directory is not empty, skipping download.")

def is_int(s):
  try:
    int(s)
    return True
  except ValueError:
    return False

# Check if the songs dir already has the names changed
if not (all(is_int(song.split(".")[0]) for song in os.listdir(SONGS_DIR))): 
  print(f"Renaming songs in {SONGS_DIR} to integers")
  for i, song in enumerate(tqdm(os.listdir(SONGS_DIR), desc="Renaming songs", unit="song")):
    song_path = os.path.join(SONGS_DIR, song)
    song_extension = song.split(".")[-1]

    file_path = os.path.join(SONGS_DIR, f"{i}.{song_extension}")

    os.rename(
      song_path,
      file_path,
    )
else: 
  print(f"The songs directory already has the names changed, skipping renaming.")

The songs directory is not empty, skipping download.
The songs directory already has the names changed, skipping renaming.


# Extract features

In [5]:
mel_spec_transform = torch.nn.Sequential(
    MelSpectrogram(n_fft=2048, hop_length=512, n_mels=128, f_min=20, f_max=8000),
    AmplitudeToDB(),
  ).to(device)

def get_spectogram(waveform, as_numpy=True):
  with torch.no_grad(): 
    mel_spec = mel_spec_transform(waveform)
  mel_spec = mel_spec.to(device)  # Move to GPU if available  

  if as_numpy:
    return mel_spec.cpu().detach().numpy()

  return mel_spec 

def get_waveform_n_sr_from_file(file_path):
  waveform, sr = torchaudio.load(file_path, normalize=True)

  if sr != TARGET_SAMPLE_RATE:  
    resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SAMPLE_RATE)
    waveform = resampler(waveform)

  waveform = waveform.mean(dim=0, keepdim=True)  # Convert to mono
  waveform = waveform.unsqueeze(0)  
  waveform = waveform.to(device)  # Move to GPU if available

  return waveform, TARGET_SAMPLE_RATE

def get_global_mean_std(files):
  sum_        = 0.0
  sum_sq      = 0.0
  count       = 0

  for file in files:
    file_path = os.path.join(SONGS_DIR, file) 
    waveform, sr = get_waveform_n_sr_from_file(file_path) 
    spec_np = get_spectogram(waveform, as_numpy=False)

    vals = spec_np.flatten()

    sum_ += vals.sum()
    sum_sq += (vals ** 2).sum()
    count += vals.numel()

  mean_global = sum_ / count
  var_global  = sum_sq / count - mean_global**2
  std_global  = torch.sqrt(var_global)

  return mean_global.cpu().detach().numpy(), std_global.cpu().detach().numpy()

def save_spectogram(segment, song_id, chunk_id, aug_name, variation=0):
    file_name = f"{song_id}_{chunk_id}_{aug_name}_{variation}.npz"

    # Store as compressed file
    np.savez_compressed(
      os.path.join(SPECTOGRAMS_DIR, file_name),
      data=segment
    )

    return {
      "song_id": song_id,
      "file_name": file_name,
      "chunk_id": chunk_id,
      "aug_name": aug_name,
      "variation": variation,
    }

global_mean, global_std = get_global_mean_std(os.listdir(SONGS_DIR))

def extract_features():
  records = []

  for song in tqdm(os.listdir(SONGS_DIR), desc="Extracting features of the songs", unit="song"):
    song_path = os.path.join(SONGS_DIR, song)
    song_id = os.path.splitext(song)[0]

    waveform, sr = get_waveform_n_sr_from_file(song_path)

    segments_samples = int(SEGMENT_DURATION * sr)
    total_samples = waveform.shape[2]

    hop_length = segments_samples // 2
    chunk_id = 0
    for start in range(0, total_samples - segments_samples + 1, hop_length):

      if start + segments_samples > total_samples:
        break

      end = start + segments_samples
      segment = waveform[:, :, start:end]

      with torch.no_grad():
        mel_spec = get_spectogram(segment)

      # Normalize using the global mean & st
      mel_spec = (mel_spec - global_mean) / global_std

      segment_file = f"{song_id}_{chunk_id}.wav"

      # Save the segment
      torchaudio.save(
        os.path.join(SEGMENTS_DIR, segment_file),
        segment.cpu().squeeze(0),
        sr,
      )

      records.append(
        save_spectogram(mel_spec, song_id, chunk_id, "original")
      )

      chunk_id += 1
    
  # Save the dataframe
  df = pd.DataFrame(records)
  df.to_csv(DATASET_FILE, index=False)

if (len(os.listdir(SPECTOGRAMS_DIR)) == 0 or len(os.listdir(SEGMENTS_DIR)) == 0):
  print("Not founded spectograms, extracting them...")
  extract_features()
else:
  print("Founded spectograms, skiping it")

Founded spectograms, skiping it


# Balance classes using augmentations

In [6]:
augmentations = {
    "background_noise": AddBackgroundNoise(
      p=0.5,
      background_paths=BACKGROUND_NOISE_DIR,
      min_snr_in_db=10,
      max_snr_in_db=20,
      output_type="dict"
    ),  
    "colored_noise": AddColoredNoise(
      p=0.5,
      min_snr_in_db=10,
      max_snr_in_db=20,
      output_type="dict"
    ),
    "gain": Gain(
      p=0.5,  
      min_gain_in_db=-10,
      max_gain_in_db=20,
      output_type="dict"
    ),
    "pitch_shift": PitchShift(
      p=0.5,
      min_transpose_semitones=-2,
      max_transpose_semitones=2,
      sample_rate=TARGET_SAMPLE_RATE,
      output_type="dict"
    ),
  }

def augment_data():
  df = pd.read_csv(DATASET_FILE)

  if (df["aug_name"] == "original").all():
    print("The dataset only has original data, applying augmentation.")

    # Get the number of clases per song_id
    print(df["song_id"].value_counts())

    number_of_samples_target = df["song_id"].value_counts().max() * 1.2

    print(f"Target number of samples: {number_of_samples_target}")

    unique_song_ids = df["song_id"].unique()
    
    for song_id in tqdm(unique_song_ids, desc="Augmenting data", unit="song"):
      song_df = df[df["song_id"] == song_id]

      while df[df["song_id"] == song_id].shape[0] < number_of_samples_target:
        # Get a random sample of the original data
        sample = song_df.sample(n=1, random_state=np.random.randint(0, 10000)).iloc[0]
        segment_file = f"{song_id}_{sample['chunk_id']}.wav"
        file_path = os.path.join(SEGMENTS_DIR, segment_file)

        # Apply a random augmentation
        aug_name = np.random.choice(list(augmentations.keys()))
        augment = augmentations[aug_name]

        waveform, sr = get_waveform_n_sr_from_file(file_path)
        mel_spec = get_spectogram(waveform)
        augmented_segment = augment(waveform, sample_rate=TARGET_SAMPLE_RATE)['samples']
        mel_spec = get_spectogram(augmented_segment)
        mel_spec = (mel_spec - global_mean) / global_std  

        # Check if its a variation
        duplicated = df[
          (df["song_id"] == song_id) & 
          (df["chunk_id"] == sample["chunk_id"]) &
          (df["aug_name"] == aug_name)
        ]

        if duplicated.shape[0] > 0:
          variation_num = duplicated["variation"].max() + 1
        else:
          variation_num = 0
        

        record = save_spectogram(
          mel_spec,
          song_id,
          sample["chunk_id"],
          aug_name,
          variation=variation_num 
        )

        # Append the record to the dataframe
        df = pd.concat([df, pd.DataFrame([record])], ignore_index=True)

    df.to_csv(DATASET_FILE, index=False)
    print(df["song_id"].value_counts())

  else:
    print("The dataset already has augmented data, skipping augmentation.")
    return
  
augment_data()

The dataset already has augmented data, skipping augmentation.


# Dataloaders

In [7]:
class SpectogramDataset(Dataset):
  def __init__(self, spectograms_dir=SPECTOGRAMS_DIR, transform=None):
    self.spectograms_files = os.listdir(SPECTOGRAMS_DIR)
    self.spectograms_dir = spectograms_dir
    self.transform = transform
    dataset = pd.read_csv(DATASET_FILE)
    self.labels = dataset["song_id"]
    self.classes = dataset['song_id'].unique()

  def __len__(self):
    return len(self.spectograms_files)
  
  def __getitem__(self, idx):
    file_name = self.spectograms_files[idx]
    file_path = os.path.join(self.spectograms_dir, file_name)

    mel_spec = np.load(file_path)["data"] 

    if mel_spec.ndim == 4:
      mel_spec = mel_spec.squeeze(0).squeeze(0)
    elif mel_spec.ndim == 3:
      mel_spec = mel_spec.squeeze(0) 
        
    mel_spec = mel_spec[np.newaxis, ...]

    label = int(file_name.split("_")[0])  # Get the song id from the file name 

    if self.transform:
      mel_spec = self.transform(mel_spec) 

    mel_spec = torch.tensor(mel_spec, dtype=torch.float32)
    label = torch.tensor(label, dtype=torch.int64)

    return mel_spec, label

def get_spectograms():
  dataset = SpectogramDataset()

  idx = list(range(len(dataset)))
  train_idx, test_idx = train_test_split(idx, test_size=0.2, random_state=42, stratify=dataset.labels)

  train_subset = Subset(dataset, train_idx)   
  test_subset  = Subset(dataset, test_idx)

  return train_subset, test_subset

# Using a CNN te extract embeddings

In [8]:
class CNN(nn.Module):
  def __init__(self, num_classes=256):
    super(CNN, self).__init__()

    self.conv_block = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=3, padding=1),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      nn.Dropout2d(0.1),
      nn.MaxPool2d((2, 1)),
      
      nn.Conv2d(32, 64, kernel_size=3, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.MaxPool2d((2, 2)),

      nn.Conv2d(64, 128, kernel_size=3, padding=1),
      nn.BatchNorm2d(128),
      nn.ReLU(),
      nn.MaxPool2d((2, 2)),

      nn.Conv2d(128, 256, kernel_size=3, padding=1),
      nn.BatchNorm2d(256),
      nn.ReLU(),

      nn.AdaptiveAvgPool2d((4, 4)),
    )

    self.classifier = nn.Sequential(
      nn.Flatten(),
      nn.Linear(256 * 4 * 4, 1024),
      nn.ReLU(),
      nn.Dropout(0.3),
      nn.Linear(1024, num_classes),
    )

    self.scaler = GradScaler(device)

    self.training_ETA = 0.0
    self.total_epochs = 0
    self.current_epoch = 0
    self.is_model_trained = False
  
  def forward(self, x):
    if x.dim() > 4:
      x = x.squeeze(1)

    x = self.conv_block(x)
    x = self.classifier(x)

    return x

  def fit(self, train_loader, test_loader, epochs=50, learning_rate=0.001, patience=7):
    self.total_epochs = epochs
    training_epoch_duration = []

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(self.parameters(), lr=learning_rate, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience)

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(epochs):
      self.train()
      epoch_loss = 0.0
      correct = 0
      total = 0
      start_time = time.time()
      self.current_epoch = epoch

      i = 0
      for inputs, labels in train_loader:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        with autocast(device_type=device):
          outputs = self(inputs)
          loss = loss_function(outputs, labels)
        
        self.scaler.scale(loss).backward()
        self.scaler.step(optimizer)
        self.scaler.update()

        epoch_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
      
      # Calculate the average loss for the epoch
      train_acc = correct / total
      avg_train_loss = epoch_loss / len(train_loader)

      # Validation
      val_loss = 0.0
      correct = 0
      total = 0
      self.eval()
      with torch.no_grad():
        for inputs, labels in test_loader:
          inputs = inputs.to(device, non_blocking=True)
          labels = labels.to(device, non_blocking=True)

          with autocast(device_type=device):
            outputs = self(inputs)
            loss = loss_function(outputs, labels)

          val_loss += loss.item()
          _, predicted = outputs.max(1)
          total += labels.size(0)
          correct += predicted.eq(labels).sum().item()
      
      val_acc = correct / total
      avg_val_loss = val_loss / len(test_loader)
      scheduler.step(avg_val_loss)

      # Early stopping
      if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_state = self.state_dict()    
        patience_counter = 0
      else:
        patience_counter += 1
        if patience_counter >= patience:
          print(f"Early stopping at epoch {epoch}/{epochs}")
          break

      # Calculate the training ETA
      elapsed_time = time.time() - start_time
      training_epoch_duration.append(elapsed_time)
      self.training_ETA = (sum(training_epoch_duration) / (self.current_epoch + 1)) * (self.total_epochs - self.current_epoch)

      print(f"Epoch {epoch}: Train Loss {avg_train_loss:.4f}, Val Loss {avg_val_loss:.4f}, Train Acc {train_acc:.4f}, Val Acc {val_acc:.4f}, ETA: {self.training_ETA}")

    if best_model_state:
      self.load_state_dict(best_model_state)
      print(f"Best model loaded from epoch {self.current_epoch}/{self.total_epochs}")
    self.is_model_trained = True

  def predict(self, X):
    self.eval()

    X = torch.as_tensor(X, dtype=torch.float32)

    if isinstance(X, list):
        X = torch.stack(X)
    elif X.ndim == 2:
        X = X.unsqueeze(0)

    X = X.to(device)

    with torch.no_grad(), autocast(device_type=device):
        embeddings = self(X)

    embeddings = embeddings.cpu()

    return embeddings

  def save_trained_model(self):
    # Save the model
    torch.save(self.state_dict(), TRAINED_MODEL_FILE)

    print(f"Model saved to {TRAINED_MODEL_FILE}")
  
  def load_trained_model(self):
    # Load the model
    map_location = torch.device(device)
    self.load_state_dict(torch.load(TRAINED_MODEL_FILE, map_location=map_location))
    self.eval()
    self.is_model_trained = True

def initialize_model():
  train_ds, test_ds = get_spectograms()

  # Create DataLoaders for train and test sets
  train_loader = DataLoader(
    train_ds, 
    batch_size=BATCH_SIZE, 
    shuffle=SHUFFLE, 
    num_workers=NUM_WORKERS, 
    pin_memory=True,
    persistent_workers=True if NUM_WORKERS > 0 else False,
    prefetch_factor=2 if NUM_WORKERS > 0 else None,
  )

  test_loader = DataLoader(
    test_ds, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS, 
    pin_memory=True,
    persistent_workers=True if NUM_WORKERS > 0 else False,
    prefetch_factor=2 if NUM_WORKERS > 0 else None,
  )

  n_classes = len(list(set(train_loader.dataset.dataset.classes) | set(test_loader.dataset.dataset.classes)))
  print(f"Num of classes: {n_classes}")

  model = CNN(n_classes).to(device)

  if torch.__version__ >= '2.0':
    model = torch.compile(model)

  train_model(model, train_loader, test_loader)

  return model 

def train_model(model, train_loader, test_loader):
  # Check if the model is already trained
  if (os.path.exists(TRAINED_MODEL_FILE)):
    print(f"Model already trained. Loading model...")
    model.load_trained_model()
  else:
    print(f"Model not trained. Training model with {device}...")
    model.fit(train_loader=train_loader, test_loader=test_loader, epochs=50, learning_rate=0.001)
    model.save_trained_model()
    print(f"Model trained and saved.")

model = initialize_model()

Num of classes: 99
Model not trained. Training model with cuda...
Epoch 0: Train Loss 3.0305, Val Loss 2.1618, Train Acc 0.2618, Val Acc 0.4368, ETA: 17912.929260730743
Epoch 1: Train Loss 1.1905, Val Loss 0.8628, Train Acc 0.6820, Val Acc 0.7570, ETA: 9689.028897047043
Epoch 2: Train Loss 0.6246, Val Loss 0.4328, Train Acc 0.8309, Val Acc 0.8829, ETA: 6902.692459106445
Epoch 3: Train Loss 0.3788, Val Loss 0.2548, Train Acc 0.8983, Val Acc 0.9370, ETA: 5488.51508128643
Epoch 4: Train Loss 0.2661, Val Loss 0.1896, Train Acc 0.9294, Val Acc 0.9495, ETA: 4633.679958629608
Epoch 5: Train Loss 0.2079, Val Loss 0.1394, Train Acc 0.9467, Val Acc 0.9680, ETA: 4045.557001233101
Epoch 6: Train Loss 0.1765, Val Loss 0.1269, Train Acc 0.9560, Val Acc 0.9728, ETA: 3616.589928490775
Epoch 7: Train Loss 0.1560, Val Loss 0.1899, Train Acc 0.9605, Val Acc 0.9513, ETA: 3288.9591942429543
Epoch 8: Train Loss 0.1325, Val Loss 0.0921, Train Acc 0.9669, Val Acc 0.9818, ETA: 3025.115770975749
Epoch 9: Train 