# Pre-processing | Data loading | Data Augmentation | CutMix/MixUp in PyTorch
This notebook goes over the preprocessing of audio samples with Mel-spectrograms in pytorch, the applications of common audio and spectrogram augmentations with **audiomentations**, **torch-audiomentations** and **torchaudio**, and the use of CutMix and MixUp within the dataloader using **torchvision**. Since this data loading step is heavy if performed on CPU and can be a bottleneck for training, we will see how we can perform these computations on GPU for a more than 2x speed-up, at the cost of some GPU memory.

In [None]:
# Install necessary packages
!pip install audiomentations # CPU (numpy) augmentations
!pip install torch-audiomentations # GPU (torch) augmentations

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import pandas as pd
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import IPython.display as ipd
from datetime import datetime
import time
import librosa
import ast
plt.style.use('ggplot')

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import default_collate
from torchvision.transforms import v2

import audiomentations
import torch_audiomentations

In [None]:
# Utilities for displaying audio
def visualize_audio(waveform, sr, original=None):
    if isinstance(waveform, torch.Tensor):
        waveform = waveform.numpy()
    if len(waveform.shape) > 1:
        waveform = waveform[0]
    ipd.display(ipd.Audio(waveform, rate=sr))

    # Create subplots for audio visualizations
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(14, 4))

    # Waveform
    if original is not None:
        if isinstance(original, torch.Tensor):
            original = original.numpy()
        if len(original.shape) > 1:
            original = original[0]
        axs[0].plot(original, alpha=0.4, label='Original')
    axs[0].plot(waveform, alpha=0.6, label='Augmented')
    axs[0].set_title('Waveform')
    if original is not None:
        axs[0].legend()

    # Mel Spectrogram
    mel_spectrogram = librosa.feature.melspectrogram(y=waveform, sr=sr)
    mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
    librosa.display.specshow(mel_spectrogram, x_axis='time', y_axis='mel', sr=sr, ax=axs[1])
    axs[1].set_title('Mel Spectrogram')
    
    
def visualize_spec(spec):
    plt.figure(figsize=(7,4))
    plt.imshow(spec[0], origin='lower', cmap='magma')
    plt.grid(visible=False)
    plt.title('Mel Spectrogram')
    plt.xlabel('Time step')
    plt.ylabel('Mel bin')
    plt.show()
    
    
def visualize_spec_batch(spec_batch):
    plt.figure(figsize=(14,8))
    for k in range(spec_batch.shape[0]):
        plt.subplot(4,4,k+1)
        plt.imshow(spec_batch[k,0].detach().cpu().numpy(), origin='lower', cmap='magma')
        plt.grid(visible=False)
    plt.tight_layout()
    plt.show()

# Read Data
### Config

Config class for all parameters used in the notebook. STFT parameters are the same as the [https://www.kaggle.com/code/awsaf49/birdclef24-kerascv-starter-train](Keras starter notebook). Audio files are truncated to ```duration``` seconds, either randomly truncated or keeping only the beginning of the file depending on ```start_idx```. The parameters `n_mels` and `target_length` control the shape of the spectrograms (respectively the number of frequency bins and time bins). 

In [None]:
class Config:
    start_idx = 'first'
    duration = 10
    sample_rate = 32000
    target_length = 384
    n_mels = 128
    n_fft = 2028
    window = 2028
    audio_len = duration*sample_rate
    hop_length = audio_len // (target_length-1)
    fmin = 20
    fmax = 16000
    top_db = 80

    n_classes = 182
    batch_size = 16
    model_name = 'efficientnet_v2_s'
    
    dataset_mean = [-16.8828]
    dataset_std = [12.4019]

    data_aug = True        
    cutmix_mixup = False    
    loss = 'bce'    # ('crossentropy', 'bce')
    secondary_labels_weight = 0.3   

    base_dir = '/kaggle/input/birdclef-2024'
    short_noises = '/kaggle/input/birdclef-2023-additional/esc50/use_label'
    background_noises = ['/kaggle/input/birdclef2021-background-noise/aicrowd2020_noise_30sec/noise_30sec',
                         '/kaggle/input/birdclef2021-background-noise/ff1010bird_nocall/nocall',
                         #'/kaggle/input/birdclef2021-background-noise/train_soundscapes/nocall'
                        ]

### Files
The following reads the training files, stored in ```metadata['filepath']```. Training / vaidation split are ignored for this notebook.

In [None]:
train_dir = Config.base_dir + '/train_audio/'
class_names = sorted(os.listdir(train_dir))
n_classes = len(class_names)
class_labels = list(range(n_classes))
label2name = dict(zip(class_labels, class_names))
name2label = {v:k for k,v in label2name.items()}

def get_label_from_name(name):
    if name not in name2label.keys():
        return None
    return name2label[name]

metadata = pd.read_csv(Config.base_dir + '/train_metadata.csv')
metadata['filepath'] = train_dir + metadata.filename
metadata['target'] = metadata.primary_label.map(name2label)
metadata['secondary_targets'] = metadata.secondary_labels.map(lambda x: [get_label_from_name(name) for name in ast.literal_eval(x)])

### Reading audio files
The first step is reading the audio files. We use `torchaudio` for this, you can use `librosa` or others if you wish. All audio files are truncated (or padded) to the same duration, here 10s. We obtain torch tensors with 1 channel and 320000 time steps. Audio augmentations on CPU expect numpy arrays anyway.

In [None]:
# Truncated or pad the audio sample to the desired duration
def trunc_or_pad(waveform, audio_len, start_idx='random'):
    sig_len = waveform.shape[-1]
    diff_len = abs(sig_len - audio_len)

    if (sig_len > audio_len):
        # Truncate the signal to the given length
        if start_idx == 'random':
            start_idx = np.random.randint(0, diff_len)
        else:
            start_idx = 0
        waveform = waveform[:, start_idx:start_idx + audio_len]
    
    elif (sig_len < audio_len):
        # Length of padding to add at the beginning and end of the signal
        pad1 = np.random.randint(0, diff_len)
        pad2 = diff_len - pad1
        if isinstance(waveform, torch.Tensor):
            waveform = nn.functional.pad(waveform, pad=(pad1, pad2), mode='constant', value=0)
        else:
            waveform = np.pad(waveform, ((0, 0), (pad1, pad2)), mode='constant', constant_values=0)
    
    return waveform

file = metadata["filepath"][0]
waveform, sr = torchaudio.load(file)
waveform = trunc_or_pad(waveform, Config.audio_len)
waveform.shape

In [None]:
visualize_audio(waveform, sr)

# Audio Augmentations on CPU with Audiomentations
[Audiomentations](https://iver56.github.io/audiomentations/) is a python library for audio signal augmentations, similar to Albumentations for images. It implements many data augmentations for audio signal in the time domain. The audiomentations transforms expect numpy arrays (and not torch tensors) with the sample rate, and run on CPU. All transforms support application with random probability (controlled with the `p` parameter), and it implements the useful `Compose` and `OneOf` transforms. All available transforms are in the [documentation](https://iver56.github.io/audiomentations/), and can be visualized [here](https://phrasenmaeher-audio-transformat-visualize-transformation-5s1n4t.streamlit.app/). First we can view some useful augmentations in our case:

### Time Shift
Shift the samples forwards or backwards (randomly chosen), by default with rollover.

In [None]:
# Time shift
aug = audiomentations.Shift(min_shift=-0.5, max_shift=0.5, p=1)
waveform_aug = aug(waveform.numpy(), sr)
visualize_audio(waveform_aug, sr, original=waveform)

### Pitch Shift
Pitch shift the sound up or down without changing the tempo (scales frequencies without rescaling the time steps).

In [None]:
# Pitch shift
aug = audiomentations.PitchShift(min_semitones=-2.5, max_semitones=2.5, p=1)
waveform_aug = aug(waveform.numpy(), sr)
visualize_audio(waveform_aug, sr, original=waveform)

### Equalizer
Adjust the volume of different frequency bands.

In [None]:
# EQ
aug = audiomentations.SevenBandParametricEQ(min_gain_db=-12., max_gain_db=12., p=1)
waveform_aug = aug(waveform.numpy(), sr)
visualize_audio(waveform_aug, sr, original=waveform)

### Low pass filter
Cuts off high frequencies (higher than randomly picked cutoff frequency)

In [None]:
# Low pass filter
aug = audiomentations.LowPassFilter(min_cutoff_freq=750., max_cutoff_freq=7500., min_rolloff=12, max_rolloff=24, p=1)
waveform_aug = aug(waveform.numpy(), sr)
visualize_audio(waveform_aug, sr, original=waveform)

### Air absorbtion
Simulates attenuation of high frequencies due to air absorption. High frequencies are attenuated quicker than low frequencies with the distance traveled, this transform simulates that. It does not simulate the global attenuation of the gain with the traveled distance.

In [None]:
# Attenuation due to distance traveled
aug = audiomentations.AirAbsorption(min_temperature=10, max_temperature=20, min_humidity=30, max_humidity=90,
                                  min_distance=10, max_distance=100, p=1.)
waveform_aug = aug(waveform.numpy(), sr)
visualize_audio(waveform_aug, sr, original=waveform)

### Gain 
Multiply the audio by a random amplitude factor to reduce or increase the volume. Helps the model to be somewhat invariant to the gain (volume) of the signal. ```GainTransition``` is a gradual transformation (fade in/fade out).

In [None]:
# Gain augmentations
aug = audiomentations.OneOf((
    audiomentations.Gain(min_gain_db=-6., max_gain_db=6., p=1),
    audiomentations.GainTransition(min_gain_db=-12., max_gain_db=3., p=1))
)
waveform_aug = aug(waveform.numpy(), sr)
visualize_audio(waveform_aug, sr, original=waveform)

### Gaussian noise
Add Gaussian noise (white noise). With ```AddGaussianSNR``` the noise amplitude is chosen relatively to the amplitude of the given signal, to obtain the ramdomly chosen SNR (Signal to Noise Ration), while in ```AddGaussianNoise``` the amplitude of the noise is independent from the input signal.

In [None]:
# Gaussian Noise
aug = audiomentations.OneOf((
    audiomentations.AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1.),
    audiomentations.AddGaussianSNR(min_snr_db=5., max_snr_db=40., p=1.))
)
waveform_aug = aug(waveform.numpy(), sr)
visualize_audio(waveform_aug, sr, original=waveform)

### Pink noise (and other colors)
You can find details about all colors of noise [here](https://en.wikipedia.org/wiki/Colors_of_noise), and how to obtain them in the [audiomentations documentation](https://iver56.github.io/audiomentations/waveform_transforms/add_color_noise/). Pink noise is commonly used as an augmentation.

In [None]:
# Pink Noise
aug = audiomentations.AddColorNoise(min_snr_db=5., max_snr_db=40., min_f_decay=-3.01, max_f_decay=-3.01, p=1.)
waveform_aug = aug(waveform.numpy(), sr)
visualize_audio(waveform_aug, sr, original=waveform)

### Background and Short Noises
Mix in another sound with the input signal. Background noises are mixed for the full duration of the input, with an amplitude relative to the input. Short noises are added only to a few seconds of the input. The audio files are from the follow datasets: [background](https://www.kaggle.com/datasets/christofhenkel/birdclef2021-background-noise?select=aicrowd2020_noise_30sec) and [short](https://www.kaggle.com/datasets/atsunorifujita/birdclef-2023-additional?select=train_datasets).

Rain and frog audio files from the ESC-50 dataset, background noises are samples without bird calls from previous competitions.

In [None]:
# Background noises from other files
aug = audiomentations.Compose([
    audiomentations.AddShortNoises(sounds_path=Config.short_noises, min_snr_db=3., max_snr_db=30., 
                               noise_rms='relative_to_whole_input',
                               min_time_between_sounds=2., max_time_between_sounds=8., 
                               noise_transform=audiomentations.PolarityInversion(), p=1.),
    audiomentations.AddBackgroundNoise(sounds_path=Config.background_noises, min_snr_db=3., max_snr_db=30., 
                                   noise_transform=audiomentations.PolarityInversion(), p=1.)
])

waveform_aug = aug(waveform.squeeze().numpy(), sr)[None, :] # These transforms want monochannel audio
visualize_audio(waveform_aug, sr, original=waveform)

### Radomly compose augmentations
We define the augmentations for the audio signal in the following:

In [None]:
# Combining random audio transforms
waveform_transforms = audiomentations.Compose([
    audiomentations.Shift(min_shift=-0.5, max_shift=0.5, p=0.5),
    audiomentations.SevenBandParametricEQ(min_gain_db=-12., max_gain_db=12., p=0.5),
    audiomentations.AirAbsorption(min_temperature=10, max_temperature=20, min_humidity=30, max_humidity=90,
                                  min_distance=10, max_distance=100, p=1.), 

    audiomentations.OneOf([
        audiomentations.Gain(min_gain_db=-6., max_gain_db=6., p=1),
        audiomentations.GainTransition(min_gain_db=-12., max_gain_db=3., p=1)
    ], p=1.),

    audiomentations.OneOf([
        audiomentations.AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1.),
        audiomentations.AddGaussianSNR(min_snr_db=5., max_snr_db=40., p=1.),
        audiomentations.AddColorNoise(min_snr_db=5., max_snr_db=40., min_f_decay=-3.01, max_f_decay=-3.01, p=1.)
    ], p=1.),

    audiomentations.AddShortNoises(sounds_path=Config.short_noises, min_snr_db=3., max_snr_db=30., 
                               noise_rms='relative_to_whole_input',
                               min_time_between_sounds=2., max_time_between_sounds=8., 
                               noise_transform=audiomentations.PolarityInversion(), p=0.5),
    audiomentations.AddBackgroundNoise(sounds_path=Config.background_noises, min_snr_db=3., max_snr_db=30., 
                                   noise_transform=audiomentations.PolarityInversion(), p=0.5),
                                   
    audiomentations.LowPassFilter(min_cutoff_freq=750., max_cutoff_freq=7500., min_rolloff=12, max_rolloff=24, p=0.8),
    audiomentations.PitchShift(min_semitones=-2.5, max_semitones=2.5, p=0.3)
])

waveform_aug = waveform_transforms(waveform.squeeze().numpy(), sr)[None, :]
visualize_audio(waveform_aug, sr, original=waveform)

# Spectrogram augmentations
The Mel-spectrogram is computed with torchaudio, in the dB scale, and then normalized using the statistics of the whole dataset. This is done to have the same transformation for all samples, and is the common procedure for natural images with ImageNet statistics (which are of course not adequate here). The mean and standard devation can be found in the Config class, and can be computed with the following code if using more samples from other datasets for example:
```python
# Compute dataset mean
base_mean = 0
base_std = 0

for k in tqdm(range(len(train_dataset))):
    spec, _ = train_dataset[k]
    base_mean += spec.mean()
    base_std += spec.std()

base_mean = base_mean/len(train_dataset)
base_std = base_std/len(train_dataset)
```

In [None]:
# Compute spectrogram
to_mel_spectrogramn = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(Config.sample_rate, n_fft=Config.n_fft, win_length=Config.window,  
                                         hop_length=Config.hop_length, n_mels=Config.n_mels, 
                                         f_min=Config.fmin, f_max=Config.fmax),
    torchaudio.transforms.AmplitudeToDB(top_db=Config.top_db),
    v2.Normalize(mean=Config.dataset_mean, std=Config.dataset_std)
)

spec = to_mel_spectrogramn(waveform)
spec.shape
visualize_spec(spec)

### Time and Frequency masking
Common augmentations for spectrograms are time and frequency masking, which respectively mask vertical and horizontal bands of the spectrograms. Keep in mind that most augmentations usually used with images make no sense for spectrograms, as they do not have the same invariances and symmetries of natural images. 

The Transforms are cutom implementations based on torchaudio transforms to add randomness and multiple bands. The maximum number of masked bands is controled by ```n_masks```, the maximum width of the bands by ```max_mask_pct``` and the probability of each mask in ```n_masks``` by ```prob```. Masked bands are set to either zero or the mean of the spectrogram.

In [None]:
# Spetrogram transforms
class FrequencyMaskingAug(torchaudio.transforms.FrequencyMasking):
    def __init__(self, prob, max_mask_pct, n_mels, n_masks, mask_mode='mean'):
        self.prob = prob
        self.freq_mask_param = max_mask_pct * n_mels
        self.n_masks = n_masks
        self.mask_mode = mask_mode
        super(FrequencyMaskingAug, self).__init__(self.freq_mask_param)
    def forward(self, specgram):
        if self.mask_mode == 'mean':
            mask_value = specgram.mean()
        else:
            mask_value = 0
        
        for _ in range(self.n_masks):
            # for batched inputs
            if len(specgram.shape) > 3:
                for k in range(specgram.size(0)):
                    if np.random.random() < self.prob:
                        if self.mask_mode == 'mean':
                            mask_value = specgram[k].mean()
                        else:
                            mask_value = 0
                        specgram[k] = super().forward(specgram[k], mask_value)
            else:
                if np.random.random() < self.prob:
                    specgram = super().forward(specgram, mask_value)

        return specgram
    
class TimeMaskingAug(torchaudio.transforms.TimeMasking):
    def __init__(self, prob, max_mask_pct, n_steps, n_masks, mask_mode='mean'):
        self.prob = prob
        self.time_mask_param = max_mask_pct * n_steps
        self.n_masks = n_masks
        self.mask_mode = mask_mode
        super(TimeMaskingAug, self).__init__(self.time_mask_param)
    def forward(self, specgram):
        if self.mask_mode == 'mean':
            mask_value = specgram.mean()
        else:
            mask_value = 0
        
        for _ in range(self.n_masks):
            # for batched inputs
            if len(specgram.shape) > 3:
                for k in range(specgram.size(0)):
                    if np.random.random() < self.prob:
                        if self.mask_mode == 'mean':
                            mask_value = specgram[k].mean()
                        else:
                            mask_value = 0
                        specgram[k] = super().forward(specgram[k], mask_value)
            else:
                if np.random.random() < self.prob:
                    specgram = super().forward(specgram, mask_value)

        return specgram
    

spec_transforms = nn.Sequential(
    FrequencyMaskingAug(0.3, 0.1, Config.n_mels, n_masks=3, mask_mode='mean'),
    TimeMaskingAug(0.3, 0.1, Config.target_length, n_masks=3, mask_mode='mean'),
)

In [None]:
spec_aug = spec_transforms(spec)
visualize_spec(spec_aug)

# Dataset
We can now write a Dataset object implementing the augmentations and spectrogram computation:

In [None]:
# Truncate or pad the audio sample to the desired duration
def trunc_or_pad(waveform, audio_len, start_idx='random'):
    sig_len = waveform.shape[-1]
    diff_len = abs(sig_len - audio_len)

    if (sig_len > audio_len):
        # Truncate the signal to the given length
        if start_idx == 'random':
            start_idx = np.random.randint(0, diff_len)
        else:
            start_idx = 0
        waveform = waveform[:, start_idx:start_idx + audio_len]
    
    elif (sig_len < audio_len):
        # Length of padding to add at the beginning and end of the signal
        pad1 = np.random.randint(0, diff_len)
        pad2 = diff_len - pad1
        if isinstance(waveform, torch.Tensor):
            waveform = nn.functional.pad(waveform, pad=(pad1, pad2), mode='constant', value=0)
        else:
            waveform = np.pad(waveform, ((0, 0), (pad1, pad2)), mode='constant', constant_values=0)
    
    return waveform


class AudioDataset(Dataset):
    def __init__(
            self, 
            df, 
            n_classes,
            start_idx = 'random',
            duration = 10,
            sample_rate = 32000,
            target_length = 384,
            n_mels = 128,
            n_fft = 2028,
            window = 2028,
            hop_length = None,
            fmin = 20,
            fmax = 16000,
            top_db = 80,
            waveform_transforms=None,
            spec_transforms=None,
            mean=None,
            std=None,
            loss='crossentropy',
            secondary_labels_weight=0.
            ):
        super(AudioDataset, self).__init__()
        self.df = df
        self.n_classes = n_classes
        self.start_idx = start_idx
        self.duration = duration
        self.sample_rate = sample_rate
        self.audio_len = duration*sample_rate
        self.target_length = target_length
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.window = window
        self.hop_length = self.audio_len // (target_length-1) if not hop_length else hop_length
        self.fmin = fmin
        self.fmax = fmax
        self.top_db = top_db
        self.loss = loss
        self.secondary_labels_weight = secondary_labels_weight

        self.to_mel_spectrogramn = nn.Sequential(
            torchaudio.transforms.MelSpectrogram(self.sample_rate, n_fft=self.n_fft, win_length=self.window,  
                                                 hop_length=self.hop_length, n_mels=self.n_mels, 
                                                 f_min=self.fmin, f_max=self.fmax),
            torchaudio.transforms.AmplitudeToDB(top_db=self.top_db)
        )
        if mean is not None:
            self.to_mel_spectrogramn.append(v2.Normalize(mean=mean, std=std))

        self.waveform_transforms = waveform_transforms
        self.spec_transforms  = spec_transforms
        

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]

        label = torch.tensor(item['target'])
        if self.loss == 'bce':
            label = nn.functional.one_hot(label, num_classes=self.n_classes).float()
            for l in item['secondary_targets']:
                if l is not None:
                    label += nn.functional.one_hot(torch.tensor(l), num_classes=self.n_classes)*self.secondary_labels_weight

        file = item['filepath']
        waveform, sr = torchaudio.load(file)
        waveform = trunc_or_pad(waveform, self.audio_len, self.start_idx)

        if self.waveform_transforms is not None:
            waveform = self.waveform_transforms(waveform.squeeze().numpy(), sr)[None, :]
            waveform = torch.Tensor(waveform)

        spec = self.to_mel_spectrogramn(waveform)

        if self.spec_transforms is not None:
            spec = self.spec_transforms(spec)

        # expand to 3 channels for imagenet trained models
        spec = spec.expand(3,-1,-1)

        return spec, label

In [None]:
train_dataset = AudioDataset(
    metadata, 
    n_classes=Config.n_classes,
    duration=Config.duration,
    sample_rate=Config.sample_rate,
    target_length=Config.target_length,
    n_mels=Config.n_mels,
    n_fft=Config.n_fft,
    window=Config.window,
    hop_length=Config.hop_length,
    fmin=Config.fmin,
    fmax=Config.fmax,
    top_db=Config.top_db,
    waveform_transforms=waveform_transforms,
    spec_transforms=spec_transforms,
    mean=Config.dataset_mean,
    std=Config.dataset_std,
    loss=Config.loss,
    secondary_labels_weight=Config.secondary_labels_weight
    )

# CutMix / MixUp in Dataloader
CutMix and MixUp are data augmentation methods which mix two different training samples, both in their data and in their labels. More details in their respective papers: [CutMix](https://arxiv.org/abs/1905.04899) and [MixUp](https://arxiv.org/abs/1710.09412).

As these transforms use multiple samples, they cannot be implemented in the same way as the previous ones. We can implement them easily with [Torchvision](https://pytorch.org/vision/main/auto_examples/transforms/plot_cutmix_mixup.html) inside of the Dataloader. However, the torchvision transforms expects a number for each label and do not support one-hot encoded labels which you will use with BCE (Binary Cross Entropy) loss. We can easily modify them to add support for one-hot encoded labels. The only modification is in ```_BaseMixUpCutMix``` with ```one_hot_labels```:

In [None]:
from typing import Any, Callable, Dict, List, Tuple
import math
import numbers
import warnings
import PIL.Image
from torchvision.transforms.v2._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F
class _BaseMixUpCutMix(v2.Transform):
    def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default", one_hot_labels: bool = False) -> None:
        super().__init__()
        self.alpha = float(alpha)
        self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

        self.num_classes = num_classes

        self._labels_getter = _parse_labels_getter(labels_getter)
        self.one_hot_labels = one_hot_labels

    def forward(self, *inputs):
        inputs = inputs if len(inputs) > 1 else inputs[0]
        flat_inputs, spec = tree_flatten(inputs)
        needs_transform_list = self._needs_transform_list(flat_inputs)

        if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
            raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")

        labels = self._labels_getter(inputs)
        if not isinstance(labels, torch.Tensor):
            raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
        elif labels.ndim != 1 and not self.one_hot_labels:
            raise ValueError(
                f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead."
            )
        elif labels.ndim != 2 and labels.size(-1) != self.num_classes and self.one_hot_labels:
            raise ValueError(
                f"labels tensor should be of shape (batch_size, {self.num_classes}) " f"but got shape {labels.shape} instead."
            )

        params = {
            "labels": labels,
            "batch_size": labels.shape[0],
            **self._get_params(
                [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
            ),
        }

        # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
        # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
        needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
        flat_outputs = [
            self._transform(inpt, params) if needs_transform else inpt
            for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
        ]

        return tree_unflatten(flat_outputs, spec)

    def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
        expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
        if inpt.ndim != expected_num_dims:
            raise ValueError(
                f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
            )
        if inpt.shape[0] != batch_size:
            raise ValueError(
                f"The batch size of the image or video does not match the batch size of the labels: "
                f"{inpt.shape[0]} != {batch_size}."
            )

    def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
        if not self.one_hot_labels:
            label = one_hot(label, num_classes=self.num_classes)
        if not label.dtype.is_floating_point:
            label = label.float()
        return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))


class MixUp(_BaseMixUpCutMix):
    """Apply MixUp to the provided batch of images and labels.

    Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.

    .. note::
        This transform is meant to be used on **batches** of samples, not
        individual images. See
        :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
        examples.
        The sample pairing is deterministic and done by matching consecutive
        samples in the batch, so the batch needs to be shuffled (this is an
        implementation detail, not a guaranteed convention.)

    In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
    into a tensor of shape ``(batch_size, num_classes)``.

    Args:
        alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
        num_classes (int): number of classes in the batch. Used for one-hot-encoding.
        labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
            By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
            common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
            It can also be a callable that takes the same input as the transform, and returns the labels.
    """

    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        return dict(lam=float(self._dist.sample(())))  # type: ignore[arg-type]

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        lam = params["lam"]

        if inpt is params["labels"]:
            return self._mixup_label(inpt, lam=lam)
        elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
            self._check_image_or_video(inpt, batch_size=params["batch_size"])

            output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))

            if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
                output = tv_tensors.wrap(output, like=inpt)

            return output
        else:
            return inpt


class CutMix(_BaseMixUpCutMix):
    """Apply CutMix to the provided batch of images and labels.

    Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
    <https://arxiv.org/abs/1905.04899>`_.

    .. note::
        This transform is meant to be used on **batches** of samples, not
        individual images. See
        :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
        examples.
        The sample pairing is deterministic and done by matching consecutive
        samples in the batch, so the batch needs to be shuffled (this is an
        implementation detail, not a guaranteed convention.)

    In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
    into a tensor of shape ``(batch_size, num_classes)``.

    Args:
        alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
        num_classes (int): number of classes in the batch. Used for one-hot-encoding.
        labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
            By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
            common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
            It can also be a callable that takes the same input as the transform, and returns the labels.
    """

    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        lam = float(self._dist.sample(()))  # type: ignore[arg-type]

        H, W = query_size(flat_inputs)

        r_x = torch.randint(W, size=(1,))
        r_y = torch.randint(H, size=(1,))

        r = 0.5 * math.sqrt(1.0 - lam)
        r_w_half = int(r * W)
        r_h_half = int(r * H)

        x1 = int(torch.clamp(r_x - r_w_half, min=0))
        y1 = int(torch.clamp(r_y - r_h_half, min=0))
        x2 = int(torch.clamp(r_x + r_w_half, max=W))
        y2 = int(torch.clamp(r_y + r_h_half, max=H))
        box = (x1, y1, x2, y2)

        lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

        return dict(box=box, lam_adjusted=lam_adjusted)

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        if inpt is params["labels"]:
            return self._mixup_label(inpt, lam=params["lam_adjusted"])
        elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
            self._check_image_or_video(inpt, batch_size=params["batch_size"])

            x1, y1, x2, y2 = params["box"]
            rolled = inpt.roll(1, 0)
            output = inpt.clone()
            output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]

            if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
                output = tv_tensors.wrap(output, like=inpt)

            return output
        else:
            return inpt

CutMix and MixUp are (optionnaly) used in the collate function of the dataloader, with a ramdom application. 

In [None]:
cutmix_or_mixup = v2.RandomApply([
    v2.RandomChoice([
        CutMix(num_classes=Config.n_classes, alpha=0.5, one_hot_labels=Config.loss=='bce'),
        MixUp(num_classes=Config.n_classes, alpha=0.5, one_hot_labels=Config.loss=='bce')
    ], p=[0.65, 0.35])
], p=0.7)


def mix_collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))

collate_fn = mix_collate_fn if Config.cutmix_mixup else None

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)

### Result of all augmentations
Augmentations visualized for a batch:

In [None]:
spec_batch, labels = next(iter(train_loader))
visualize_spec_batch(spec_batch)

# Augmentations and Spectrograms on GPU
Data loading can be quite long with the computation of spectrograms and the data augmentations, especially since we cannot use a high number of workers in the dataloader on Kaggle. The following code shows how long it takes to go through the whole dataset (around 25 min on Kaggle):

In [None]:
# Uncomment to see the duration on CPU (20-30min)
#for spec_batch, labels in tqdm(train_loader):
#    pass

We can instead perform the augmentations and spectrogram computation on GPU to avoid the dataloading bottleneck. All augmentations are done after data loading, on batches of data. We first define a Dataset which only returns unmodified audio samples, and a Dataloader which contains the same CutMix and MixUp augments, which are done on the original audio signals this time:

In [None]:
# Dataset returning truncated audio samples
class AudioDataset(Dataset):
    def __init__(
            self, 
            df, 
            n_classes,
            start_idx = 'random',
            duration = 10,
            sample_rate = 32000,
            loss='crossentropy',
            secondary_labels_weight=0.
            ):
        super(AudioDataset, self).__init__()
        self.df = df
        self.n_classes = n_classes
        self.start_idx = start_idx
        self.duration = duration
        self.sample_rate = sample_rate
        self.audio_len = duration*sample_rate
        self.loss = loss
        self.secondary_labels_weight = secondary_labels_weight
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]

        label = torch.tensor(item['target'])
        if self.loss == 'bce':
            label = nn.functional.one_hot(label, num_classes=self.n_classes).float()
            for l in item['secondary_targets']:
                if l is not None:
                    label += nn.functional.one_hot(torch.tensor(l), num_classes=self.n_classes)*self.secondary_labels_weight

        file = item['filepath']
        waveform, sr = torchaudio.load(file)
        waveform = trunc_or_pad(waveform, self.audio_len, self.start_idx)
        return waveform, label

train_dataset = AudioDataset(
    metadata, 
    n_classes=Config.n_classes,
    duration=Config.duration,
    sample_rate=Config.sample_rate,
    loss=Config.loss,
    secondary_labels_weight=Config.secondary_labels_weight
    )

# CutMix and MixUp in the dataloader, this time directly on audio files
cutmix_or_mixup = v2.RandomApply([
    v2.RandomChoice([
        CutMix(num_classes=Config.n_classes, alpha=0.5, one_hot_labels=Config.loss=='bce'),
        MixUp(num_classes=Config.n_classes, alpha=0.5, one_hot_labels=Config.loss=='bce')
    ], p=[0.65, 0.35])
], p=0.7)

def mix_collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))

collate_fn = mix_collate_fn if Config.cutmix_mixup else None

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)

In [None]:
# Get one batch on device
device = torch.device('cuda')

waveform_batch, labels = next(iter(train_loader))
waveform_batch = waveform_batch.to(device)
labels = labels.to(device)
waveform_batch.shape, labels.shape

### Audio transforms
Since Audiomentations is made for CPU data augmentations on numpy array, we instead use the [torch-audiomentations](https://github.com/asteroid-team/torch-audiomentations/tree/main) library. The transforms expect batches of torch tensors this time, can be computed on GPU and support  differentiability if you want to use them in the middle on your model. However, not all Audiomentations transforms are available with torch-audiomentations. Each transform has three parameters to control the random application: ```p``` is the probabilty of applying the transform, ```mode``` controls  whether the transformation is individually applied to each sample of the batch (with different random parameters) or uniformly to the whole batch, and ```p_mode``` if the random choice of applying the transform is for each sample or for the whole batch. More details in the documentation; We set them both to ```per_example``` to have the same behavior as before. Missing transforms can be done on CPU before the others.

In [None]:
waveform_transforms = torch_audiomentations.Compose([
    torch_audiomentations.Shift(min_shift=-0.5, max_shift=0.5, p=0.5, mode='per_example', p_mode='per_example'),
    
    torch_audiomentations.Gain(min_gain_in_db=-6., max_gain_in_db=6., p=1, mode='per_example', p_mode='per_example'),
    
    torch_audiomentations.AddColoredNoise(min_snr_in_db=5., max_snr_in_db=40., min_f_decay=0, max_f_decay=0, 
                                          p=1., mode='per_example'), # Gaussian (white) Noise
    torch_audiomentations.AddColoredNoise(min_snr_in_db=5., max_snr_in_db=40., min_f_decay=-3.01, max_f_decay=-3.01, 
                                          p=1., mode='per_example'), # Pink Noise

    # AddBackgroundNoise only supports '.wav' files, you can convert the files or try to modify the transform
    #torch_audiomentations.AddBackgroundNoise(background_paths=Config.short_noises, min_snr_in_db=3., max_snr_in_db=30., p=0.5),
    #torch_audiomentations.AddBackgroundNoise(background_paths=Config.background_noises, min_snr_in_db=3., max_snr_in_db=30., p=0.5),
                                   
    torch_audiomentations.LowPassFilter(min_cutoff_freq=750., max_cutoff_freq=7500., p=0.8, mode='per_example', p_mode='per_example'),
    torch_audiomentations.PitchShift(sample_rate=Config.sample_rate, min_transpose_semitones=-2.5, max_transpose_semitones=2.5, 
                                     p=0.3, mode='per_example', p_mode='per_example')
])

In [None]:
# Apply audio transforms to the batch
waveform_batch = waveform_transforms(waveform_batch, sr)

### Spectrogram transforms
The same torchaudio spectrogram transforms and augmentation can be used on batched tensors, on GPU. 

In [None]:
to_mel_spectrogramn = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(Config.sample_rate, n_fft=Config.n_fft, win_length=Config.window,  
                                         hop_length=Config.hop_length, n_mels=Config.n_mels, 
                                         f_min=Config.fmin, f_max=Config.fmax),
    torchaudio.transforms.AmplitudeToDB(top_db=Config.top_db),
    v2.Normalize(mean=Config.dataset_mean, std=Config.dataset_std)
).to(device)

spec_transforms = nn.Sequential(
    FrequencyMaskingAug(0.3, 0.1, Config.n_mels, n_masks=3, mask_mode='mean'),
    TimeMaskingAug(0.3, 0.1, Config.target_length, n_masks=3, mask_mode='mean'),
).to(device)

In [None]:
# Compute spectrogram and apply transforms
spec_batch = spec_transforms(to_mel_spectrogramn(waveform_batch))
visualize_spec_batch(spec_batch)

###  GPU results
With spectrograms and augmentations of GPU, the time to go through the whole dataset is reduced to less than 10 minuted on Kaggle:

In [None]:
# Uncomment to see the duration on GPU (around 8min)
#for waveform_batch, labels in tqdm(train_loader):
#    waveform_batch = waveform_batch.to(device)
#    labels = labels.to(device)
#    waveform_batch = waveform_transforms(waveform_batch, sr)
#    spec_batch = spec_transforms(to_mel_spectrogramn(waveform_batch))