In [1]:
import sys
sys.path.append("..")

from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
from IPython.display import Audio
import numpy as np
import pickle

from communicative_agent import CommunicativeAgent
from lib.dataset_wrapper import Dataset
from lib.notebooks import show_ema

from external import lpcynet

In [2]:
agents_path = glob("../out/communicative_agent/*/")
agents_path.sort()

In [3]:
agents_alias = []

for agent_path in agents_path:
    agent = CommunicativeAgent.reload(agent_path, load_nn=False)
    config = agent.config
    
    if agent.synthesizer.config['dataset']['art_type'] == "ema":
        continue
    
    agent_alias = " ".join((
        agent_path,
        f"{','.join(agent.sound_quantizer.config['dataset']['names'])}",
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}",
        f"jerk={config['training']['jerk_loss_weight']}",
    ))
    
    agents_alias.append((agent_alias, agent_path))

In [4]:
def show_agent(agent_path):
    agent = CommunicativeAgent.reload(agent_path)
    
    sound_type = agent.synthesizer.config["dataset"]["sound_type"]
    art_type = agent.synthesizer.config["dataset"]["art_type"]
    synth_dataset = agent.synthesizer.dataset
    
    def show_dataset(dataset_name):
        dataset = Dataset(dataset_name)
        items_cepstrum = dataset.get_items_data(sound_type, cut_silences=False)
        items_source = dataset.get_items_data("source", cut_silences=False)
        sampling_rate = dataset.features_config["wav_sampling_rate"]
        
        # items_ema = dataset.get_items_data("ema", cut_silences=True)
        
        items_name = dataset.get_items_list()
        
        def resynth_item(item_name):
            item_cepstrum = items_cepstrum[item_name]
            item_source = items_source[item_name]
            item_wave = dataset.get_item_wave(item_name)
            nb_frames = len(item_cepstrum)
            
            repetition = agent.repeat(item_cepstrum)
            repeated_cepstrum = repetition["sound_repeated"]
            estimated_cepstrum = repetition["sound_estimated"]
            estimated_art = repetition["art_estimated"]
            
            repeated_sound = np.concatenate((repeated_cepstrum, item_source), axis=1)
            estimated_sound = np.concatenate((estimated_cepstrum, item_source), axis=1)

            repeated_wave = lpcynet.synthesize_frames(repeated_sound)
            estimated_wave = lpcynet.synthesize_frames(estimated_sound)
            
            print("Original sound:")
            display(Audio(item_wave, rate=sampling_rate))
            print("Repetition (Inverse model → Synthesizer → LPCNet):")
            display(Audio(repeated_wave, rate=sampling_rate))
            print("Estimation (Inverse model → Direct model → LPCNet):")
            display(Audio(estimated_wave, rate=sampling_rate))
            
            plt.figure(figsize=(nb_frames/20, 6), dpi=120)
            
            ax = plt.subplot(311)
            ax.set_title("original %s" % (sound_type))
            ax.imshow(item_cepstrum.T, origin="lower")
            
            ax = plt.subplot(312)
            ax.set_title("Repetition")
            ax.imshow(repeated_cepstrum.T, origin="lower")
            
            ax = plt.subplot(313)
            ax.set_title("Estimation")
            ax.imshow(estimated_cepstrum.T, origin="lower")
            
            plt.tight_layout()
            plt.show()
            
            if art_type == "art_params":
                estimated_art = synth_dataset.art_to_ema(estimated_art)
            # item_ema = items_ema[item_name]
            show_ema(estimated_art, reference=None, dataset=synth_dataset)
        
        display(ipw.interactive(resynth_item, item_name=items_name))
    display(ipw.interactive(show_dataset, dataset_name=agent.sound_quantizer.config["dataset"]["names"]))

display(ipw.interactive(show_agent, agent_path=agents_alias))

interactive(children=(Dropdown(description='agent_path', options=(('../out/communicative_agent/00b2b5ff5b7c6f9…