In [1]:
import sys
sys.path.append("..")

from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
from IPython.display import Audio
import numpy as np 
import pickle

from quantizer import Quantizer
from lib.dataset_wrapper import Dataset
from lib.notebooks import show_ema
from external import lpcynet

In [2]:
quantizers_path = glob("../out/quantizer/*/")
quantizers_path.sort()

In [3]:
quantizers_alias = {}

for quantizer_path in quantizers_path:
    quantizer = Quantizer.reload(quantizer_path, load_nn=False)
    config = quantizer.config
    
    quantizer_i = quantizer_path[-2]
    quantizer_alias = "\n".join((
        f"{','.join(config['dataset']['names'])}",
        f"hidden_layers={len(config['model']['hidden_dims'])}x{config['model']['hidden_dims'][0]}",
        f"{quantizer_i}",
    ))
    
    quantizers_alias[quantizer_alias] = quantizer_path

In [7]:
datasets_current_item = {}

def show_quantizer(quantizer_alias):
    quantizer_path = quantizers_alias[quantizer_alias]
    quantizer = Quantizer.reload(quantizer_path)
    
    sound_type = quantizer.config["dataset"]["data_type"]
    assert sound_type == "cepstrum"
    
    def show_dataset(dataset_name):
        dataset = Dataset(dataset_name)
        speaker_id = quantizer.config["dataset"]["names"].index(dataset_name)
        items_cepstrum = dataset.get_items_data(sound_type, cut_silences=True)
        items_source = dataset.get_items_data("source", cut_silences=True)
        sampling_rate = dataset.features_config["wav_sampling_rate"]
        
        items_name = dataset.get_items_list()
        if dataset_name in datasets_current_item:
            current_item = datasets_current_item[dataset_name]
        else:
            current_item = items_name[0][0]
        
        def resynth_item(item_name=current_item, freeze_source=False):
            datasets_current_item[dataset_name] = item_name
            
            item_cepstrum = items_cepstrum[item_name]
            item_source = items_source[item_name]
            item_wave = dataset.get_item_wave(item_name)
            nb_frames = len(item_cepstrum)
            
            repetition = quantizer.autoencode(item_cepstrum, speaker_id)
            resynth_cepstrum = repetition["seqs_pred"]
            
            if freeze_source:
                item_source[:] = (1, 0)
            
            resynth_sound = np.concatenate((resynth_cepstrum, item_source), axis=1)

            repeated_wave = lpcynet.synthesize_frames(resynth_sound)
            
            print("Original sound:")
            display(Audio(item_wave, rate=sampling_rate))
            print("Resynth (VQ-VAE → LPCNet):")
            display(Audio(repeated_wave, rate=sampling_rate))
            
            plt.figure(figsize=(nb_frames/20, 6), dpi=120)
            
            ax = plt.subplot(311)
            ax.set_title("original %s" % (sound_type))
            ax.imshow(item_cepstrum.T, origin="lower")
            
            ax = plt.subplot(312)
            ax.set_title("Repetition")
            ax.imshow(resynth_cepstrum.T, origin="lower")
            
            plt.tight_layout()
            plt.show()
        
        display(ipw.interactive(resynth_item, item_name=items_name, freeze_source=False))
    display(ipw.interactive(show_dataset, dataset_name=quantizer.config["dataset"]["names"]))

display(ipw.interactive(show_quantizer, quantizer_alias=sorted(quantizers_alias.keys())))

interactive(children=(Dropdown(description='quantizer_alias', options=('pb2007\nhidden_layers=4x256\na',), val…