In [1]:
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader

from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms

import random
import numpy as np

import os, sys, argparse, time
from pathlib import Path

import librosa
import soundfile as sf
import configparser
import random
import json
import matplotlib.pyplot as plt
import pdb

In [2]:
class VAE(nn.Module):
  def __init__(self, n_bins, n_units, latent_dim):
    super(VAE, self).__init__()

    self.n_bins = n_bins
    self.n_units = n_units
    self.latent_dim = latent_dim
    
    self.fc1 = nn.Linear(n_bins, n_units)
    self.fc21 = nn.Linear(n_units, latent_dim)
    self.fc22 = nn.Linear(n_units, latent_dim)
    self.fc3 = nn.Linear(latent_dim, n_units)
    self.fc4 = nn.Linear(n_units, n_bins)

  def encode(self, x):
      h1 = F.relu(self.fc1(x))
      return self.fc21(h1), self.fc22(h1)

  def reparameterize(self, mu, logvar):
      std = torch.exp(0.5*logvar)
      eps = torch.randn_like(std)
      return mu + eps*std

  def decode(self, z):
      h3 = F.relu(self.fc3(z))
      return F.relu(self.fc4(h3))

  def forward(self, x):
      mu, logvar = self.encode(x.view(-1, self.n_bins))
      z = self.reparameterize(mu, logvar)
      return self.decode(z), mu, logvar

In [3]:
sampling_rate = 44100
n_bins = 384
n_units = 2048
latent_dim = 256
device = 'cuda:0'

batch_size = 256

dataset = Path(r'D:\datasets\Audio\latent-timbre-synthesis\erokia')
my_test_audio = dataset / 'test_audio'
my_cqt = dataset / 'npy'

In [7]:
# LOAD FROM CHECKPOINT - PASS IF YOU WOULD LIKE TO LOAD FROM MODEL

state = torch.load(Path(r"D:\datasets\Audio\latent-timbre-synthesis\erokia\lts-pytorch\run-001\model\checkpoints\ckpt_00400"))
model = VAE(n_bins, n_units, latent_dim).to(device)
model.load_state_dict(state['state_dict'])
model.eval()

VAE(
  (fc1): Linear(in_features=384, out_features=2048, bias=True)
  (fc21): Linear(in_features=2048, out_features=256, bias=True)
  (fc22): Linear(in_features=2048, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=2048, bias=True)
  (fc4): Linear(in_features=2048, out_features=384, bias=True)
)

In [4]:
# LOAD FROM MODEL
model = VAE(n_bins, n_units, latent_dim).to(device)
model = torch.load(Path(r'D:\datasets\Audio\latent-timbre-synthesis\erokia\lts-pytorch\run-001\model\last_model.pt'))
model.eval()

VAE(
  (fc1): Linear(in_features=384, out_features=2048, bias=True)
  (fc21): Linear(in_features=2048, out_features=256, bias=True)
  (fc22): Linear(in_features=2048, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=2048, bias=True)
  (fc4): Linear(in_features=2048, out_features=384, bias=True)
)

In [8]:
# List the test audio files from the dataset
test_files = [f for f in my_test_audio.glob('*.wav')]
init = True

for test in test_files:
    
    audio_full, _ = librosa.load(test, sr=sampling_rate)
    dataname = Path(test).stem
    cqt_full = np.load(my_cqt.joinpath(dataname + '.npy'))

    if init:
        test_dataset_audio = audio_full
        test_dataset_cqt = cqt_full
        init = False
    else:
        test_dataset_audio = np.concatenate((test_dataset_audio, audio_full ),axis=0)
        test_dataset_cqt = np.concatenate((test_dataset_cqt, cqt_full ),axis=0)

# Create a dataloader for test dataset
test_tensor = torch.Tensor(test_dataset_cqt)
test_dataset = TensorDataset(test_tensor)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [9]:
test_dataset_cqt.shape

(15411, 384)

In [16]:
init_test = True
for iterno, test_tuple in enumerate(test_dataloader):
    test_sample, = test_tuple
    with torch.no_grad():
        test_sample = test_sample.cuda()
        test_pred_z = model.encode(test_sample)
        test_pred = model.decode(test_pred_z[0])
    if init_test:
        test_predictions = test_pred
        init_test = False
    else:
        test_predictions = torch.cat((test_predictions, test_pred ),0)

y_inv_32 = librosa.griffinlim_cqt(test_predictions.permute(1,0).cpu().numpy(), sr=sampling_rate, n_iter=4, hop_length=128, bins_per_octave=48, dtype=np.float32)
sf.write('test_reconst.wav', y_inv_32, sampling_rate)

In [23]:
sf.write('test_reconst.wav', y_inv_32, sampling_rate)

In [13]:
model.encode(test_sample)

(tensor([[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
        grad_fn=<AddmmBackward>))

In [72]:
s = np.random.rand(np.random.randint(1024))
if s.shape[0] % 1024 != 0:
    num_zeros = 1024 - s.shape[0] % 1024
    s = np.pad(s, (0, num_zeros), 'constant', constant_values=(0,0))

s.shape

(1024,)

In [34]:
cqt_fmin_hz = librosa.note_to_hz('A1')
cqt_fmin_samples = 1 / cqt_fmin_hz * sampling_rate
cqt_fmin_samples

801.8181818181818

In [58]:
s = np.random.rand(np.random.randint(1024))
print(s.shape)

(192,)
