In [None]:
import sys

sys.path.append("..")

from pathlib import Path
from dataset.audio_data import WavDatasetPair
import torch
from matplotlib import pyplot as plt
import librosa
import numpy as np


In [None]:
audio_dir = "../data/audioset/all"
sample_file = sorted(Path(audio_dir).glob("*.wav"))[0]


In [None]:
sample_file


In [None]:
from dataset.audio_augmentations import (
    get_contrastive_augment,
    get_weak_augment,
    PrecomputedNorm,
)


In [None]:
contrastive_augment = get_contrastive_augment()
weak_augment = get_weak_augment()


In [None]:
ds = WavDatasetPair(
    sample_rate=16000,
    audio_files=[sample_file],
    labels=None,
    random_crop=True,
    contrastive_aug=contrastive_augment,
    weak_aug=weak_augment,
)


In [None]:
dl = torch.utils.data.DataLoader(
    ds,
    batch_size=1,
    num_workers=1,
    pin_memory=True,
    shuffle=True,
)


## Augmented waveforms

In [None]:
wavs = next(iter(dl))


In [None]:
wavs[0][0].shape


In [None]:
orig_wav = librosa.load(sample_file, sr=16000)[0]
l = len(orig_wav)
unit_samples = int(16000 * 0.95)
if l > unit_samples:
    start = np.random.randint(l - unit_samples)
    orig_wav = orig_wav[start : start + unit_samples]
elif l < unit_samples:
    orig_wav = np.pad(orig_wav, (0, unit_samples - l), mode="constant", value=0)


In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(8, 12), sharex=True)


axs[0].plot(orig_wav)
axs[0].set_title("original wav")

axs[1].plot(wavs[0][0].numpy())
axs[1].set_title("contrastive aug")

axs[2].plot(wavs[1][0].numpy())
axs[2].set_title("weak aug")

plt.show()


## Log mel-spectrograms

In [None]:
from dataset.audio_augmentations import NormalizeBatch, PrecomputedNorm
import nnAudio.features


In [None]:
post_norm = NormalizeBatch()
# pre_norm = PrecomputedNorm()
to_spec = nnAudio.features.MelSpectrogram(
    sr=16000,
    n_fft=1024,
    win_length=1024,
    hop_length=160,
    n_mels=64,
    fmin=60,
    fmax=7800,
    center=True,
    power=2,
    verbose=False,
)


In [None]:
torch.from_numpy(orig_wav).unsqueeze(0).shape


In [None]:
img0 = (to_spec(torch.from_numpy(orig_wav).unsqueeze(0)) + torch.finfo().eps).unsqueeze(
    1
)
img1 = (to_spec(wavs[0]) + torch.finfo().eps).unsqueeze(1)
img2 = (to_spec(wavs[1]) + torch.finfo().eps).unsqueeze(1)


In [None]:
fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(8, 12), sharex=True)

S_dB_0 = librosa.power_to_db(img0[0, 0, :, :].numpy(), ref=np.max)
img0 = librosa.display.specshow(
    S_dB_0, x_axis="time", y_axis="mel", sr=16000, fmax=8000, ax=axs[0]
)

S_dB_1 = librosa.power_to_db(img1[0, 0, :, :].numpy(), ref=np.max)
img1 = librosa.display.specshow(
    S_dB_1, x_axis="time", y_axis="mel", sr=16000, fmax=8000, ax=axs[1]
)

S_dB_2 = librosa.power_to_db(img2[0, 0, :, :].numpy(), ref=np.max)
img2 = librosa.display.specshow(
    S_dB_2, x_axis="time", y_axis="mel", sr=16000, fmax=8000, ax=axs[2]
)

fig.colorbar(img0, ax=axs[0], format="%+2.0f dB")
fig.colorbar(img1, ax=axs[1], format="%+2.0f dB")
fig.colorbar(img2, ax=axs[2], format="%+2.0f dB")

plt.show()
