In [1]:
import torch
from torchvision import transforms
import math
import librosa
import soundfile as sf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import display, Audio
# from tqdm import tqdm

import data_loader.data_loaders as module_data_loaders
from data_loader.data_loaders import CollectDataLoader
import model.model as module_model
from utils import get_instance

In [2]:
model_path = 'saved/models/Conv1dSpecVAE/0923_153845/model_best.pth'
audio_path = 'saved/audio'

sr = 22050
n_fft = 2048
hop_length = 512

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
resume = torch.load(model_path, map_location=device)
config = resume['config']
# config.print_()
latent_dim = config['arch']['args']['latent_dim']
batch_size = config['data_loader']['args']['batch_size']

In [3]:
model = get_instance(module_model, 'arch', config)
model.load_state_dict(resume['state_dict'])
model.eval()
model.is_featExtract = False

data_loader = get_instance(module_data_loaders, 'data_loader', config)
valid_data_loader = data_loader.split_validation()

-- datasets.CollectData --
path_to_dataset:  ['dataset/textureDataset/m256-.5s']
path_to_data: [
	 dataset/textureDataset/m256-.5s/trainingdata/classA/512334__felix-blume__light-rain-in-a-field-of-bananas-trees_mono_5min_0.npy
	 dataset/textureDataset/m256-.5s/trainingdata/classA/512334__felix-blume__light-rain-in-a-field-of-bananas-trees_mono_5min_1.npy
	 dataset/textureDataset/m256-.5s/trainingdata/classA/512334__felix-blume__light-rain-in-a-field-of-bananas-trees_mono_5min_10.npy
	 dataset/textureDataset/m256-.5s/trainingdata/classA/512334__felix-blume__light-rain-in-a-field-of-bananas-trees_mono_5min_100.npy
	 dataset/textureDataset/m256-.5s/trainingdata/classA/512334__felix-blume__light-rain-in-a-field-of-bananas-trees_mono_5min_101.npy 
]
extensions:  ['wav', 'mp3', 'npy', 'pth']
subset:  train
transform:  Compose(
    <dataset.transformers.LoadNumpyAry object at 0x10626d7f0>
)
labels:  ['classA', 'classA', 'classA']
# of classes:  1
length:  615 

-- base data loader --
validati

In [4]:
def get_embeddings(model, data_loader, device):
    model.eval()
    mu_all = None
    logvar_all = None
    labels_all = None
    
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(data_loader):
            images = images.to(device)
            recon_batch, mu, logvar = model(images)
            if batch_idx == 0:
                mu_all = mu.clone().detach()
                logvar_all = logvar.clone().detach()
                labels_all = labels.clone().detach()
            else:
                mu_all = torch.cat((mu_all, mu), dim=0)
                logvar_all = torch.cat((logvar_all, logvar), dim=0)
                labels_all = torch.cat((labels_all, labels), dim=0)
                
    return mu_all, logvar_all, labels_all

In [5]:
idx1, label1, x1 = data_loader.dataset[10]
idx2, label2, x2 = data_loader.dataset[11]
n_freqBand, n_timeFrames = x1.shape[0], x1.shape[1]

with torch.no_grad():
    x1 = torch.from_numpy(x1).unsqueeze(0)
    x1.shape
    x1_recon, mu1, logvar1, z1 = model(x1)

    x2 = torch.from_numpy(x2).unsqueeze(0)
    x2.shape
    x2_recon, mu2, logvar2, z2 = model(x2)

In [6]:
x1_au = librosa.griffinlim(S=x1.squeeze(0).numpy(), n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)
x1_recon_au = librosa.griffinlim(S=x1_recon.squeeze(0).numpy(), n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)

x2_au = librosa.griffinlim(S=x2.squeeze(0).numpy(), n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)
x2_recon_au = librosa.griffinlim(S=x2_recon.squeeze(0).numpy(), n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)

print("x1 original")
display(Audio(x1_au, rate=sr))
print("x1 recontructed with Griffin-Lim")
display(Audio(x1_recon_au, rate=sr))

print("x2 original")
display(Audio(x2_au, rate=sr))
print("x2 recontructed with Griffin-Lim")
display(Audio(x2_recon_au, rate=sr))

x1 original


x1 recontructed with Griffin-Lim


x2 original


x2 recontructed with Griffin-Lim


In [7]:
idx1, label1, x1 = data_loader.dataset[10]
idx2, label2, x2 = data_loader.dataset[11]

n_freqBand, n_timeFrames = x1.shape[0], x1.shape[1]

alpha = np.linspace(start=0.0, stop=1.0, num=10)

with torch.no_grad():
    x1 = torch.from_numpy(x1).unsqueeze(0)
    x1.shape
    mu1, logvar1, z1 = model.encode(x1)
    x1_recon = model.decode(z1)

    x2 = torch.from_numpy(x2).unsqueeze(0)
    x2.shape
    mu2, logvar2, z2 = model.encode(x2)
    x2_recon = model.decode(z2)

delta_mu = (mu2 - mu1)

In [8]:
for k, a in enumerate(alpha):
    z_1to2 = z1 + a * delta_mu
    
#     print(z_1to2)
    
    with torch.no_grad():
        x_recon = model.decode(z_1to2).squeeze(0).numpy()

    if k == 0:
        x_1to2 = x_recon
    else:
        x_1to2 = np.concatenate((x_1to2, x_recon), axis=1)        

x1_au = librosa.griffinlim(S=x1.squeeze(0).numpy(), n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)
x1_recon_au = librosa.griffinlim(S=x1_recon.squeeze(0).numpy(), n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)

x2_au = librosa.griffinlim(S=x2.squeeze(0).numpy(), n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)
x2_recon_au = librosa.griffinlim(S=x2_recon.squeeze(0).numpy(), n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)

x_1to2_au = librosa.griffinlim(S=x_1to2, n_iter=32, hop_length=512, win_length=2048, window='hann', center=True, 
                        dtype=None, length=None, pad_mode='reflect', momentum=0.99, init='random', random_state=None)

print("x1 original")
display(Audio(x1_au, rate=sr))
print("x1 recontructed with Griffin-Lim")
display(Audio(x1_recon_au, rate=sr))

print("x2 original")
display(Audio(x2_au, rate=sr))
print("x2 recontructed with Griffin-Lim")
display(Audio(x2_recon_au, rate=sr))

print("x1 to x2 recontructed with Griffin-Lim")
display(Audio(x_1to2_au, rate=22050))

x1 original


x1 recontructed with Griffin-Lim


x2 original


x2 recontructed with Griffin-Lim


x1 to x2 recontructed with Griffin-Lim


In [15]:
# write ouput
sf.write('saved/audio/x1_orig.wav', x1_au, sr, format='wav', subtype='PCM_16')
sf.write('saved/audio/x1_recon.wav', x1_recon_au, sr, format='wav', subtype='PCM_16')
sf.write('saved/audio/x2_orig.wav', x2_au, sr, format='wav', subtype='PCM_16')
sf.write('saved/audio/x2_recon.wav', x2_recon_au, sr, format='wav', subtype='PCM_16')
sf.write('saved/audio/x_1to2.wav', x_1to2_au, sr, format='wav', subtype='PCM_16')