In [155]:
import torch
import numpy as np

from tqdm import tqdm
from pedalboard.pedalboard import load_plugin

from src.dataset.paired_audio_dataset import PairedAudioDataset
from src.wrappers.dafx_wrapper import DAFXWrapper
from src.wrappers.null_dafx_wrapper import NullDAFXWrapper

In [156]:
DAFX_FILE = "/home/kieran/Level5ProjectAudioVAE/src/dafx/mda.vst3"
DAFX_NAME = "clean"
SAMPLE_RATE = 24_000
AUDIO_DIR = "/home/kieran/Level5ProjectAudioVAE/src/audio"
DATASETS = ["vctk_24000", "musdb18_24000"]
NUM_EXAMPLES = 10_000

In [157]:
if DAFX_NAME.lower() == "clean":
    dafx = NullDAFXWrapper()
else:
    dafx = load_plugin(DAFX_FILE, plugin_name=DAFX_NAME)
    dafx = DAFXWrapper(dafx=dafx, sample_rate=SAMPLE_RATE)

In [158]:
def audio_to_spectrogram(signal: torch.Tensor,
                         n_fft: int = 4096,
                         hop_length: int = 2048,
                         window_size: int = 4096,
                         return_complex: bool = True):

    bs, _, _ = signal.size()

    window = torch.nn.Parameter(torch.hann_window(window_size))

    if return_complex:
        # compute spectrogram of waveform
        X = torch.stft(
            signal.view(bs, -1),
            n_fft=n_fft,
            hop_length=hop_length,
            window=window,
            return_complex=False,
        )

        return X.permute(0, 3, 2, 1).detach()

    X = torch.stft(
            signal.view(bs, -1),
            n_fft=n_fft,
            hop_length=hop_length,
            window=window,
            return_complex=True,
        )

    X_abs = X.abs().unsqueeze(1).permute(0, 1, 3, 2)

    return X_abs


In [159]:
dataset = PairedAudioDataset(
    dafx=dafx,
    audio_dir=AUDIO_DIR,
    subset="train",
    input_dirs=DATASETS,
    num_examples_per_epoch=NUM_EXAMPLES,
    augmentations={},
    length=24_000*2.5,
    effect_input=False,
    effect_output=True,
    dummy_setting=True
)

loader= torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=16,
    timeout=6000,
)

100%|██████████████████████████████████████| 208/208 [00:00<00:00, 15451.92it/s]


Loaded 208 files for train = 52.51 hours.





In [148]:
# for batch in loader:
#     x, _ = batch
#
#     x_spec_complx = audio_to_spectrogram(x, return_complex=True)
#     x_spec = audio_to_spectrogram(x, return_complex=False)
#
#     print("Complex shape: ", x_spec_complx.shape)
#     print("Non-complex shape: ", x_spec.shape)
#
#     print("Real parts equal: ", torch.equal(x_spec_complx[:,0,:,:], x_spec[:,:,:,0]))
#     print("Imag parts equal: ", torch.equal(x_spec_complx[:,1,:,:], x_spec[:,:,:,1]))

Complex shape:  torch.Size([16, 2, 59, 2049])
Non-complex shape:  torch.Size([16, 1, 59, 2049])
Real parts equal:  False
Imag parts equal:  False


In [160]:
real_means = []
real_stds = []
complex_means = []
complex_stds = []

for batch in tqdm(loader):
    x, y = batch

    x_spec = audio_to_spectrogram(x, return_complex=True)
    y_spec = audio_to_spectrogram(y, return_complex=True)

    real_means.append(x_spec[:,0,:,:].mean())
    real_stds.append(x_spec[:,0,:,:].std())
    complex_means.append(x_spec[:,1,:,:].mean())
    complex_stds.append(x_spec[:,1,:,:].std())

    real_means.append(y_spec[:,0,:,:].mean())
    real_stds.append(y_spec[:,0,:,:].std())
    complex_means.append(y_spec[:,1,:,:].mean())
    complex_stds.append(y_spec[:,1,:,:].std())

100%|██████████| 625/625 [00:16<00:00, 39.03it/s]


In [161]:
real_means_np = np.array(real_means)
real_stds_np = np.array(real_stds)
complex_means_np = np.array(complex_means)
complex_stds_np = np.array(complex_stds)

In [162]:
real_means_np.mean()

-1.77781e-05

In [163]:
real_stds_np.mean()

1.1515284

In [164]:
complex_means_np.mean()

7.083435e-08

In [165]:
complex_stds_np.mean()

1.1449312