In [None]:
import numpy as np
from torch.utils.data import Dataset
import librosa
import soundfile as sf
import random

import matplotlib.pyplot as plt
import audiomentations as am
import IPython.display as ipd
# import torch_audiomentations as ta

In [None]:
class ContrastiveAudioDataset(Dataset):
    def __init__(self, dataset, sample_rate=44100):
        """
        Args:
            dataset: The dataset to be used (expecting free-music-archive-retrieval).
            sample_rate: The target sample rate.
        """
        self.dataset = dataset
        self.sample_rate = sample_rate

        self.audiomentations = am.Compose([
            am.AdjustDuration(duration_seconds=5.0, p=1),
            am.OneOf([
                # am.AddBackgroundNoise(p=1),
                am.Gain(min_gain_db=-10, max_gain_db=5, p=1),
                am.AddGaussianNoise(min_amplitude=0.01, max_amplitude=0.03, p=1),
                am.OneOf([
                    am.HighPassFilter(min_cutoff_freq=500, max_cutoff_freq=1000, p=1),
                    am.BandPassFilter(min_center_freq=500, max_center_freq=1000, p=1),
                    am.BandStopFilter(min_center_freq=500, max_center_freq=1000, p=1),
                    am.LowPassFilter(min_cutoff_freq=500, max_cutoff_freq=1000, p=1),
                ], p=1),
                am.PolarityInversion(p=1),
                am.TimeStretch(min_rate=0.8, max_rate=1.25, p=1),
                am.TimeMask(min_band_part=0.1, max_band_part=0.2, p=1),
                am.PitchShift(min_semitones=-4, max_semitones=4, p=1),
            ], p=1,)
        ])

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        try:
            sample = self.dataset[idx]

            # Process original audio
            audio_data = sample["audio"]["array"]
            if(self.sample_rate != 44100):
                audio_data = librosa.resample(audio_data, orig_sr=44100, target_sr=self.sample_rate)

            if(random.random() < 1/8):
                # use existed q_audio_back for background noise
                transformed = librosa.resample(sample["q_audio_back"]["array"], orig_sr=44100, target_sr=self.sample_rate)
            else:
                # apply other transformation
                transformed = self.audiomentations(audio_data, sample_rate=self.sample_rate)

            return {
                "original": audio_data,
                "transformed": transformed,
            }
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            # Return a valid sample with zeros if there's an error
            return {
                "original": np.zeros(5 * self.sample_rate),
                "transformed": np.zeros(5 * self.sample_rate),
            }

## Demo

In [None]:
sample, sr = librosa.load("example.wav", sr=44100)

# test audiomeantations
def test_audiomentations(aug: am.Compose, audio_data: np.ndarray, sample_rate: int):
    # Convert to float32
    audio_data = audio_data.astype(np.float32)

    # Apply transformations
    transformed_audio = aug(samples=audio_data, sample_rate=sample_rate)
    return transformed_audio

def compare_spectrogram(sample, processed):
    S = librosa.stft(sample)
    S2 = librosa.stft(processed)
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 1, 1)
    librosa.display.specshow(librosa.amplitude_to_db(np.abs(S), ref=np.max), y_axis='log', x_axis='time')
    plt.title('Original Audio')
    plt.colorbar(format='%+2.0f dB')
    plt.subplot(2, 1, 2)
    librosa.display.specshow(librosa.amplitude_to_db(np.abs(S2), ref=np.max), y_axis='log', x_axis='time')
    plt.title('Processed Audio')
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()
    plt.show()

print("Original Audio:")
ipd.Audio(sample, rate=sr, normalize=False)

In [None]:
print("Gain:")
# aug = am.Gain(min_gain_db=-10, max_gain_db=5, p=1)
aug = am.Gain(min_gain_db=-10, max_gain_db=-10, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
ipd.Audio(processed, rate=sr, normalize=False)

In [None]:
print("GaussianNoise:")
aug = am.AddGaussianNoise(min_amplitude=0.01, max_amplitude=0.03, p=1)
# aug = am.AddGaussianNoise(min_amplitude=0.01, max_amplitude=0.01, p=1)
# aug = am.AddGaussianNoise(min_amplitude=0.03, max_amplitude=0.03, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)

In [None]:
print("HighPassFilter:")
aug = am.HighPassFilter(min_cutoff_freq=500, max_cutoff_freq=1000, p=1)
# aug = am.HighPassFilter(min_cutoff_freq=500, max_cutoff_freq=500, p=1)
# aug = am.HighPassFilter(min_cutoff_freq=1000, max_cutoff_freq=1000, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
# compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)

In [None]:
print("LowPassFilter:")
aug = am.LowPassFilter(min_cutoff_freq=500, max_cutoff_freq=1000, p=1)
# aug = am.LowPassFilter(min_cutoff_freq=500, max_cutoff_freq=500, p=1)
# aug = am.LowPassFilter(min_cutoff_freq=1000, max_cutoff_freq=1000, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)

In [None]:
print("BandPassFilter:")
aug = am.BandPassFilter(min_center_freq=500, max_center_freq=1000, p=1)
# aug = am.BandPassFilter(min_center_freq=500, max_center_freq=500, p=1)
# aug = am.BandPassFilter(min_center_freq=1000, max_center_freq=1000, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)

In [None]:
print("BandStopFilter:")
aug = am.BandStopFilter(min_center_freq=500, max_center_freq=1000, p=1)
# aug = am.BandStopFilter(min_center_freq=500, max_center_freq=500, p=1)
# aug = am.BandStopFilter(min_center_freq=1000, max_center_freq=1000, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)

In [None]:
print("PolarityInversion:")
aug = am.PolarityInversion(p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
ipd.Audio(processed, rate=sr)

In [None]:
print("TimeStretch:")
aug = am.TimeStretch(min_rate=0.8, max_rate=1.25, p=1)
# aug = am.TimeStretch(min_rate=0.8, max_rate=0.8, p=1)
# aug = am.TimeStretch(min_rate=1.25, max_rate=1.25, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)

In [None]:
print("TimeMask:")
aug = am.TimeMask(min_band_part=0.1, max_band_part=0.2, p=1)
# aug = am.TimeMask(min_band_part=0.1, max_band_part=0.1, p=1)
# aug = am.TimeMask(min_band_part=0.2, max_band_part=0.2, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)

In [None]:
print("PitchShift:")
aug = am.PitchShift(min_semitones=-4, max_semitones=4, p=1)
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)

In [None]:
print("Adjust Durattion and apply one of them:")
aug = am.Compose([
    am.AdjustDuration(duration_seconds=5.0, p=1),
    am.OneOf([
        # am.AddBackgroundNoise(p=1),
        am.Gain(min_gain_db=-10, max_gain_db=5, p=1),
        am.AddGaussianNoise(min_amplitude=0.01, max_amplitude=0.03, p=1),
        am.OneOf([
            am.HighPassFilter(min_cutoff_freq=500, max_cutoff_freq=1000, p=1),
            am.BandPassFilter(min_center_freq=500, max_center_freq=1000, p=1),
            am.BandStopFilter(min_center_freq=500, max_center_freq=1000, p=1),
            am.LowPassFilter(min_cutoff_freq=500, max_cutoff_freq=1000, p=1),
        ], p=1),
        am.PolarityInversion(p=1),
        am.TimeStretch(min_rate=0.8, max_rate=1.25, p=1),
        am.TimeMask(min_band_part=0.1, max_band_part=0.2, p=1),
        am.PitchShift(min_semitones=-4, max_semitones=4, p=1),
    ], p=1,)
])
processed = test_audiomentations(aug, sample, 44100)
sf.write("processed.wav", processed, 44100)
compare_spectrogram(sample, processed)
ipd.Audio(processed, rate=sr)