<a href="https://colab.research.google.com/github/joris-vaneyghen/mss-jazz-playalong/blob/main/explore_audio_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Audio Segmentation & Musical Instrument Tagger

### Objective:
The goal of this project is to segment an audio file into distinct sections and tag each section with the instruments being played.

### Requirements:
- Each segment should be at least **2,5 seconds** in length.
- Consecutive segments should feature **different sets of instruments**.

In [None]:
!pip install demucs -q

In [None]:
!pip install ruptures -q

In [None]:
# download our audio example
!git clone https://github.com/joris-vaneyghen/mss-jazz-playalong.git

In [None]:
# dowload our audio tagger
!git clone https://github.com/fschmid56/EfficientAT

In [None]:
# Let's listen to our audio example

from IPython.display import Audio

Audio('mss-jazz-playalong/examples/Sweet Dreams_Single Ladies.mp3')


In [None]:
# Using the audio tagger detect the acoustic events in our audio example. This will print the top 10 detected acoustic events (set runtime type to GPU for faster run)
#!cd EfficientAT && python inference.py --cuda --model_name=dymn20_as --audio_path="../mss-jazz-playalong/examples/Jazz Standards Medley.mp3"

### Instrument Detection Limitations:

Our audio tagger successfully detects instruments such as **Singing, Saxophone, Trombone, and Trumpet**, but tends to ignore **drums** and **double bass**. This limitation arises because the tagger was trained on the **Audioset** dataset, which uses **weakly-labeled** data. In this dataset, **drums** and **bass** were often overlooked, leading to reduced detection accuracy for these instruments.


### Segmentation Approach:

To segment the audio, we convert the waveform into a **multi-dimensional time series** of sound class detections using our audio tagger. This results in a time series with **527 dimensions**, each corresponding to one of the sound classes detected by the tagger.

For detecting change points in this time series, we use the **Ruptures** library, which is well-suited for this task due to several reasons:

- **Versatility**: Ruptures can handle a wide range of data types and is adaptable to different segmentation problems, making it ideal for complex multi-dimensional audio data.
- **Efficiency**: It is optimized for large datasets, allowing fast and accurate detection of change points, even when dealing with high-dimensional time series.
- **Customizability**: Ruptures offers a variety of methods (e.g., dynamic programming, window-based detection) that can be tailored to our specific needs, ensuring robust and reliable segmentation.

By using Ruptures, we can effectively identify moments where the instrument set or sound profile changes, leading to precise audio segmentation.


In [None]:
%cd EfficientAT/

Inspired on EfficientAT/inference.py we load the audio tagger model

In [None]:
import torch
from models.dymn.model import get_model as get_dymn
from models.preprocess import AugmentMelSTFT
from helpers.utils import NAME_TO_WIDTH


def load_mel_and_dymn20_as(device):
    """
    Load the model and mel spectrogram processor for audio tagging.

    Args:
        device (torch.device): The device to load the model onto (e.g., 'cuda' or 'cpu').

    Returns:
        mel (AugmentMelSTFT): Mel spectrogram processor.
        model (torch.nn.Module): Loaded model.
    """
    sample_rate=32000
    window_size=800
    hop_size=320
    n_mels=128
    strides=[2, 2, 2, 2]
    model_name = 'dymn20_as'

    model = get_dymn(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name, strides=strides)

    # Send model to the specified device
    model.to(device)
    model.eval()

    # Create a mel spectrogram processor (preprocessor)
    mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size)
    mel.to(device)
    mel.eval()

    return mel, model


### Customizing Model Output: Retaining the Time Dimension

In our customized model, we aim to keep the **Time** dimension while processing the waveform:

1. **Stereo Channels as Batch**:  
   In the original implementation, a mono waveform is processed. In our case, we use a stereo waveform and stack the left and right channel at the batch dimension.

2. **Time-Frequency Domain Conversion**:  
   After the input waveform is converted into the Time-Frequency domain, the model compresses both the **Time** and **Frequency** dimensions by a factor of 32.

2. **Pooling Operation Before MLP Layers**:  
   Before the final MLP layers, the model performs an **Average Pooling** operation. However, instead of averaging over the Time and Frequency dimensions, we choose to:
   - Retain the **Time** dimension.
   - Perform the averaging across the **Batch** and **Frequency** dimensions.

This approach ensures that the models output can be used as a multi-dimensional time series.


In [None]:
import librosa
import numpy as np
from torch import autocast
from contextlib import nullcontext

def preds_over_time(mel, model, waveform, device):
  waveform = torch.from_numpy(waveform).to(device) # shape = (C=2,L)
  with torch.no_grad(), autocast(device_type=device) if device == 'cuda' else nullcontext():
    spec = mel(waveform) # shape = (C, F=128, T=L/320)
    input = spec.unsqueeze(1) # shape = (N=C, D=1, F, T)
    # print(input.shape)
    features = model._feature_forward(input) # shape = (N, D=1920, F'=F/32, T'≃T/32)
    # print(features.shape)
    #We permute Time with Batch dimensions so that pooling is done on the batch and frequency dimension
    features = features.permute(3, 1, 2, 0) # shape = (T', F', C', N)
    preds, embed = model._clf_forward(features)
    preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() # shape = (T', D'=527)
  return preds, embed.cpu().numpy(), features.cpu().numpy()


In [None]:
# lets test this on our example
audio_path = '../mss-jazz-playalong/examples/Sweet Dreams_Single Ladies.mp3'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
mel, model = load_mel_and_dymn20_as(device)
(waveform, _) = librosa.core.load(audio_path, sr=32000, mono=False)

preds, embed, features = preds_over_time(mel, model, waveform, device)

In [None]:
# The time dimension is reduced by a factor 10240 (hop_size * compress_factor)
import math
assert math.ceil(waveform.shape[1] / (320 * 32)) == preds.shape[0]

In [None]:
import matplotlib.pyplot as plt

plt.plot(preds[:, 27])
plt.xlabel("Time")
plt.ylabel("Probability of Singing")
plt.title("Probability of Singing over Time")
plt.show()

plt.plot(preds[:, 153])
plt.xlabel("Time")
plt.ylabel("Probability of Piano")
plt.title("Probability of Piano over Time")
plt.show()

plt.plot(preds[:, 197])
plt.xlabel("Time")
plt.ylabel("Probability of Sax")
plt.title("Probability of Sax over Time")
plt.show()


In [None]:
import ruptures as rpt

# signal = preds[:, [27, 153, 197]]
signal = embed

# Stel de minimale lengte van een segment in op 8 (= 2,56 seconden. namelijk 8 * (32 * 320) / 32000)
min_size = 8

# Gebruik de Pelt-methode voor breekpuntdetectie
model = "rbf"  # Verandering in gemiddelde (kan aangepast worden naar andere methoden zoals "l1", "l2", "rbf")
algo = rpt.Pelt(model=model, min_size=min_size, jump=1).fit(signal)


# Detecteer breekpunten, zonder het aantal vooraf te specificeren
penalty = 4  # Penalty bepaalt hoe streng we breekpunten toestaan, je kunt hiermee spelen
bkps = algo.predict(pen=penalty)

# Plot het resultaat
rpt.display(preds[:, 153], bkps, figsize=(10, 6))  # Plot alleen de piano dimensie
plt.title("Detectie van breekpunten in het multidimensionale signaal")
plt.show()


# Print de gevonden breekpunten
print("# breekpunten:", len(bkps))
print("Gevonden breekpunten:", bkps)


In [None]:
(waveform, sr) = librosa.core.load(audio_path, mono=True)
f = (10240 * sr) // 32000

In [None]:
display(Audio(waveform[0: bkps[0] * f], rate=sr))
for i in range(len(bkps)-1):
  display(Audio(waveform[bkps[i] * f: bkps[i+1] * f], rate=sr))

In [None]:
!demucs -n htdemucs_ft "../mss-jazz-playalong/examples/Sweet Dreams_Single Ladies.mp3" -o out

In [None]:
Audio("out/htdemucs_ft/Sweet Dreams_Single Ladies/bass.wav")

In [None]:
# prompt: given numpy.ndarray waveform, caculate an average of absolute values per block of 10240

import numpy as np

import numpy as np

def calculate_average_db_per_block(waveform, block_size=10240):
    """
    Calculates the average decibels per block of a given size in a waveform.

    Args:
      waveform: A numpy.ndarray representing the waveform.
      block_size: The size of the block for calculating the average.

    Returns:
      A list of average decibel values for each block.
    """
    averages = []
    # Adding a small value to avoid log of zero
    epsilon = 1e-10  # Prevent log(0)

    for i in range(0, len(waveform), block_size):
        block = waveform[i:i + block_size]
        if len(block) > 0:
            # Calculate absolute value, add epsilon to avoid log(0), then convert to decibels
            block_abs = np.abs(block) + epsilon
            block_db = 20 * np.log10(block_abs)
            # Calculate the average decibels for the block
            # average_db = np.mean(block_db)
            average_db = np.mean(np.abs(block))
            averages.append(average_db)

    return np.array(averages)


# Example usage (assuming you have the 'waveform' variable defined):
# averages_per_block = calculate_average_abs_per_block(waveform)
# print(averages_per_block)


In [None]:
(bass, sr) = librosa.core.load("out/htdemucs_ft/Sweet Dreams_Single Ladies/bass.wav", mono=True)
(drums, sr) = librosa.core.load("out/htdemucs_ft/Sweet Dreams_Single Ladies/drums.wav", mono=True)
(other, sr) = librosa.core.load("out/htdemucs_ft/Sweet Dreams_Single Ladies/other.wav", mono=True)
(vocals, sr) = librosa.core.load("out/htdemucs_ft/Sweet Dreams_Single Ladies/vocals.wav", mono=True)

bass = calculate_average_db_per_block(bass, block_size = 10240 * sr//32000)
other = calculate_average_db_per_block(other, block_size = 10240 * sr//32000)
drums = calculate_average_db_per_block(drums, block_size = 10240 * sr//32000)
vocals = calculate_average_db_per_block(vocals, block_size = 10240 * sr//32000)
stacked_signal = np.stack((drums, bass, vocals, other), axis=1)


In [None]:
# prompt: stack stacked_signal together with embed

import numpy as np
combined_signal = np.concatenate((stacked_signal, preds[:, [187]]), axis=1)
combined_signal.shape

In [None]:
plt.plot(bass)
plt.xlabel("Time")
plt.ylabel("Volume of bass")
plt.title("Volume of bass over Time")
plt.show()

plt.plot(other)
plt.xlabel("Time")
plt.ylabel("Volume of other")
plt.title("Volume of other over Time")
plt.show()

plt.plot(drums)
plt.xlabel("Time")
plt.ylabel("Volume of drums")
plt.title("Volume of drums over Time")
plt.show()

plt.plot(vocals)
plt.xlabel("Time")
plt.ylabel("Volume of vocals")
plt.title("Volume of vocals over Time")
plt.show()


In [None]:
import ruptures as rpt


# Stel de minimale lengte van een segment in op 8 (= 2,56 seconden. namelijk 8 * (32 * 320) / 32000)
min_size = 8

# Gebruik de Pelt-methode voor breekpuntdetectie
model = "normal"  # Verandering in gemiddelde (kan aangepast worden naar andere methoden zoals "l1", "l2", "rbf")
algo = rpt.Pelt(model=model, min_size=min_size, jump=1).fit(stacked_signal)
# algo = rpt.Dynp(model=model, min_size=min_size, jump=1).fit(combined_signal)
# algo = rpt.Window(width=min_size, model=model).fit(stacked_signal)

# Detecteer breekpunten, zonder het aantal vooraf te specificeren
penalty = 100  # Penalty bepaalt hoe streng we breekpunten toestaan, je kunt hiermee spelen
bkps = algo.predict(pen=penalty)

# Plot het resultaat
rpt.display(stacked_signal[:, 0], bkps, figsize=(10, 6))  # Plot alleen de other dimensie
plt.title("Detectie van breekpunten in het multidimensionale signaal")
plt.show()

# Plot het resultaat
rpt.display(stacked_signal[:, 1], bkps, figsize=(10, 6))  # Plot alleen de other dimensie
plt.title("Detectie van breekpunten in het multidimensionale signaal")
plt.show()

# Plot het resultaat
rpt.display(stacked_signal[:, 2], bkps, figsize=(10, 6))  # Plot alleen de other dimensie
plt.title("Detectie van breekpunten in het multidimensionale signaal")
plt.show()

# Plot het resultaat
rpt.display(stacked_signal[:, 3], bkps, figsize=(10, 6))  # Plot alleen de other dimensie
plt.title("Detectie van breekpunten in het multidimensionale signaal")
plt.show()



# Print de gevonden breekpunten
print("# breekpunten:", len(bkps))
print("Gevonden breekpunten:", bkps)


In [None]:
(waveform, sr) = librosa.core.load(audio_path, mono=True)
f = (10240 * sr) // 32000

display(Audio(waveform[0: bkps[0] * f], rate=sr))
for i in range(len(bkps)-1):
  display(Audio(waveform[bkps[i] * f: bkps[i+1] * f], rate=sr))

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=20)  # Reduce to 20 dimensions
reduced_data = pca.fit_transform(preds)

In [None]:
start = 0
new_bkps = []
parent = []
for i in range(len(bkps)):
  end = bkps[i]
  signal = reduced_data[start:end]
  # Stel de minimale lengte van een segment in op 8 (= 2,56 seconden. namelijk 8 * (32 * 320) / 32000)
  min_size = 8
  # Gebruik de Pelt-methode voor breekpuntdetectie
  model = "rbf"  # Verandering in gemiddelde (kan aangepast worden naar andere methoden zoals "l1", "l2", "rbf")
  algo = rpt.Pelt(model=model, min_size=min_size, jump=1).fit(signal)
  # Detecteer breekpunten, zonder het aantal vooraf te specificeren
  penalty = 3  # Penalty bepaalt hoe streng we breekpunten toestaan, je kunt hiermee spelen
  sub_bkps = algo.predict(pen=penalty)
  parent.extend([i for bkp in sub_bkps])
  new_bkps.extend([bkp + start for bkp in sub_bkps])
  start  = end


print(bkps)
print(new_bkps)
print(parent)

In [None]:
display(Audio(waveform[0: new_bkps[0] * f], rate=sr))
for i in range(len(new_bkps)-1):
  print(parent[i + 1])
  print(new_bkps[i])
  display(Audio(waveform[new_bkps[i] * f: new_bkps[i+1] * f], rate=sr))

In [None]:
def features_to_preds(model, features, device):
  features = torch.from_numpy(features).to(device)
  features = features.permute(3, 1, 2, 0)
  features = torch.mean(features, dim=0, keepdim=True)
  with torch.no_grad(), autocast(device_type=device) if device == 'cuda' else nullcontext():
    preds, embed = model._clf_forward(features)
    preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() # shape = (T', D'=527)
  return preds, embed.squeeze().cpu().numpy()

In [None]:
mel, model = load_mel_and_dymn20_as(device)


In [None]:
start = 0
preds_per_segm = []
embed_per_segm = []
demucs_per_segm = []
for i in range(len(new_bkps)):
  end = new_bkps[i]
  preds_i, embed_i = features_to_preds(model, features[start: end], device);
  preds_per_segm.append(preds_i)
  embed_per_segm.append(embed_i)
  demucs_per_segm.append(stacked_signal[start: end].mean(axis=0))
  start = end

preds_per_segm = np.vstack(preds_per_segm)
embed_per_segm = np.vstack(embed_per_segm)
demucs_per_segm= np.vstack(demucs_per_segm)

demucs_per_segm.shape

In [None]:
preds_per_segm[:,194]

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# Range of clusters to try (from 2 to 5)
cluster_range = range(2, 10)

# List to store SSE (sum of squared distances) for the elbow method
sse = []
silhouette_scores = []

pca = PCA(n_components=6)  # Reduce to 50 dimensions or fewer
data = pca.fit_transform(embed_per_segm)
data = np.concatenate((preds_per_segm, demucs_per_segm), axis=1)

# Perform KMeans clustering for different values of k
for k in cluster_range:
    kmeans = KMeans(n_clusters=k, random_state=42)
    kmeans.fit(data)
    sse.append(kmeans.inertia_)  # SSE for elbow method
    silhouette_avg = silhouette_score(data, kmeans.labels_)
    silhouette_scores.append(silhouette_avg)

# Plot SSE for elbow method
plt.figure(figsize=(10, 5))
plt.plot(cluster_range, sse, 'bx-')
plt.xlabel('Number of clusters (k)')
plt.ylabel('SSE (Sum of Squared Distances)')
plt.title('Elbow Method for Optimal k')
plt.show()

# Plot Silhouette Score for each k
plt.figure(figsize=(10, 5))
plt.plot(cluster_range, silhouette_scores, 'bx-')
plt.xlabel('Number of clusters (k)')
plt.ylabel('Silhouette Score')
plt.title('Silhouette Score for Different k')
plt.show()

# Choose the best k based on visual inspection of the elbow and silhouette score
best_k = cluster_range[np.argmax(silhouette_scores)]
print(f"Best number of clusters: {best_k}")

# Perform KMeans clustering with the best k
kmeans = KMeans(n_clusters=best_k, random_state=42)
labels = kmeans.fit_predict(data)

# Print cluster labels for each sample
print("Cluster labels for the data points:", labels)
