<a href="https://colab.research.google.com/github/joris-vaneyghen/mss-jazz-playalong/blob/main/segmentation/segment_and_tag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ruptures -q
# download our audio example
!git clone https://github.com/joris-vaneyghen/mss-jazz-playalong.git
# dowload our audio tagger
!git clone https://github.com/fschmid56/EfficientAT

In [None]:
%cd EfficientAT/

In [None]:
import torch
from models.dymn.model import get_model as get_dymn
from models.preprocess import AugmentMelSTFT
from helpers.utils import NAME_TO_WIDTH

import librosa
import numpy as np
from torch import autocast
from contextlib import nullcontext
import ruptures as rpt
import matplotlib.pyplot as plt
import json
import os
from sklearn.decomposition import PCA

In [None]:
input_path = '../mss-jazz-playalong/examples'
output_path = '../mss-jazz-playalong/out/demucs'
resolution = 0.32 # resolution of EfficientAT model
min_size = 8 # De minimale lengte van een segment = 8 chunks  (= 2,56 seconden x 0.32)
sample_rate = 32000
max_size = 1024 # nb chunks to tag at ones


In [None]:
def load_mel_and_dymn20_as(device):
    """
    Load the model and mel spectrogram processor for audio tagging.

    Args:
        device (torch.device): The device to load the model onto (e.g., 'cuda' or 'cpu').

    Returns:
        mel (AugmentMelSTFT): Mel spectrogram processor.
        model (torch.nn.Module): Loaded model.
    """
    sample_rate=32000
    window_size=800
    hop_size=320
    n_mels=128
    strides=[2, 2, 2, 2]
    model_name = 'dymn20_as'

    model = get_dymn(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name, strides=strides)

    # Send model to the specified device
    model.to(device)
    model.eval()

    # Create a mel spectrogram processor (preprocessor)
    mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size)
    mel.to(device)
    mel.eval()

    return mel, model

def preds_over_time(mel, model, waveform, device):
  waveform = torch.from_numpy(waveform).to(device)  # shape = (C=2, L)
  all_preds = []
  all_embeds = []
  all_features = []

  max_input_length = int(max_size * sample_rate * resolution)

  # Process waveform in segments of max_input_length
  num_samples = waveform.shape[1]
  for start_idx in range(0, num_samples, max_input_length):
      end_idx = min(start_idx + max_input_length, num_samples)
      waveform_segment = waveform[:, start_idx:end_idx]  # Segment of the waveform

      with torch.no_grad(), autocast(device_type=device) if device == 'cuda' else nullcontext():
          # Compute mel-spectrogram for the current segment
          spec = mel(waveform_segment)  # shape = (C, F=128, T=(end_idx - start_idx) / 320)
          input = spec.unsqueeze(1)  # shape = (N=C, D=1, F, T)
          features = model._feature_forward(input)  # shape = (N, D=1920, F'=F/32, T'≃T/32)

          # Permute Time with Batch dimensions so that avg pooling is done on the batch and frequency dimension
          features = features.permute(3, 1, 2, 0)  # shape = (T', F', C', N)
          preds, embed = model._clf_forward(features)
          preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy()  # shape = (T', D'=527)

          # Collect predictions, embeddings, and features
          all_preds.append(preds)
          all_embeds.append(embed.cpu().numpy())
          all_features.append(features.cpu().numpy())

  # Concatenate the results from all segments
  all_preds = np.concatenate(all_preds, axis=0)  # Concatenate over time axis
  all_embeds = np.concatenate(all_embeds, axis=0)  # Concatenate embeddings over time axis
  all_features = np.concatenate(all_features, axis=0)  # Concatenate features over time axis

  return all_preds, all_embeds, all_features



def features_to_preds(model, features, device):
  features = torch.from_numpy(features).to(device)
  features = features.permute(3, 1, 2, 0)
  features = torch.mean(features, dim=0, keepdim=True)
  with torch.no_grad(), autocast(device_type=device) if device == 'cuda' else nullcontext():
    preds, embed = model._clf_forward(features)
    preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() # shape = (T', D'=527)
  return preds, embed.squeeze().cpu().numpy()

def load_json(dir, mp3_file):
    # Replace .mp3 extension with .json
    json_file_name = mp3_file.replace('.mp3', '.json')
    file_path = os.path.join(dir, json_file_name)

    # Check if the .json file exists
    if not os.path.exists(file_path):
        return {}  # Return an empty dictionary if the .json file doesn't exist

    # Load the JSON file if it exists
    with open(file_path, 'r') as file:
        return json.load(file)

def save_json(dir, mp3_file, data):
    # Replace .mp3 extension with .json
    json_file_name = mp3_file.replace('.mp3', '.json')
    file_path = os.path.join(dir, json_file_name)

    # Check if directory exists, create it if not
    if not os.path.exists(dir):
        os.makedirs(dir)

    # Save the data to the .json file
    with open(file_path, 'w') as file:
        json.dump(data, file, indent=4)

def iterate_files(dir):
    for file_name in os.listdir(dir):
        if file_name.endswith('.mp3'):
            yield file_name

def calculate_bkpts_from_demucs(demucs):
  drums = np.array(demucs['drums'])
  bass = np.array(demucs['bass'])
  vocals = np.array(demucs['vocals'])
  other = np.array(demucs['other'])
  stacked_signal = np.stack((drums, bass, vocals, other), axis=1)

  # Gebruik de Pelt-methode voor breekpuntdetectie
  model = "normal"
  algo = rpt.Pelt(model=model, min_size=min_size, jump=1).fit(stacked_signal)

  # Detecteer breekpunten, zonder het aantal vooraf te specificeren
  penalty = 100  # Penalty bepaalt hoe streng we breekpunten toestaan, je kunt hiermee spelen
  bkps = algo.predict(pen=penalty)

  return bkps


def calculate_bpts_from_tagger(preds, bkps_demucs):
  pca = PCA(n_components=20)  # Reduce to 20 dimensions
  reduced_data = pca.fit_transform(preds)

  start = 0
  new_bkps = []
  for i in range(len(bkps_demucs)):
    end = bkps_demucs[i]
    signal = reduced_data[start:end]

    # Gebruik de Pelt-methode voor breekpuntdetectie
    model = "rbf"
    algo = rpt.Pelt(model=model, min_size=min_size, jump=1).fit(signal)
    # Detecteer breekpunten, zonder het aantal vooraf te specificeren
    penalty = 3  # Penalty bepaalt hoe streng we breekpunten toestaan, je kunt hiermee spelen
    sub_bkps = algo.predict(pen=penalty)
    new_bkps.extend([bkp + start for bkp in sub_bkps])
    start  = end

  return new_bkps


def avg_preds_per_segemnt(model, bkps, features, demucs):
  start = 0
  segments = []
  for i in range(len(bkps)):
    end = bkps[i]
    preds_i, embed_i = features_to_preds(model, features[start: end], device);
    segment = {}
    segment['start_idx'] = start
    segment['end_idx'] = end
    segment['preds'] = preds_i.tolist()
    # segment['embed'] = embed_i.tolist()
    for key in demucs.keys():
      segment[key] = np.array(demucs[key][start:end]).mean(axis=0)
    segments.append(segment)
    start = end

  return segments



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mel, model = load_mel_and_dymn20_as(device)


for mp3_file in iterate_files(input_path):
  data = load_json(output_path, mp3_file)
  if ('demucs' in data.keys() and 'segments' not in data.keys()):
    print(mp3_file)
    audio_path = os.path.join(input_path, mp3_file)
    (waveform, _) = librosa.core.load(audio_path, sr=32000, mono=False)
    bkpts_demucs = calculate_bkpts_from_demucs(data['demucs'])
    preds, embed, features = preds_over_time(mel, model, waveform, device)
    bkpts_tagger = calculate_bpts_from_tagger(preds, bkpts_demucs)
    segments = avg_preds_per_segemnt(model, bkpts_tagger, features, data['demucs'])
    data['segments'] = segments

    save_json(output_path, mp3_file, data)