# Tutorial about Source separation and introducing conditioning

- date: 2024-10-10
- author: gmeseguerbrocal@deezer.com

The code of the model is based on UNet and CUNet for source separation
- Papers:
    - https://openaccess.city.ac.uk/id/eprint/19289/1/7bb8d1600fba70dd79408775cd0c37a4ff62.pdf
    - https://arxiv.org/pdf/1907.01277

## Import packages

In [1]:
import os
import random
import requests
import shutil
import itertools
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Tuple, Union, Iterator

import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from tensorflow.keras.utils import Progbar
from torch.utils.data import IterableDataset

!pip install musdb --quiet
import musdb

!pip install museval --quiet
from museval.metrics import bss_eval

!pip install einops --quiet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m963.5/963.5 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.9/137.9 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h


## Dataset

This section of the code will help you:

* Download the **MUSDB** dataset: A rich collection of music tracks with mixtures separated into 4 distinct stems: vocals, drums, bass, and others.

* Create an Infinite DataLoader: This DataLoader is designed to continuously provide data points, ensuring a seamless supply of both mixed audio and target sources when calling next. Here's how it works:

    * **MusDBDataset**: Loads the audio files of the desired split in parallel, accessing all the stems. It chunks each song into individual segments and yields them.

    * **SourceSeparationDataloader**: Takes individual segments from MusDBDataset and creates a buffer to mix segments from different audio tracks, ensuring a diverse and continuous stream of data.


In [2]:
FULL_MUSDB = True  # Set this accordingly
MUSDB_PATH = '/content/musdb/'

if not os.path.exists(MUSDB_PATH):
    os.makedirs(MUSDB_PATH)

def download_file(url: str, dest_path: str):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024 * 1024  # 1 MB

    with open(dest_path, 'wb') as file:
        for chunk in tqdm(response.iter_content(chunk_size=block_size), total=total_size // block_size, unit='MB', unit_scale=True):
            file.write(chunk)
    print("\nDownload complete!")


def extract_zip(zip_path: str, extract_to: str):
    print(f"Extracting {zip_path} to {extract_to}...")
    shutil.unpack_archive(zip_path, extract_to)

def remove_dir(path: str):
    if os.path.exists(path):
        shutil.rmtree(path)

if FULL_MUSDB:
    download_file('https://zenodo.org/records/1117372/files/musdb18.zip?download=1', os.path.join(MUSDB_PATH, 'musdb.zip'))

    remove_dir(os.path.join(MUSDB_PATH, 'test/'))
    remove_dir(os.path.join(MUSDB_PATH, 'train/'))

    extract_zip(os.path.join(MUSDB_PATH, 'musdb.zip'), MUSDB_PATH)

    musdb.DB(root=MUSDB_PATH, download=False)
else:
    musdb.DB(root=MUSDB_PATH, download=True)


4.47kMB [02:45, 26.9MB/s]                           



Download complete!
Extracting /content/musdb/musdb.zip to /content/musdb/...


In [None]:
class SourceSeparationDataloader:
    """
    A custom data loader with shuffle buffer functionality.

    Attributes:
    dataset (torch.utils.data.Dataset): The dataset to shuffle.
    buffer_size (int): Size of the shuffle buffer.
        The bigger the buffer size is more audio segments  will be include.
        Since we are getting all the segments of a song to speed up the process
        we want this to be big to mix segments from different songs
    batch_size (int): Size of each batch.
    buffer (list): The buffer holding dataset elements.
    dataset_iter (iterator): An iterator over the dataset.
    """

    def __init__(self, dataset: torch.utils.data.Dataset, buffer_size: int, batch_size: int) -> None:
        """
        Initialize a shuffle buffer for the dataset.

        Parameters:
        dataset (torch.utils.data.Dataset): The dataset to shuffle.
        buffer_size (int): Size of the shuffle buffer.
        batch_size (int): Size of each batch.
        """
        self.dataset = dataset
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.buffer = []
        self.device = device
        self.dataset_iter = iter(dataset)

        for i in range(buffer_size):
            if (i + 1) % (buffer_size // 10) == 0:
                print(f"Buffer {int(((i + 1) / buffer_size) * 100)}% filled.")
            self.buffer.append(next(self.dataset_iter))

    def get_next(self) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
        """
        Get the next batch of shuffled elements from the buffer.

        Returns:
        tuple[torch.Tensor, torch.Tensor, List[str]]: A batch of randomly shuffled elements.
        """
        batch_x = []
        batch_y = []
        batch_label = []

        for _ in range(self.batch_size):
            if len(self.buffer) == 0:
                break  # If buffer is empty, stop forming the batch

            idx = random.randint(0, len(self.buffer) - 1)
            batch_x.append(self.buffer[idx][0])
            batch_y.append(self.buffer[idx][1])
            batch_label.append(self.buffer[idx][2])

            try:
                self.buffer[idx] = next(self.dataset_iter)
            except StopIteration:
                self.buffer.pop(idx)

        return torch.stack(batch_x), torch.stack(batch_y), batch_label

    def __iter__(self) -> 'SourceSeparationDataloader':
        return self

    def __next__(self) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
        return self.get_next()


class MusDBDataset(IterableDataset):
    """
    A dataset for MusDB tracks with segment processing and multiprocessing
    loading.

    Attributes:
    split (str): The dataset split ('train', 'valid', etc.).
    targets (list): List of target audios ('vocals', 'drums', etc.).
    segment_dur (float): Duration of each segment in seconds.
    segment_overlap (float): Overlap between segments in seconds.
    num_workers (int): Number of threads for parallel processing.
    mus (musdb.DB): MusDB dataset.
    """

    def __init__(
        self,
        split: str,
        targets: List[str],
        segment_dur: float,
        segment_overlap: float,
        num_workers: int,
        musdb_path: str = MUSDB_PATH,
    ) -> None:
        """
        Initialize the MusDBDataset.

        Parameters:
        split (str): The dataset split ('train', 'valid', etc.).
        targets (list): List of target audios ('vocals', 'drums', etc.).
        segment_dur (float): Duration of each segment in seconds.
        segment_overlap (float): Overlap between segments in seconds.
        num_workers (int): Number of threads for parallel processing.
        """

        valid_targets = {'vocals', 'drums', 'bass', 'other'}
        assert all(target in valid_targets for target in targets)

        self.split = split
        self.targets = targets
        self.segment_dur = segment_dur
        self.segment_overlap = segment_overlap
        self.num_workers = num_workers
        self.mus = musdb.DB(subsets=self.split, root=musdb_path)

        if not self.mus.tracks:
            raise ValueError(f"The dataset for split '{self.split}' is empty or not loaded properly.")

    def _process_track(self, track) -> List[Tuple[torch.Tensor, torch.Tensor, str]]:
        """
        Process a track by unfolding and rearranging its audio data.

        Parameters:
        track (musdb.Track): A musdb track object.

        Returns:
        list[tuple[torch.Tensor, torch.Tensor, str]]: Processed input and target audio segments.
        """
        size = int(track.rate * self.segment_dur)
        step = int(track.rate * self.segment_overlap)
        name = track.name
        rate = track.rate
        x = torch.tensor(track.audio).to(torch.float32)
        x = rearrange(x.unfold(dimension=0, size=size, step=step), 's c d -> s d c')
        segments = []
        for target in self.targets:
            y = torch.tensor(track.targets[target].audio).to(torch.float32)
            y = rearrange(y.unfold(dimension=0, size=size, step=step), 's c d -> s d c')
            segments.extend(list(zip(x, y, [target]*x.shape[0])))
        del(x, y, track)
        return (segments, name, rate)


    def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor, str]]:
        """
        Iterate over the dataset with parallel track processing.

        Yields:
        tuple[torch.Tensor, torch.Tensor, str]: Segmented input and target audio.
        """

        if self.split == 'train':
            # Continuous processing for training data
            while True:
                with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
                    tracks = [self.mus.tracks[i] for i in torch.randperm(len(self.mus.tracks)).tolist()][:self.num_workers]
                    futures = [executor.submit(self._process_track, track) for track in tracks]
                    for future in as_completed(futures):
                        segments = future.result()[0]
                        for x, y, l in segments:
                            yield x, y, l
        elif self.split == 'test':
            # One-time processing for test data
            with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
                futures = [executor.submit(self._process_track, track) for track in self.mus.tracks]
                for future in as_completed(futures):
                    yield future.result()
        else:
            raise ValueError(f"Unknown split: {self.split}")


# Example of how to initialize the data loader
duration = 4
ds_train = SourceSeparationDataloader(MusDBDataset('train', ['vocals'], duration, duration/2, num_workers=8), buffer_size=2999, batch_size=32)

Buffer 9% filled.
Buffer 19% filled.
Buffer 29% filled.
Buffer 39% filled.
Buffer 49% filled.
Buffer 59% filled.
Buffer 69% filled.
Buffer 79% filled.
Buffer 89% filled.
Buffer 99% filled.


In [4]:
d = next(ds_train)
print(d[0].shape, d[1].shape, d[2])
del(ds_train)

torch.Size([32, 176400, 2]) torch.Size([32, 176400, 2]) ['vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals', 'vocals']


## Spectral operations

In this section, we cover the following spectral operations critical for the source separation model:

* **Waveform Processing and STFT Computation**: Functions to process waveforms, compute the complex Short Fast Fourier Transform (STFT), and then use the inverse STFT (iSTFT) to revert to the waveform.

* **Model Input and Output**: The input and output of the source separation model are waveforms.

* **Internal Operations**: We compute the FFT of the input waveform. The deep neural network receives the real and imaginary parts of the FFT as two separate channels. The mask is applied independently to each channel (real/imaginary). The iFFT then converts the masked output back to the waveform.

* **Model Blocks**: Detailed description of the model components responsible for these operations.



In [5]:
def get_audio_prepro_args(
    dur: float,
    window: str = "hanning",
    n_fft: int = 2048,
    sr: int = 44100,
    hop_factor: float = 0.5,
    stereo: bool = True,
) -> tuple[torch.Tensor, int, int, int, int, int, int, bool]:
    """
    Prepare audio preprocessing arguments.

    Parameters:
    dur (float): Duration in seconds.
    window (str): Window type. Default is "hanning".
    n_fft (int): Number of FFT points. Default is 2048.
    sr (int): Sample rate. Default is 44100.
    hop_factor (float): Factor to calculate hop length. Default is 0.5.
    stereo (bool): If True, stereo audio is used. Default is True.

    Returns:
    tuple: window tensor, number of FFT points, hop length, sample rate, number of frames,
           number of bins, length in samples, stereo flag.
    """
    hop_fft = np.round(n_fft * hop_factor).astype(np.int32)
    length_in_samples = int(np.ceil(dur * sr))
    n_frames = int(np.ceil(length_in_samples / hop_fft))
    n_bins = n_fft // 2 + 1
    if window == "hanning":
        w = torch.hann_window(n_fft)
    return w, n_fft, hop_fft, sr, n_frames, n_bins, length_in_samples, stereo


def view_as_real_img(x: torch.Tensor) -> torch.Tensor:
    """
    Convert complex tensor to a real image tensor.

    Parameters:
    x (torch.Tensor): Complex input tensor.

    Returns:
    torch.Tensor: Real image tensor with separated real and imaginary parts.
    """
    return torch.cat((x.real.unsqueeze(-1), x.imag.unsqueeze(-1)), dim=-1)


def waveform2spec(
    x: torch.Tensor,
    window: torch.Tensor,
    n_fft_audio: int,
    fft_hop: int
) -> torch.Tensor:
    """
    Convert waveform tensor to spectrogram tensor using Short-Time Fourier Transform (STFT).

    Parameters:
    x (torch.Tensor): Input waveform tensor.
    window (torch.Tensor): Window function tensor.
    n_fft_audio (int): Number of FFT points.
    fft_hop (int): Hop length for STFT.

    Returns:
    torch.Tensor: Spectrogram tensor with separated real and imaginary parts.
    """
    return view_as_real_img(x.stft(n_fft=n_fft_audio, window=window, hop_length=fft_hop, return_complex=True).type(torch.complex64))


class STFTModule(nn.Module):
    """
    Module for Short-Time Fourier Transform (STFT) on audio signals.

    Attributes:
    duration (float): Duration of audio signal.
    window_fft (torch.Tensor): Window function tensor.
    n_fft (int): Number of FFT points.
    hop_fft (int): Hop length.
    sr (int): Sample rate.
    n_frames (int): Number of frames.
    n_bins (int): Number of frequency bins.
    length_in_samples (int): Length of audio signal in samples.
    stereo (bool): If True, stereo audio is used.
    stft_fn (function): Function for STFT computation.
    """

    def __init__(self, duration: float) -> None:
        super(STFTModule, self).__init__()
        self.window_fft, self.n_fft, self.hop_fft, self.sr, self.n_frames, self.n_bins, self.length_in_samples, self.stereo = get_audio_prepro_args(duration)
        self.duration = duration
        self.stft_fn = waveform2spec

    def get_input_shape(self) -> tuple[int, int]:
        """
        Get the input shape for the module.

        Returns:
        tuple: Shape of input tensor.
        """
        return self.length_in_samples, 2 if self.stereo else 1

    def get_out_shape(self) -> list[int]:
        """
        Get the output shape for the module.

        Returns:
        list: Shape of output tensor.
        """
        return [4 if self.stereo else 2, int(self.n_frames), int(self.n_bins)]

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for the module.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        tuple: Output tensor after STFT.
        """
        shape = x.shape
        x = rearrange(x, "b s c -> (b c) s")
        x = self.stft_fn(x, self.window_fft.to(x.device), self.n_fft, self.hop_fft)
        x = rearrange(x, "(b c) f t r -> b (c r) t f", b=shape[0], c=shape[-1])
        return x


class iSTFTModule(nn.Module):
    """
    Module for Inverse Short-Time Fourier Transform (iSTFT) on audio signals.

    Attributes:
    duration (float): Duration of audio signal.
    window_fft (torch.Tensor): Window function tensor.
    n_fft (int): Number of FFT points.
    hop_fft (int): Hop length.
    sr (int): Sample rate.
    n_frames (int): Number of frames.
    n_bins (int): Number of frequency bins.
    length_in_samples (int): Length of audio signal in samples.
    stereo (bool): If True, stereo audio is used.
    """

    def __init__(self, duration: float) -> None:
        super(iSTFTModule, self).__init__()
        self.window_fft, self.n_fft, self.hop_fft, self.sr, self.n_frames, self.n_bins, self.length_in_samples, self.stereo = get_audio_prepro_args(duration)
        self.duration = duration

    def istft_fn(
        self,
        x_fft: torch.Tensor,
        window: torch.Tensor,
        n_fft_audio: int,
        fft_hop: int,
        length: int
    ) -> torch.Tensor:
        """
        Compute Inverse Short-Time Fourier Transform (iSTFT).

        Parameters:
        x_fft (torch.Tensor): Input spectrogram tensor.
        window (torch.Tensor): Window function tensor.
        n_fft_audio (int): Number of FFT points.
        fft_hop (int): Hop length for iSTFT.
        length (int): Length of the output waveform.

        Returns:
        torch.Tensor: Reconstructed waveform tensor.
        """
        return x_fft.istft(
            n_fft=n_fft_audio,
            window=window,
            hop_length=fft_hop,
            return_complex=False,
            length=length,
        )

    def get_input_shape(self) -> list[int]:
        """
        Get the input shape for the module.

        Returns:
        list: Shape of input tensor.
        """
        return [4 if self.stereo else 2, self.n_frames, self.n_bins]

    def get_out_shape(self) -> tuple[int, int]:
        """
        Get the output shape for the module.

        Returns:
        tuple: Shape of output tensor.
        """
        return self.length_in_samples, 2 if self.stereo else 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the module.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Output tensor after iSTFT.
        """
        shape = x.shape
        x = rearrange(x, "b (c r) t f -> (b c) f t r", r=2, c=shape[1] // 2)
        x = torch.view_as_complex(x.contiguous())
        x = self.istft_fn(x, self.window_fft.to(x.device), self.n_fft, self.hop_fft, self.length_in_samples)
        return rearrange(x, "(b c) s -> b s c", b=shape[0], c=shape[1] // 2)

## Blocks

Basic building blocks for our Source Separation UNet:

*   **Convolutional block for the Encoder**: This block is essential for capturing spatial hierarchies and feature representations from the input data.
*   **Up-convolution for the Decoder**: This includes an interpolation function for up-sampling any tensor dimension to match the corresponding skip connection dimension.

These foundational blocks can be seamlessly swapped out for more advanced options, such as ConvNeXt, to enhance model performance.

In [6]:
def interpolate_tensor(x: torch.Tensor, shape: tuple[int, int], mode: str = "bilinear") -> torch.Tensor:
    """
    Interpolate a tensor to a given shape using specified interpolation mode.

    Parameters:
    x (torch.Tensor): Input tensor.
    shape (tuple[int, int]): Target shape for interpolation.
    mode (str): Interpolation mode. Default is "bilinear".

    Returns:
    torch.Tensor: Interpolated tensor.
    """
    return F.interpolate(x, size=shape, mode=mode, align_corners=False, antialias=False)


def get_activation(name: str) -> nn.Module:
    """
    Retrieve activation function by name.

    Parameters:
    name (str): Name of the activation function.

    Returns:
    nn.Module: Activation function module.

    Raises:
    ValueError: If the activation function name is not found.
    """
    activation_map = {
        "relu": nn.ReLU(),
        "leaky_relu": nn.LeakyReLU(),
        "gelu": nn.GELU(),
        "prelu": torch.nn.PReLU(),
        "tanh": nn.Tanh(),
        "identity": nn.Identity(),
    }
    if name in activation_map:
        return activation_map[name]
    else:
        raise ValueError(f"Activation function {name} not found.")


class ConvBlock(nn.Module):
    """
    Convolutional block with optional dropout and activation.

    Attributes:
    in_channels (int): Number of input channels.
    out_channels (int): Number of output channels.
    ops (nn.ModuleList): List of operations in the block.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[tuple[int, int], int] = 3,
        dropout: bool = False,
        activation: str = 'relu',
        **kargs: dict
    ) -> None:
        """
        Initialize the ConvBlock.

        Parameters:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (Union[tuple[int, int], int]): Size of the convolving kernel. Default is 3.
        dropout (bool): If True, includes a dropout layer. Default is True.
        activation (str): Activation function name. Default is 'relu'.
        **kargs (dict): Additional keyword arguments for Conv2d.
        """
        super(ConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ops = nn.ModuleList()

        self.ops.append(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                **kargs
            )
        )
        self.ops.append(nn.BatchNorm2d(out_channels))
        if dropout:
            self.ops.append(nn.Dropout(p=0.25))
        self.ops.append(get_activation(activation))

    def get_out_shape(self, shape: tuple[int, int]) -> list[int]:
        """
        Compute the output shape of the block given input shape.

        Parameters:
        shape (tuple[int, int]): Input shape.

        Returns:
        list[int]: Output shape.
        """
        def to_dim_fn(v: int, p: int, d: int, k: int, s: int) -> int:
            return int(1 + ((v + 2 * p - d * (k - 1) - 1) / s))

        if self.ops[0].padding != "same":
            p = self.ops[0].padding
            d = self.ops[0].dilation
            k = self.ops[0].kernel_size
            s = self.ops[0].stride
            a = to_dim_fn(shape[1], p[0], d[0], k[0], s[0])
            b = to_dim_fn(shape[2], p[1], d[1], k[1], s[1])
        else:
            a, b = shape[1:]
        return [self.out_channels, a, b]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the convolutional block.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Output tensor.
        """
        for op in self.ops:
            x = op(x)
        return x


class UpConvBlock(ConvBlock):
    """
    Upsampling convolutional block with interpolation.

    Attributes:
    out_shape (list[int]): Target shape after upsampling.
    """

    def __init__(self, in_channels: int, out_channels: int, out_shape: list[int], **kargs: dict) -> None:
        super(UpConvBlock, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            **kargs,
        )
        self.out_shape = out_shape

    def get_out_shape(self) -> list[int]:
        """
        Get the output shape for the block.

        Returns:
        list[int]: Output shape.
        """
        return [self.out_channels, *self.out_shape]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the upsampling block.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Output tensor.
        """
        x = interpolate_tensor(x, self.out_shape[1:])
        return super(UpConvBlock, self).forward(x)


## Model

Code for creating the model:

* **STFT**: Computes the spectral transformation from the waveform.

* **Encoder**: Codifies and highlights the relevant information from the signal. It downsamples the input tensor (the complex STFT with imaginary and real components stored independently per channel) by a factor of 2 in both time and frequency dimensions, while simultaneously increasing the number of channels by a power of 2.

* **Decoder**: Transforms the latent space back to the original input. It has skip connections to retrieve fine-grain details in the signal. The output maintains the same dimensions as the input to the encoder, allowing it to be used either as a mask for the input or as a final output signal. Both cases will then be fed into the iSTFT.

* **iSTFT**: Computes the waveform from the complex spectrogram.

In [7]:
class Encoder(nn.Module):
    """
    Encoder module consisting of multiple convolutional blocks.

    Attributes:
    n_layers (int): Number of layers in the encoder.
    ops (nn.ModuleList): List of convolutional operations.
    block_shapes (List[List[int]]): List of shapes of the blocks in the encoder.
    """

    def __init__(
        self,
        block_input_shape: List[int],
        n_layers: int = 6,
        n_filters: int = 16,
        max_n_filters: int = 512
    ) -> None:
        """
        Initialize the Encoder module.

        Parameters:
        block_input_shape (List[int]): Shape of the input block.
        n_layers (int): Number of layers in the encoder. Default is 6.
        n_filters (int): Initial number of filters for the convolutions. Default is 16.
        max_n_filters (int): Maximum number of filters for the convolutions. Default is 512.
        """
        super(Encoder, self).__init__()
        self.n_layers = n_layers
        self.ops = nn.ModuleList([])
        self.block_shapes = [block_input_shape]
        kargs = {
            "out_channels": n_filters,
            "stride": 2,
            "in_channels": block_input_shape[0]
        }
        for _ in range(self.n_layers):
            self.ops.append(ConvBlock(**kargs))
            block_input_shape = self.ops[-1].get_out_shape(block_input_shape)
            self.block_shapes.append(block_input_shape)
            kargs['in_channels'] = self.block_shapes[-1][0]
            kargs['out_channels'] = min(self.block_shapes[-1][0] * 2, max_n_filters)

    def forward(self, x: torch.Tensor) -> list[Union[torch.Tensor, Any]]:
        """
        Forward pass through the encoder.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        list[Union[torch.Tensor, Any]]: List of outputs from each layer of the encoder.
        """
        outputs = [x]
        for op in self.ops:
            x = op(x)
            outputs.append(x)
        return outputs

class Decoder(nn.Module):
    """
    Decoder module for upsampling and reconstructing the input.

    Attributes:
    mask_act (str): Activation function for the mask.
    encoder_block_shapes (List[List[Any]]): Shapes of the encoder blocks.
    ops (nn.ModuleList): List of upsampling convolutional operations.
    """

    def __init__(
        self,
        encoder_block_shapes: List[List[Any]],
        mask_act: str,
    ) -> None:
        """
        Initialize the Decoder module.

        Parameters:
        encoder_block_shapes (List[List[Any]]): Shapes of the encoder blocks.
        mask_act (str): Activation function for the mask.
        """
        super(Decoder, self).__init__()
        self.mask_act = mask_act
        self.encoder_block_shapes = encoder_block_shapes
        self.ops = nn.ModuleList([])

        for i in range(len(self.encoder_block_shapes) - 1):
            in_channels = self.encoder_block_shapes[i][0] * 2 if i != 0 else self.encoder_block_shapes[i][0]
            kargs = {
                'out_shape': self.encoder_block_shapes[i+1],
                'in_channels': in_channels,
                'out_channels': self.encoder_block_shapes[i+1][0],
                'padding': 'same',
            }
            if i == len(self.encoder_block_shapes) - 2:
                kargs['activation'] = mask_act
            self.ops.append(UpConvBlock(**kargs))

    def forward(self, *x: Any) -> torch.Tensor:
        """
        Forward pass through the decoder.

        Parameters:
        x (Any): Input tensors.

        Returns:
        torch.Tensor: Output tensor after upsampling and reconstruction.
        """
        x, encoder_outputs = x[0], x[1:]
        for i in range(len(encoder_outputs)):
            if i < len(encoder_outputs) - 1:
                x = torch.cat((self.ops[i](x), encoder_outputs[i]), dim=1)
            else:
                x = self.ops[i](x)
        return x


class SourceSeparation(nn.Module):
    """
    Source separation module using STFT, Encoder, Decoder, and iSTFT.

    Attributes:
    duration (float): Duration of the audio.
    behavior (str): Behavior for the output ("masking" or "mapping").
    stft (STFTModule): Short-Time Fourier Transform module.
    encoder (Encoder): Encoder module.
    decoder (Decoder): Decoder module.
    istft (iSTFTModule): Inverse Short-Time Fourier Transform module.
    """

    def __init__(
        self,
        duration: float,
        mask_act: str = "tanh",
        behavior: str = "masking",
    ) -> None:
        """
        Initialize the SourceSeparation module.

        Parameters:
        duration (float): Duration of the audio.
        mask_act (str): Activation function for the mask. Default is "tanh".
        behavior (str): Behavior for the output. Either "masking" or "mapping". Default is "masking".
        """
        super(SourceSeparation, self).__init__()
        assert behavior in ["masking", "mapping"]
        self.duration = duration
        self.behavior = behavior
        self.stft = STFTModule(duration)
        self.encoder = Encoder(self.stft.get_out_shape())
        self.decoder = Decoder(
            encoder_block_shapes=self.encoder.block_shapes[::-1],
            mask_act=mask_act
        )
        self.istft = iSTFTModule(duration)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the source separation module.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Output tensor after source separation.
        """

        # we transform the waveform into a time-frequency representation
        x = self.stft(x)

        """
        The encoder takes spec with the real/img info on each channel and processes
        it through multiple convolutional layers. This step essentially extracts features
        and compresses the audio representation.
        """
        x_encoder = self.encoder(x)

        """
        The decoder takes the output from the encoder and reconstructs the separated
        audio signal. It uses upsampling and skip connections to recover fine details.
        """
        y = self.decoder(*x_encoder[::-1])

        if self.behavior == "masking":
            y = torch.multiply(x, y)

        return self.istft(y)

### Loss

Several waveform loss approaches exist. We use only the basic cosine distance loss, which optimizes well and efficiently guides the model, though not perfectly.

In [8]:
def cosine_loss() -> torch.nn.CosineSimilarity:
    """
    Create a cosine similarity-based loss function.

    Returns:
    torch.nn.CosineSimilarity: Cosine similarity function with an embedded loss function.
    """
    cosine = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
    """
    (1 - cosine) to have a loss between 2 and 0.
    Minimizing it will do the signals be close together
    """
    loss_fn = lambda y_pred, y_true: 1 - cosine(y_pred, y_true)
    return loss_fn

## Training

We create the data loader, the model, the optimizer, and the loss function. The data loader always provides a batch, and each batch should differ from the previous ones (especially with data augmentation, which is common in source separation but not in our case). In this context, the idea of an epoch is not well-defined. Instead, we use an epoch as a moment to validate the model and save the training progress (not implemented in this tutorial), but it doesn't necessarily mean completing one iteration over the entire dataset.

In [9]:
def training_loop(
    model,
    ds_iter: Any,
    device: str,
    n_epochs: int = 50,
    n_steps: int = 512,
    loss_fn: torch.nn.CosineSimilarity  = cosine_loss(),
    lr: float = 5e-4,
) -> Tuple[float, bool]:
    """
    Run the training loop for the given model.

    Parameters:
    model: The model to be trained.
    ds_iter (Any): An iterator over the dataset.
    device (str): The device to use for training (e.g., 'cuda' or 'cpu').
    n_epochs (int): Number of epochs to train. Default is 10.
    n_steps (int): Number of steps per epoch. Default is 512.
    loss_fn (torch.nn.CosineSimilarity): The loss function to use. Default is cosine_loss().
    lr (float): Learning rate for the optimizer. Default is 5e-4.

    Returns:
    Tuple[float, bool]: The trained model and a boolean indicating success.
    """

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    # --- TRAINING ---
    for i in range(n_epochs):
        print('\n epoch {}'.format(i))
        model.train()
        progress_bar = Progbar(n_steps + 1)
        for j in range(n_steps):
            batch = next(ds_iter)
            optimizer.zero_grad()
            mix = batch[0].to(device)
            target = batch[1].to(device)
            loss = loss_fn(model(mix), target).mean()
            loss.backward()
            optimizer.step()
            progress_bar.update(j, [("train_loss", loss.item())])
            torch.cuda.empty_cache()
    return model


duration = 4
# Initialize the data loader
ds_train = SourceSeparationDataloader(MusDBDataset('train', ['vocals'], duration, duration/2, num_workers=4), 1999, 32)
model = SourceSeparation(duration=duration).to(device)
model = training_loop(model, ds_train, device)

Buffer 9% filled.
Buffer 19% filled.
Buffer 29% filled.
Buffer 39% filled.
Buffer 49% filled.
Buffer 59% filled.
Buffer 69% filled.
Buffer 79% filled.
Buffer 89% filled.
Buffer 99% filled.

 epoch 0
[1m  2/513[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:11:58[0m 30s/step - train_loss: 0.7753

KeyboardInterrupt: 

## Evaluation
In this section, you will:

* Loop over the **MusDBDataset**: Ensure that all stems for a given song is provided. No shuffle is needed since we can evaluate each track independly. We go over the test set once and then finish the computation.

* Compute Predictions with the Model: Generate predictions for each segment using the model.

* Prepare Stems for Metrics Calculation: Always include both the source (real and predicted) and the "residual" (mixture - source) to compute the metrics accurately.


In [None]:
def get_metrics(
    ref: npt.NDArray[np.float32],
    est: npt.NDArray[np.float32],
    sr: int,
    window: int = 1,
    hop: int = 1,
) -> Dict[str, Dict[str, int]]:
    """
    Computes and logs various audio separation metrics.

    Parameters:
    ref (npt.NDArray[np.float32]): Reference signal.
    est (npt.NDArray[np.float32]): Estimated signal.
    sr (int): Sampling rate.
    window (int): Window size in seconds. Defaults to 1.
    hop (int): Hop size in seconds. Defaults to 1.

    Returns:
    Dict[str, Dict[str, int]]: A dictionary containing metrics for vocals and accompaniment (acc).
    """
    print("Computing the metrics")
    sdr, isr, sir, sar, _ = bss_eval(
        ref,
        est,
        window=window * sr,
        hop=hop * sr,
        framewise_filters=False,
        bsseval_sources_version=False,
        compute_permutation=False,
    )
    print("Done!")
    output["target"] = {"sdr": sdr[0], "isr": isr[0], "sir": sir[0], "sar": sar[0]}
    output["rest"] = {"sdr": sdr[1], "isr": isr[1], "sir": sir[1], "sar": sar[1]}

    # The median is the most standard way of comparing metrics on source separation
    print(
        "Results target: \n\tSDR: {}, \n\tSIR: {}, \n\tSAR: {}".format(
            np.nanmedian(sdr[0]), np.nanmedian(sir[0]), np.nanmedian(sar[0])
        )
    )
    print(
        "Results rest: \n\tSDR: {}, \n\tSIR: {}, \n\tSAR: {}".format(
            np.nanmedian(sdr[1]), np.nanmedian(sir[1]), np.nanmedian(sar[1])
        )
    )
    return output


ds_test = MusDBDataset('test', ['vocals'], duration, duration, num_workers=8)

model.eval()
output = {}
for track, name, rate in iter(ds_test):
    print("Processing track {}".format(name))
    m = torch.stack(list(zip(*track))[0], dim=0)
    with torch.no_grad():
        p = model(m.to(device)).detach().cpu()

    m = rearrange(m, 'b s c -> (b s) c')
    p = rearrange(p, 'b s c -> (b s) c')
    t = torch.cat(list(zip(*track))[1], dim=0)

    ref = torch.stack((t, m - t), dim=0)
    est = torch.stack((p, m - p), dim=0)
    output[name] = get_metrics(ref, est, rate)



Processing track Angels In Amplifiers - I'm Alright


## Conditioning

In this section, we delve into the use of FiLM layers for conditioning. These layers offer a straightforward method to tailor audio features based on the target instrument. While our implementation focuses on applying FiLM layers at the bottleneck layer of the encoder, this approach can be easily adapted to other layers, providing flexibility in your audio processing pipeline. This versatility ensures that you can fine-tune the conditioning to best suit your needs.

In [None]:
class FiLMBlock(nn.Module):
    """
    Feature-wise Linear Modulation (FiLM) block for conditioning.

    Attributes:
    gammas (nn.Embedding): Embedding layer for gamma values.
    betas (nn.Embedding): Embedding layer for beta values.
    """

    def __init__(self, n_channels_bottleneck: int, valid_targets: List[str] = ['vocals', 'drums', 'bass', 'other'],) -> None:
        """
        Initialize the FiLMBlock.
        This block learns instrument-specific scaling and shifting factors,
        which are applied to the features before decoding.

        Parameters:
        n_channels_bottleneck (int): Number of bottleneck channels.
        valid_targets (List[str]): List of valid targets. Default is ['vocals', 'drums', 'bass', 'other'].

        """
        super(FiLMBlock, self).__init__()
        self.valid_targets = valid_targets
        """
        Embedding layers are a way to represent categorical data
        (like instrument names) as numerical vectors that the network can
        understand. gammas and betas are learned parameters that will be
        used to modulate the audio features.
        """
        self.gammas = nn.Embedding(len(valid_targets), n_channels_bottleneck)
        self.betas = nn.Embedding(len(valid_targets), n_channels_bottleneck)

    def forward(self, x: torch.Tensor, ctxt: List) -> torch.Tensor:
        """
        Forward pass through the FiLM block.

        Parameters:
        x (torch.Tensor): Input tensor.
        ctxt (torch.Tensor): Context tensor with indices of the desired instrument.

        Returns:
        torch.Tensor: Modulated tensor.
        """
        device = x.device
        # Instrument names to numerical indices using the valid_targets list.
        ctxt = torch.tensor([self.valid_targets.index(item) for item in ctxt]).to(device)
        # We retrieve the right gamma and beta values from the embedding layers.
        gammas = rearrange(self.gammas(ctxt), 'b c -> b c 1 1')
        betas = rearrange(self.betas(ctxt), 'b c -> b c 1 1')
        # It returns the modulated audio features.
        x = gammas * x + betas
        return x


class ConditoningSourceSeparation(SourceSeparation):
    """
    Source separation model with conditioning using FiLM.

    Attributes:
    conditioning (FiLMBlock): FiLM block for conditioning.
    """
    def __init__(
        self,
        duration: float,
        **kargs: dict
    ) -> None:
        """
        Initialize the ConditoningSourceSeparation module.

        Parameters:
        duration (float): Duration of the audio.
        **kargs (dict): Additional keyword arguments for SourceSeparation.
        """
        super(ConditoningSourceSeparation, self).__init__(duration=duration, **kargs)
        self.conditioning = FiLMBlock(n_channels_bottleneck=self.encoder.block_shapes[-1][0])

    def forward(self, x: torch.Tensor, ctxt: List) -> torch.Tensor:
        """
        Forward pass through the conditioning source separation model.

        Parameters:
        x (torch.Tensor): Input tensor.
        ctxt (torch.Tensor): Context tensor with indices of the desired instrument.

        Returns:
        torch.Tensor: Output tensor after source separation and conditioning.
        """
        x = self.stft(x)
        x_encoder = self.encoder(x)
        # Conditioning happens only at the bottleneck of the encoder.
        x_encoder[-1] = self.conditioning(x_encoder[-1], ctxt)
        y = self.decoder(*x_encoder[::-1])
        if self.behavior == "masking":
            y = torch.multiply(x, y)
        return self.istft(y)


We updated training loop, which closely mirrors the previous one. The primary distinction is that the model now receives two parameters: the mixture and the desired instrument. Additionally, the dataloader has been enhanced to provide target audio segments from various instruments and different songs, ensuring a richer and more diverse training experience.


In [None]:
def training_loop(
    model,
    ds_iter: Any,
    device: str,
    n_epochs: int = 50,
    n_steps: int = 512,
    loss_fn: torch.nn.CosineSimilarity  = cosine_loss(),
    lr: float = 5e-4,
) -> Tuple[float, bool]:

    """
    Run the training loop for the given model.

    Parameters:
    model: The model to be trained.
    ds_iter (Any): An iterator over the dataset.
    device (str): The device to use for training (e.g., 'cuda' or 'cpu').
    n_epochs (int): Number of epochs to train. Default is 10.
    n_steps (int): Number of steps per epoch. Default is 512.
    loss_fn (torch.nn.CosineSimilarity): The loss function to use. Default is cosine_loss().
    lr (float): Learning rate for the optimizer. Default is 5e-4.

    Returns:
    Tuple[float, bool]: The trained model and a boolean indicating success.
    """

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    # --- TRAINING ---
    for i in range(n_epochs):
        print('\n epoch {}'.format(i))
        model.train()
        progress_bar = Progbar(n_steps + 1)
        for j in range(n_steps):
            batch = next(ds_iter)
            optimizer.zero_grad()
            mix = batch[0].to(device)
            target = batch[1].to(device)
            label = batch[2]
            loss = loss_fn(model(mix, label), target).mean()
            loss.backward()
            optimizer.step()
            progress_bar.update(j, [("train_loss", loss.item())])
            torch.cuda.empty_cache()
    return model


duration = 4
instruments = ['vocals', 'drums', 'bass', 'other']
# Initialize the data loader
ds_train = SourceSeparationDataloader(MusDBDataset('train', instruments, duration, duration/2, num_workers=8), 3999, 32)
cmodel = ConditoningSourceSeparation(duration=duration).to(device)
cmodel = training_loop(cmodel, ds_train, device)

In [None]:
ds_test = MusDBDataset('test', instruments, duration, duration, num_workers=8)
model.eval()
output = {}

for track, name, rate in iter(ds_test):
    print("Processing track {}".format(name))
    ll = list(zip(*track))[2]
    mm = torch.stack(list(zip(*track))[0], dim=0)
    tt = torch.cat(list(zip(*track))[1], dim=0)
    for i, (instrument, ctxt) in enumerate(itertools.groupby(ll)):
        print("Separating {}".format(instrument))
        ctxt = list(ctxt)
        tmp = len(ctxt)
        m = mm[i:i+tmp]

        with torch.no_grad():
            p = cmodel(m.to(device), ctxt).detach().cpu()

        m = rearrange(m, 'b s c -> (b s) c')
        p = rearrange(p, 'b s c -> (b s) c')
        t = tt[:m.shape[0]]

        ref = torch.stack((t, m - t), dim=0)
        est = torch.stack((p, m - p), dim=0)

        output.setdefault(instrument, {})
        output[instrument][name] = get_metrics(ref, est, rate)