In [27]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from pedalboard.pedalboard import load_plugin

from src.dataset.paired_audio_dataset import PairedAudioDataset
from src.wrappers.dafx_wrapper import DAFXWrapper
from src.wrappers.null_dafx_wrapper import NullDAFXWrapper

from src.models.style_transfer_vae import StyleTransferVAE

In [28]:
DAFX_FILE = "/home/kieran/Level5ProjectAudioVAE/src/dafx/mda.vst3"
DAFX_NAME = "clean"
SAMPLE_RATE = 24_000
AUDIO_DIR = "/home/kieran/Level5ProjectAudioVAE/src/audio"
DATASETS = ["vctk_24000"]
NUM_EXAMPLES = 10
CHECKPOINT = "/home/kieran/Level5ProjectAudioVAE/src/l5proj_style_vae/ync68xdq/checkpoints/epoch=193-step=121250.ckpt"
CHECKPOINT_ID = CHECKPOINT.split("/")[-3]

In [29]:
model = StyleTransferVAE.load_from_checkpoint(CHECKPOINT)
model.eval()
print()




In [30]:
if DAFX_NAME.lower() == "clean":
    dafx = NullDAFXWrapper()
else:
    dafx = load_plugin(DAFX_FILE, plugin_name=DAFX_NAME)
    dafx = DAFXWrapper(dafx=dafx, sample_rate=SAMPLE_RATE)

In [31]:
dataset = PairedAudioDataset(
    dafx=dafx,
    audio_dir=AUDIO_DIR,
    subset="train",
    input_dirs=DATASETS,
    num_examples_per_epoch=NUM_EXAMPLES,
    augmentations={},
    length=131_072,
    effect_input=False,
    effect_output=True,
    dummy_setting=True
)

loader= torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=1,
    timeout=6000,
)

100%|████████████████████████████████████████| 88/88 [00:00<00:00, 22522.50it/s]


Loaded 88 files for train = 66.89 hours.





In [32]:
batch = next(iter(loader))
x_clean, y_clean = batch

In [33]:
from IPython.display import Audio

In [34]:
Audio(x_clean.squeeze().numpy(), rate=SAMPLE_RATE)

In [35]:
Audio(y_clean.squeeze().numpy(), rate=SAMPLE_RATE)

In [11]:
batch = next(iter(loader))
x, y = batch
x_s = model.audio_to_spectrogram(signal=x, return_phase=False)
y_s = model.audio_to_spectrogram(signal=y, return_phase=False)
X = torch.concat([x_s, y_s], dim=1)

X_hat, _, _, z = model(X)

tensor(2.0791, grad_fn=<StdBackward0>)