In [None]:
%cd ..
import matplotlib.pyplot as plt

%matplotlib inline
from functools import partial
from multiprocessing import Pool
from pathlib import Path

import librosa
import numpy as np
import torch
import torchaudio
import torchaudio.functional as AF
from IPython.display import Audio
from rich.progress import track
from torch.nn import functional as F
from torch.utils.data import DataLoader

from myddsp import constants as C
from myddsp.preprocessors import get_centered_frames
from myddsp.train import Zak
from myddsp.vae import (
    load_file,
    persist,
    prepare_all,
    prepare_amp,
    prepare_pitch,
    prepare_stft,
)

In [None]:
a = torch.randn(2, 3, 5)

In [None]:
a.argmax(dim=2).shape

In [None]:
def plot(x, figsize=(12, 6), *args, **kwargs):
    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(x, *args, **kwargs)


def matshow(x, figsize=(12, 12), *args, **kwargs):
    fig, ax = plt.subplots(figsize=figsize)
    ax.matshow(x, *args, origin="lower", **kwargs)


def safe_log(x, eps=1e-5):
    return torch.log(torch.clamp(x, min=eps))

In [None]:
zak = Zak()
zak.load_from_checkpoint("logs/version_8/checkpoints/epoch=2-step=80340.ckpt")

zak = zak.eval()
for p in zak.parameters():
    p.requires_grad = False
zak = zak.cuda()

In [None]:
stft, amp, pitch = prepare_all()

In [None]:
size = 4096
with torch.inference_mode():
    result = zak.decoder(amp[:size].cuda(), pitch[:size].cuda()).cpu()

In [None]:
matshow(result.T)

In [None]:
result.min(), result.max(), stft[:size].min(), stft[:size].max()

In [None]:
y_hat = librosa.griffinlim(
    np.concatenate([np.zeros((1, result.shape[0])), result.T]),
    n_iter=64,
    hop_length=C.HOP_LENGTH,
    n_fft=C.N_FFT,
    init="random",
)
Audio(data=y_hat, rate=C.SAMPLE_RATE)