In [None]:
import torch
import torchaudio.functional as F

from senhance.data.audio import Audio
from senhance.data.augmentations.background_noise import BackgroundNoise
from senhance.data.augmentations.reverb import Reverb
from senhance.data.augmentations.clipping import Clipping
from senhance.data.augmentations.speed import Speed
from senhance.data.augmentations.dither import Dither
from senhance.data.augmentations.flanger import Flanger
from senhance.data.augmentations.overdrive import Overdrive
from senhance.data.augmentations.phaser import Phaser
from senhance.data.augmentations.filters import LowPass, HighPass, BandPass
from senhance.data.augmentations.chain import Chain
from senhance.data.augmentations.default import get_default_augmentation

from IPython.display import display
from IPython.display import Audio as AudioPlayer

In [2]:
x = Audio("/data/denoising/speech/daps/clean/f10_script1_clean.wav")
x = x.resample(24_000)
bnoise = BackgroundNoise(
    "/data/denoising/noise/records/DEMAND/48k/index.train.json",
    min_snr=5.0,
    max_snr=25.0,
    p=1.0,
)
reverb = Reverb(
    ir_index_path="/data/denoising/noise/irs/RoyJames/OPENAIR/IRs/air-museum/index.json",
    # min_drr=0.0,
    # max_drr=1.0,
    p=0.5,
)
clipping = Clipping(min_clip_percentile=0.0, max_clip_percentile=0.1, p=0.8)
low_pass = LowPass(freqs_hz=torch.linspace(1000, 24000, 10).tolist(), p=1.0)
high_pass = HighPass(freqs_hz=[4000, 8000], p=1.0)
band_pass = BandPass(bands_hz=[[400, 800]], p=1.0)
speed = Speed(min_factor=0.5, max_factor=1.5, p=1.0)
dither = Dither()
flanger = Flanger()
overdrive = Overdrive(min_gain=10, max_gain=90, min_colour=10, max_colour=90)
phaser = Phaser(
    min_gain_in=0,
    max_gain_in=1,
    min_gain_out=0,
    max_gain_out=100,
    min_delay_ms=0,
    max_delay_ms=5,
    min_decay=0,
    max_decay=0.99,
    min_mod_speed=0.1,
    max_mod_speed=2,
)

chain = Chain(bnoise, reverb, low_pass)

In [None]:
AudioPlayer(x.random_excerpt(1.7).waveform.numpy(), rate=x.sample_rate)

In [None]:
excerpt = x.random_excerpt(3).normalize(-24.0)
AudioPlayer(excerpt.waveform.numpy(), rate=excerpt.sample_rate)

In [None]:
augment = get_default_augmentation(sequence_length_s=64 / 75, split="train", p=1.0)
# augment = high_pass
augment_params = augment.sample_parameters(excerpt)
print(augment_params)

In [None]:
augmented = augment.augment(excerpt.waveform[None], augment_params)
AudioPlayer(augmented[0], rate=excerpt.sample_rate)

In [None]:
import matplotlib.pyplot as plt

from senhance.data.stft import MelSpectrogram

mel_spectrogram = MelSpectrogram(1024, 256, 80, x.sample_rate)
mels = mel_spectrogram.magnitudes(excerpt.waveform)
plt.figure()
plt.imshow(mels[0].log().flip(0), aspect="auto")
plt.figure()
mels = mel_spectrogram.magnitudes(augmented)
plt.imshow(mels[0].log().flip(0), aspect="auto")

In [8]:
# noise = augment_params.noise[0]
# print(augment_params.snr[0])
# AudioPlayer(noise.numpy(), rate=excerpt.sample_rate)

In [9]:
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(4) as executor:
    futures = [executor.submit(x.salient_excerpt, 1) for _ in range(16)]
    excerpts = [future.result() for future in futures]

In [None]:
from senhance.data.augmentations.augmentations import BatchAugmentationParameters
augment_params = [augment.sample_parameters(excerpt) for excerpt in excerpts]
augment_params = BatchAugmentationParameters.collate(augment_params)

In [11]:
excerpts_waveforms = torch.stack([excerpt.waveform for excerpt in excerpts])
augmented = augment.augment(excerpts_waveforms, augment_params)

In [None]:
print(((augmented - excerpts_waveforms).abs()).sum(dim=(1, 2)))
((augmented - excerpts_waveforms).abs() < 1e-5).sum(dim=(1, 2))

In [None]:
for i in range(8):
    print(augment_params[i].params)
    display(AudioPlayer(augmented[i], rate=24000))

In [None]:
ir = (
    Audio(
        "/data/denoising/noise/irs/RoyJames/OPENAIR/IRs/air-museum/b-format/AR_bformat_S1R1_1.wav"
    )
    .mono()
    .resample(excerpt.sample_rate)
)
print(ir.sample_rate)

In [None]:
display(AudioPlayer(excerpt.waveform, rate=excerpt.sample_rate))
display(AudioPlayer(ir.waveform, rate=ir.sample_rate))

In [174]:
rir = ir.waveform[..., ir.waveform.argmax() - 1 :]
rir = rir / torch.linalg.vector_norm(rir, ord=2)
rir = torch.nn.functional.pad(rir, (0, excerpt.waveform.shape[-1] - rir.shape[-1]))
out = F.fftconvolve(excerpt.waveform[None], rir[None])[
    ..., : excerpt.waveform.shape[-1]
]

In [None]:
ir.waveform.shape, excerpt.waveform.shape, out.shape, excerpt.waveform.shape

In [None]:
AudioPlayer(out[0], rate=ir.sample_rate)

In [None]:
mel_spectrogram = MelSpectrogram(1024, 256, 80, ir.sample_rate)
plt.imshow(mel_spectrogram(rir[None]).log()[0].flip(0))

In [None]:
plt.plot(rir[0])

In [None]:
plt.plot(excerpt.waveform[0])

In [None]:
plt.plot(out[0, 0])