In [1]:
import sys
sys.path.append("..")

from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
import numpy as np 
from tqdm.notebook import tqdm

from imitative_agent import ImitativeAgent
from lib.dataset_wrapper import Dataset
from lib import utils
from lib import abx_utils
from lib import notebooks

In [4]:
agents_path = glob("../out/imitative_agent/*/")
agents_path.sort()

agents_alias = {}
agents_group = {}

for agent_path in agents_path:
    agent = ImitativeAgent.reload(agent_path, load_nn=False)
    config = agent.config
        
    agent_i = agent_path[-2]
    agent_alias = " ".join((
        f"{','.join(config['dataset']['names'])}",
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}",
        f"jerk_c={config['training']['jerk_loss_ceil']}",
        f"jerk_w={config['training']['jerk_loss_weight']}",
        f"bi={config['model']['inverse_model']['bidirectional']}",
        f"({agent_i})",
    ))
    agents_alias[agent_alias] = agent_path
    
    agent_group = " ".join((
        f"{','.join(config['dataset']['names'])}",
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}",
        f"jerk_c={config['training']['jerk_loss_ceil']}",
        f"jerk_w={config['training']['jerk_loss_weight']}",
        f"bi={config['model']['inverse_model']['bidirectional']}",
    ))
    if agent_group not in agents_group:
        agents_group[agent_group] = []
    agents_group[agent_group].append(agent_path)

In [5]:
agents_occlusions_metrics = utils.pickle_load("../out/imitative_agent/occlusions_cache.pickle", {})

for agent_alias, agent_path in tqdm(agents_alias.items()):
    if agent_path in agents_occlusions_metrics: continue
    agent = ImitativeAgent.reload(agent_path)
    synth_dataset = agent.synthesizer.dataset
    
    main_dataset = agent.get_main_dataset()
    agent_lab = agent.get_datasplit_lab(2)
    agent_features = agent.repeat_datasplit(2)
    
    datasets_estimated_ema = {}
    for dataset_name, dataset_features in agent_features.items():
        datasets_estimated_ema[dataset_name] = {}
        items_estimated_ema = datasets_estimated_ema[dataset_name]
        
        items_estimated_art = dataset_features["art_estimated"]
        for item_name, item_estimated_art in items_estimated_art.items():
            item_estimated_ema = synth_dataset.art_to_ema(item_estimated_art)
            items_estimated_ema[item_name] = item_estimated_ema
            
    palate = synth_dataset.palate
    consonants = main_dataset.phones_infos["consonants"]
    vowels = main_dataset.phones_infos["vowels"]
    consonants_indexes = abx_utils.get_datasets_phones_indexes(agent_lab, consonants, vowels)
    agent_occlusions_metrics = abx_utils.get_occlusions_metrics(consonants, consonants_indexes, datasets_estimated_ema, palate)
    agents_occlusions_metrics[agent_path] = agent_occlusions_metrics
    utils.pickle_dump("../out/imitative_agent/occlusions_cache.pickle", agents_occlusions_metrics)

  0%|          | 0/48 [00:00<?, ?it/s]

In [6]:
def show_agent(agent_alias):
    agent_path = agents_alias[agent_alias]
    agent = ImitativeAgent.reload(agent_path, load_nn=False)
    synth_dataset = agent.synthesizer.dataset
    palate = synth_dataset.palate
    
    agent_occlusions_metrics = agents_occlusions_metrics[agent_path]
    notebooks.show_occlusions_metrics(agent_occlusions_metrics, palate)

ipw.interactive(show_agent, agent_alias=sorted(agents_alias.keys()))

interactive(children=(Dropdown(description='agent_alias', options=('pb2007 synth_art=art_params jerk_c=0 jerk_…

In [18]:
def show_group(agent_group_name):
    agent_group = agents_group[agent_group_name]
    agent = ImitativeAgent.reload(agent_group[0], load_nn=False)
    synth_dataset = agent.synthesizer.dataset
    palate = synth_dataset.palate
    
    phones = list(agents_occlusions_metrics[agent_group[0]].keys())
    distances = ["tongue_tip", "tongue_mid"]
    
    def show_phone(phone):
        plt.figure(dpi=120)
        
        for i_distance, distance in enumerate(distances):
            agents_phone_ema = []
            for agent_path in agent_group:
                phone_occlusions_metrics = agents_occlusions_metrics[agent_path][phone]
                agent_phone_ema = phone_occlusions_metrics["min_%s_ema" % distance]
                agents_phone_ema.append(agent_phone_ema)
            agents_phone_ema = np.concatenate(agents_phone_ema, axis=0)
            
            ax = plt.subplot(2, 1, 1 + i_distance, aspect="equal")
            ax.plot(palate[:, 0], palate[:, 1])
            ax.scatter(agents_phone_ema[:, 0::2], agents_phone_ema[:, 1::2], s=1)

        plt.tight_layout()
        plt.show()
            
    ipw.interact(show_phone, phone=phones)
    

ipw.interactive(show_group, agent_group_name=sorted(agents_group.keys()))

interactive(children=(Dropdown(description='agent_group_name', options=('pb2007 synth_art=art_params jerk_c=0 …