In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from main import LitVAE, AudioDataset, next_power_of_2
import torch
import torch.nn.functional as F
from torchaudio.transforms import GriffinLim
import numpy as np
from scipy.spatial import geometric_slerp
from IPython.display import Audio
import librosa
from pathlib import Path
from torchcrepe.convert import frequency_to_bins, cents_to_bins
from torchcrepe import predict
from torchcrepe.decode import argmax
import torchvision.transforms.functional as tvf

In [None]:
def sample(n_dims=64):
    x = np.random.standard_normal(n_dims)
    x = x / np.sqrt(x.dot(x))
    return x

In [None]:
gl = GriffinLim(n_fft=2048, hop_length=512, power=1.0, n_iter=128).cuda()

In [None]:
def latest_checkpoint(v=None):
    base_path = Path('../lightning_logs/')
    
    if v is None:
        version = 0    
        for f in base_path.iterdir():
            v = int(f.name.split('_')[1])
            version = v if v > version else version
    else:
        version = v
    
    base_path = base_path / f'version_{version}/checkpoints'
    checkpoint = next(base_path.glob('*.ckpt'))
    
    return checkpoint

In [None]:
try:
    del model
except:
    pass
model = LitVAE.load_from_checkpoint(latest_checkpoint())

In [None]:
model.cuda()
model.eval()
pass

In [None]:
zs = []
start = sample(64)
s = 1.
e = 0.
block = 64
for i in range(50):
    end = sample(64)
    zs.append(geometric_slerp(start, end, np.linspace(0, 1, block, endpoint=False)) * (np.linspace(s, e, block))[:, None])
    start = end
    s, e = e, s
zs = np.concatenate(zs, axis=0)

zs = torch.from_numpy(zs.astype('float32')).cuda()
cs = torch.zeros(zs.shape[0], 360).cuda()
freq = frequency_to_bins(torch.Tensor([98.]), torch.round)
cs[:, int(freq)] = 1.0

In [None]:
feat = torch.cat([zs.T.unsqueeze(0), cs.T.unsqueeze(0)], dim=1)

In [None]:
with torch.no_grad():
    x_hat = model.vae.decoder(feat)
    y_hats = (x_hat * 1024).squeeze(0).T

In [None]:
s = y_hats.cpu().numpy()

# plt.matshow(s)
# plt.show()

zeros = torch.zeros(y_hats.shape[0], 1, device=y_hats.device)
sound = gl(torch.cat([zeros, y_hats], dim=1).T)
Audio(sound.cpu().numpy(), rate=44100, normalize=True)

In [None]:
PATH = '/home/kureta/Music/cello/Cello Samples/BachMinu1-00000-.wav'
y, sr = librosa.load(PATH, mono=True, sr=44100)
s = np.abs(librosa.stft(y, n_fft=2048, hop_length=512)) / 1024
plt.matshow(s.T)
plt.show()

Audio(y, rate=44100, normalize=False)

In [None]:
CREPE_SAMPLE_RATE = 16000
sample_rate = 44100
n_fft = 2048
hop_length = 512
crepe_hop_length = next_power_of_2(hop_length * CREPE_SAMPLE_RATE / sample_rate)
_, _, probs = predict(torch.from_numpy(y).unsqueeze(0), sample_rate=sample_rate, hop_length=crepe_hop_length,
                      return_periodicity=True, device='cuda', decoder=argmax, batch_size=512)
probs = probs.argmax(dim=1)[0]
probs = tvf.resize(probs.unsqueeze(0).unsqueeze(0), [1, s.shape[1]]).squeeze(1)
probs = F.one_hot(probs[0], 360).T.unsqueeze(0)

In [None]:
with torch.no_grad():
    x_hat, _, _, _ = model.vae(torch.from_numpy(s[1:]).unsqueeze(0).cuda(), probs)
    y_hats = (x_hat * 1024).squeeze(0).T

In [None]:
plt.matshow(y_hats.cpu().numpy())
plt.show()

zeros = torch.zeros(y_hats.shape[0], 1, device=y_hats.device)
sound = gl(torch.cat([zeros, y_hats], dim=1).T)
Audio(sound.cpu().numpy(), rate=44100, normalize=True)