In [1]:
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader

from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms

import random
import numpy as np

import os, sys, argparse, time
from pathlib import Path

sys.path.append('../')

import librosa
import soundfile as sf
import configparser
import random
import json
import matplotlib.pyplot as plt


from ltsp.model import VAE, loss_function
from ltsp.tests import init_test_audio

In [2]:
sampling_rate = 44100
n_bins = 384
n_units = 2048
latent_dim = 256
device = 'cuda:0'

batch_size = 256

dataset = Path(r'C:\Users\Kivanc\Documents\my_workspace\datasets\latent-timbre-synthesis-pytorch-2\erokia')
my_test_audio = dataset / 'test_audio'
my_cqt = dataset / 'npy'

In [3]:
# LOAD FROM MODEL
model = VAE(n_bins, n_units, latent_dim).to(device)
model_path = dataset / 'lts-pytorch' / 'run-004' / 'model' / 'best_model.pt'
model = torch.load(model_path)
model.eval()

VAE(
  (fc1): Linear(in_features=384, out_features=2048, bias=True)
  (fc21): Linear(in_features=2048, out_features=256, bias=True)
  (fc22): Linear(in_features=2048, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=2048, bias=True)
  (fc4): Linear(in_features=2048, out_features=384, bias=True)
)

In [5]:
# DONT RUN THIS
# LOAD FROM CHECKPOINT - PASS IF YOU WOULD LIKE TO LOAD FROM MODEL

state = torch.load(Path(r'C:\Users\Kivanc\Documents\my_workspace\datasets\latent-timbre-synthesis-pytorch-2\erokia\lts-pytorch\run-004\model\checkpoints\ckpt_00500'))
model = VAE(n_bins, n_units, latent_dim).to(device)
model.load_state_dict(state['state_dict'])
model.eval()

VAE(
  (fc1): Linear(in_features=384, out_features=2048, bias=True)
  (fc21): Linear(in_features=2048, out_features=256, bias=True)
  (fc22): Linear(in_features=2048, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=2048, bias=True)
  (fc4): Linear(in_features=2048, out_features=384, bias=True)
)

In [11]:
# List the test audio files from the dataset
test_files = [f for f in my_test_audio.glob('*.wav')]
init = True

for test in test_files:
    
    audio_full, _ = librosa.load(test, sr=sampling_rate)
    dataname = Path(test).stem
    cqt_full = np.load(my_cqt.joinpath(dataname + '.npy'))

    if init:
        test_dataset_audio = audio_full
        test_dataset_cqt = cqt_full
        init = False
    else:
        test_dataset_audio = np.concatenate((test_dataset_audio, audio_full ),axis=0)
        test_dataset_cqt = np.concatenate((test_dataset_cqt, cqt_full ),axis=0)

# Create a dataloader for test dataset
test_tensor = torch.Tensor(test_dataset_cqt)
test_dataset = TensorDataset(test_tensor)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

torch.float32

In [18]:
init_test = True
for iterno, test_tuple in enumerate(test_dataloader):
    test_sample, = test_tuple
    with torch.no_grad():
        test_sample = test_sample.cuda()
        test_pred_z = model.encode(test_sample.double())
        test_pred = model.decode(test_pred_z[0])
    if init_test:
        test_predictions = test_pred
        init_test = False
    else:
        test_predictions = torch.cat((test_predictions, test_pred ),0)

y_inv_32 = librosa.griffinlim_cqt(test_predictions.permute(1,0).cpu().numpy(), sr=sampling_rate, n_iter=1, hop_length=128, bins_per_octave=48, dtype=np.float32)
sf.write('test_reconst.wav', y_inv_32, sampling_rate)

# Interpolations

In [33]:
test_dataset1, test_dataset2 = torch.tensor_split(test_tensor, 2)
test_dataloader1 = DataLoader(test_dataset1, batch_size = batch_size, shuffle=False)
test_dataloader2 = DataLoader(test_dataset2, batch_size = batch_size, shuffle=False)

In [45]:
init_test = True
for iterno, test_sample in enumerate(test_dataloader1):
    with torch.no_grad():
        test_sample = test_sample.double().to(device)
        test1_mu, test1_logvar = model.encode(test_sample)
        test1_z = model.reparameterize(test1_mu, test1_logvar)
        
    if init_test:
        test1_z_all = test1_z 
        init_test = False

    else:
        test1_z_all = torch.cat((test1_z_all, test1_z),0)

In [46]:
init_test = True
for iterno, test_sample in enumerate(test_dataloader2):
    with torch.no_grad():
        test_sample = test_sample.double().to(device)
        test2_mu, test2_logvar = model.encode(test_sample)
        test2_z = model.reparameterize(test2_mu, test2_logvar)

    if init_test:
        test2_z_all = test2_z 
        init_test = False

    else:
        test2_z_all = torch.cat((test2_z_all, test2_z ),0)

In [54]:
def reconstructions(dataloader, model, save = True, path = './reconstruction.wav', sampling_rate = 44100):
    init_test = True
        
    for iterno, test_sample in enumerate(dataloader):
        with torch.no_grad():
            test_sample = test_sample.double().to(device)
            test_pred, _, _ = model(test_sample)

        if init_test:
            test_predictions = test_pred
            init_test = False

        else:
            test_predictions = torch.cat((test_predictions, test_pred ),0)
    
    if save:
        outpath = Path(path)
        y_inv_32 = librosa.griffinlim_cqt(test_predictions.permute(1,0).cpu().numpy(), sr=sampling_rate, n_iter=1, hop_length=128, bins_per_octave=48, dtype=np.float32)
        sf.write( outpath, y_inv_32, sampling_rate)
    else:
        return test_predictions

In [55]:
reconstructions(test_dataloader1, model, path = './test1_original.wav' )
reconstructions(test_dataloader2, model, path = './test2_original.wav' )

In [51]:
# (mu1 * a) + (mu2 * (1-a)) 
inter_amount = 0.5
inter_z = torch.add( torch.mul(test1_z_all, (1-inter_amount)), torch.mul(test2_z_all, inter_amount) )

In [52]:
init_test = True
      
with torch.no_grad():
    test_pred = model.decode(inter_z)

if init_test:
    test_predictions = test_pred
    init_test = False

else:
    test_predictions = torch.cat((test_predictions, test_pred ),0)

In [53]:
outpath = Path('./inter-0.5.wav')
y_inv_32 = librosa.griffinlim_cqt(test_predictions.permute(1,0).cpu().numpy(), sr=sampling_rate, n_iter=1, hop_length=128, bins_per_octave=48, dtype=np.float32)
sf.write( outpath, y_inv_32, sampling_rate)