# Audio Visualization

In [None]:
import os
from typing import Dict, Tuple, List

import librosa
import torch
import torchaudio
import pandas as pd
from matplotlib import pyplot as plt

## Utils

In [None]:
def plot_spectrogram(
    spectrogram: torch.Tensor,
    interpolation: str = "none",
    figsize: Tuple[int, int] = (20, 5)
):
    """Plots a spectrogram with shape (n_freq_bins, n_frames). Converts
    the values from amplitude to decibels for better visualization
    
    Args:
        spectrogram (torch.Tensor): Spectrogram data in amplitude (n_freq_bins, n_frames)
        interpolation (str): Matplotlib interpolation for plt.imshow()
        figsize (Tuple[int, int]): Figsize of the matplotlib figure
    """
    plt.figure(figsize=figsize)
    plt.title("Spectrogram (db)")
    plt.ylabel("Frequency bins")
    plt.xlabel("frame")
    amplitude_to_DB = torchaudio.transforms.AmplitudeToDB()
    plt.imshow(
        amplitude_to_DB(spectrogram),
        origin="lower",
        aspect="auto",
        interpolation=interpolation)
    plt.colorbar()
    plt.show()


def plot_mask(
    mask: torch.Tensor,
    labels: List[str],
    figsize: Tuple[int, int] = (20, 5)
):
    """Plots a mask with shape (n_labels, n_frames)

    Args:
        mask (torch.Tensor): Mask data to plot (n_labels, n_frames)
        labels (List[str]): List of labels to use in the y-axis of the plot
        figsize (Tuple[int, int]): Figsize of the matplotlib figure
    """
    plt.figure(figsize=figsize)
    plt.imshow(mask, aspect="auto", interpolation="none", cmap="jet")
    plt.yticks(range(len(labels)), labels=labels)
    plt.xlabel("Frame")
    plt.colorbar()
    plt.show()


def spec_mask_plot(
    spectrogram: torch.Tensor,
    mask: torch.Tensor,
    mask_labels: List[str],
    spec_interpolation: str = "none",
    figsize: Tuple[int, int] = (20, 10)
):
    """Plots a spectrogram with shape (n_freq_bins, n_frames) and a mask
    with shape (n_labels, n_frames) in a vertical subplot configuration.
    Converts the values of the spectrogram from amplitude to decibels for
    better visualization

    Args:
        spectrogram (torch.Tensor): Spectrogram data in amplitude (n_freq_bins, n_frames)
        mask (torch.Tensor): Mask data to plot (n_labels, n_frames)
        mask_labels (List[str]): List of labels to use in the y-axis of the mask
        spec_interpolation (str): Matplotlib interpolation for the spectrogram
        figsize (Tuple[int, int]): Figsize of the matplotlib figure
    """
    # Prepare the plots figure
    fig, axs = plt.subplots(2, 1, sharex=True, figsize=figsize)
    fig.subplots_adjust(hspace=0)

    # Plot the spectrogram in the top plot
    amplitude_to_DB = torchaudio.transforms.AmplitudeToDB()
    spec_im = axs[0].imshow(
        amplitude_to_DB(spectrogram),
        origin="lower",
        aspect="auto",
        interpolation=spec_interpolation)
    fig.colorbar(spec_im, ax=axs[0], shrink=0.9, pad=0.01)

    # Plot the mask
    mask_im = axs[1].imshow(mask, aspect="auto", interpolation="none", cmap="jet")
    axs[1].set_yticks(range(len(mask_labels)), labels=mask_labels)
    fig.colorbar(mask_im, ax=axs[1], shrink=0.9, pad=0.01)

    plt.show()


def labels_to_mask(
    audio_dir: str,
    audio_name: str,
    audio_annot: pd.DataFrame,
    frame_size: int,
    hop_size: int,
    n_labels: int,
    labels2idx: Dict[str, int],
    silence_label: bool = False,
) -> torch.Tensor:
    """Given a DataFrame with all the annotations of an audio file, creates a binary
    2D tensor with shape (n_labels, n_frames) containing the labels at each frame
    (for a given `frame_size` and `hop_size`)

    Args:
        audio_dir (str): Path to the folder with the audio data
        audio_name (str): name of the audio file ("path" column in raw data)
        audio_annot (pd.DataFrame): Annotations of the audio file to mask
        frame_size (int): Frame size (is samples) to create the mask
        hop_size (int): Hop size (is samples) to create the mask
        n_labels (int): Number of labels to use in the mask
        labels2idx (Dict[str, int]): Mapping from label name to index in the mask
        silence_label (bool): To add an additional label (with idx 0) for silence

    Returns:
        torch.Tensor: Binary mask of the labels
    """
    # Load the audio
    signal, sample_rate = torchaudio.load(os.path.join(audio_dir, f"{audio_name}.wav"))

    # Prepare the mask 2D matrix (labels, frames)
    n_frames = int((signal.shape[-1] - frame_size) / hop_size) + 1
    mask = torch.zeros((n_labels, n_frames))
    # Compute some utility values
    sample_time = 1 / sample_rate
    hop_time = sample_time * hop_size
    for _, row in audio_annot.iterrows():
        start_frame = int(row.start / hop_time)
        end_frame = int(row.end / hop_time)
        mask[labels2idx[row.label], start_frame:end_frame] += 1

    if silence_label:
        silence_mask = mask.sum(axis=0) < 1
        assert mask[0, :].sum() == 0  # Should be empty
        mask[0, :] = silence_mask

    return mask.bool().int()

## Config

In [None]:
ANNOTTATIONS_PATH = "../dataset/labels.csv"
AUDIO_DIR = "../dataset/audios"
AUDIO_ID = "559e3da599a31024cf3744e61a788309"
FRAME_SIZE = 2048  # To compute the spectrogram
HOP_SIZE = 1024  # To compute the spectrogram
CHUNK_SIZE = 128  # Number of spectrogram frames to take for audio chunk
SAMPLE_RATE = 50000  # All the dataset is sampled at 50kHz
SPEC_TYPE = "mel" # "mel" or "base"
N_MELS = 128
SILENCE_LABEL = True  # Add the "silence" label at idx 0

sample_duration = 1 / SAMPLE_RATE
frame_duration = sample_duration * FRAME_SIZE
hop_duration = sample_duration * HOP_SIZE
chunk_duration = ((CHUNK_SIZE - 1) * hop_duration) + frame_duration
print(f"{sample_duration=:.6f}s")
print(f"{frame_duration=:.3f}s")
print(f"{hop_duration=:.3f}s")
print(f"{chunk_duration=:.3f}s")

## Handle Annotations

In [None]:
annotations = pd.read_csv(ANNOTTATIONS_PATH)
annotations.sample(5)

In [None]:
# Extract labels
sorted_labels = list(annotations.label.value_counts().index)
if SILENCE_LABEL:
    sorted_labels = ["silence"] + sorted_labels
n_labels = len(sorted_labels)
labels2idx = {l: i for i, l in enumerate(sorted_labels)}
labels2idx

In [None]:
# Get the annotations corresponding to the selected audio
audio_annot = annotations[annotations.path == AUDIO_ID]
audio_annot

# Handle Audio

In [None]:
# Load the audio signal
audio_path = audio_annot.path.values[0]
signal, _ = torchaudio.load(os.path.join(AUDIO_DIR, f"{audio_path}.wav"))

In [None]:
# Select the transformation to create the spectrogram
if SPEC_TYPE == "base":
    spec_transform = torchaudio.transforms.Spectrogram(
        n_fft=FRAME_SIZE,
        hop_length=HOP_SIZE,
        center=False
    )
elif SPEC_TYPE == "mel":
    spec_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE, 
        n_fft=FRAME_SIZE, 
        win_length=FRAME_SIZE,
        hop_length=HOP_SIZE, 
        center=False, 
        n_mels=N_MELS
    )
else:
    raise ValueError(f"Unexpected spectrogram type ('{SPEC_TYPE}')")

# Compute the spectrogram
spectrogram = spec_transform(signal)
# Compute the binary mask
mask = labels_to_mask(
    AUDIO_DIR,
    audio_path,
    audio_annot,
    FRAME_SIZE,
    HOP_SIZE,
    n_labels,
    labels2idx,
    silence_label=True,
)
print(f"{spectrogram.shape=}")
print(f"{mask.shape=}")
spec_mask_plot(spectrogram[0], mask, sorted_labels, figsize=(40, 15))

In [None]:
# Split the whole audio spectrogram and mask into chunks
spec_chunks = librosa.util.frame(spectrogram, frame_length=CHUNK_SIZE, hop_length=CHUNK_SIZE)
mask_chunks = librosa.util.frame(mask, frame_length=CHUNK_SIZE, hop_length=CHUNK_SIZE)
print(f"{spec_chunks.shape=}")  # shape: (channels, freq, chunk_frames, chunks)
print(f"{mask_chunks.shape=}")  # shape: (labels, chunk_frames, chunks)
assert spec_chunks.shape[-2:] == mask_chunks.shape[-2:]

In [None]:
SHOW_TIME_STEP = 15  # Second to select the chunk from the audio
chunk_idx = int(SHOW_TIME_STEP / chunk_duration)
# Select the target chunk
spec_chunk = spec_chunks[:, :, :, chunk_idx].copy()
mask_chunk = mask_chunks[:, :, chunk_idx].copy()
print(f"{spec_chunk.shape=}")
print(f"{mask_chunk.shape=}")
spec_mask_plot(torch.from_numpy(spec_chunk[0]), torch.from_numpy(mask_chunk), sorted_labels, figsize=(20, 15))