In [10]:
import sys
sys.path.append("..")

from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
import numpy as np 
from tqdm.notebook import tqdm
import pandas as pd

from imitative_agent import ImitativeAgent
from lib.dataset_wrapper import Dataset
from lib import utils
from lib import abx_utils
from lib import notebooks
import os

In [11]:
agents_path = glob("../out/imitative_agent/*/")
agents_path.sort()

agents_alias = {}
agents_group = {}

for agent_path in agents_path:
    print(os.getcwd())
    print(agent_path)
    agent = ImitativeAgent.reload(agent_path, load_nn=False)
    config = agent.config
    
    #if config["training"]["jerk_loss_ceil"] != 0.014: continue
        
    agent_i = agent_path[-2]
    agent_alias = " ".join((
        f"{','.join(config['dataset']['names'])}",
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}",
        f"jerk_c={config['training']['jerk_loss_ceil']}",
        f"jerk_w={config['training']['jerk_loss_weight']}",
        f"bi={config['model']['inverse_model']['bidirectional']}",
        f"({agent_i})",
    ))
    agents_alias[agent_alias] = agent_path
    
    agent_group = " ".join((
        f"{','.join(config['dataset']['names'])}",
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}",
        f"jerk_c={config['training']['jerk_loss_ceil']}",
        f"jerk_w={config['training']['jerk_loss_weight']}",
        f"bi={config['model']['inverse_model']['bidirectional']}",
    ))
    if agent_group not in agents_group:
        agents_group[agent_group] = []
    agents_group[agent_group].append(agent_path)

/mnt/c/Users/vpaul/Documents/Inner_Speech/agent/imitative_agent
../out/imitative_agent/bfa2f1a4bdf7d85496bc8c867342f96e-0/


TypeError: to() received an invalid combination of arguments - got (tuple), but expected one of:
 * (torch.device device, torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
 * (torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
 * (Tensor tensor, bool non_blocking, bool copy, *, torch.memory_format memory_format)


In [3]:
TONGUE_CONSONANTS = ["p", "b", "t", "d", "k", "g"]
DETECTION_METHODS = {
    "p": "lips",
    "b": "lips",
    "t": "tongue_tip",
    "d": "tongue_tip",
    "k": "tongue_mid",
    "g": "tongue_mid",
}

In [4]:
agents_ema = {}
datasets_occlusions = {}

for agent_alias, agent_path in tqdm(agents_alias.items()):
    agent_ema = agents_ema[agent_path] = {}
    
    agent = ImitativeAgent.reload(agent_path)
    synth_dataset = agent.synthesizer.dataset
    
    main_dataset = agent.get_main_dataset()
    agent_features = agent.repeat_datasplit(None)
    
    for dataset_name, dataset_features in agent_features.items():
        if dataset_name not in datasets_occlusions:
            dataset = Dataset(dataset_name)
            palate = dataset.palate
            vowels = dataset.phones_infos["vowels"]
            datasets_lab = {dataset_name: dataset.lab}
            datasets_ema = {dataset_name: dataset.get_items_data("ema")}
            consonants_indexes = abx_utils.get_datasets_phones_indexes(
                datasets_lab, TONGUE_CONSONANTS, vowels
            )
            datasets_occlusions[dataset_name] = abx_utils.get_occlusions_indexes(
                TONGUE_CONSONANTS, consonants_indexes, DETECTION_METHODS, datasets_ema, palate,
            )
        
        items_estimated_ema = agent_ema[dataset_name] = {}
        
        items_estimated_art = dataset_features["art_estimated"]
        for item_name, item_estimated_art in items_estimated_art.items():
            item_estimated_ema = synth_dataset.art_to_ema(item_estimated_art)
            items_estimated_ema[item_name] = item_estimated_ema

0it [00:00, ?it/s]

In [5]:
def show_dataset(dataset_name):
    dataset = Dataset(dataset_name)
    items_ema = dataset.get_items_data("ema")
    dataset_occlusions = datasets_occlusions[dataset_name]
    palate = dataset.palate
    
    display_xlim = (dataset.ema_limits["xmin"] * 0.95, dataset.ema_limits["xmax"] * 1.05)
    display_ylim = (dataset.ema_limits["ymin"] * 0.95, dataset.ema_limits["ymax"] * 1.05)
    
    def show_occlusions(offset=2):
        consonants_stats = {}
        for consonant, occlusions in dataset_occlusions.items():
            plt.figure(figsize=(12, 3), dpi=60)

            ax_start = plt.subplot(121, aspect="equal")
            ax_start.set_title("%s start (PB original)" % consonant)
            ax_start.set_xlim(*display_xlim)
            ax_start.set_ylim(*display_ylim)
            ax_start.plot(palate[:, 0], palate[:, 1])
            ax_start.set_xticks([])
            ax_start.set_yticks([])

            ax_stop = plt.subplot(122, aspect="equal")
            ax_stop.set_title("%s stop (PB original)" % consonant)
            ax_stop.set_xlim(*display_xlim)
            ax_stop.set_ylim(*display_ylim)
            ax_stop.plot(palate[:, 0], palate[:, 1])
            ax_stop.set_xticks([])
            ax_stop.set_yticks([])

            occlusions_start_ema = []
            occlusions_stop_ema = []
            for occlusion in occlusions:
                item_ema = items_ema[occlusion[1]]
                occlusions_start_ema.append(item_ema[occlusion[2] - offset])
                occlusions_stop_ema.append(item_ema[occlusion[3] + offset])
            occlusions_start_ema = np.array(occlusions_start_ema)
            occlusions_stop_ema = np.array(occlusions_stop_ema) 
            
            occlusions_stats = consonants_stats[consonant] = {}
            for occlusions_type, occlusions_ema in {"start": occlusions_start_ema, "stop": occlusions_stop_ema}.items():
                lips_distance = np.sqrt(np.sum((occlusions_ema[:, 10:12] - occlusions_ema[:, 8:10]) ** 2, axis=1))
                occlusions_stats["%s_lips" % occlusions_type] = "%.2f ±%.2f" % (lips_distance.mean(), lips_distance.std())
                
                tongue_tip_distance = abx_utils.coil_distances_from_palate(occlusions_ema[:, 2:4], palate)
                occlusions_stats["%s_tongue_tip" % occlusions_type] = "%.2f ±%.2f" % (tongue_tip_distance.mean(), tongue_tip_distance.std())
                
                tongue_mid_distance = abx_utils.coil_distances_from_palate(occlusions_ema[:, 4:6], palate)
                occlusions_stats["%s_tongue_mid" % occlusions_type] = "%.2f ±%.2f" % (tongue_mid_distance.mean(), tongue_mid_distance.std())

            ax_start.scatter(occlusions_start_ema[:, 0::2], occlusions_start_ema[:, 1::2], c="tab:blue", s=2)
            ax_stop.scatter(occlusions_stop_ema[:, 0::2], occlusions_stop_ema[:, 1::2], c="tab:blue", s=2)

            plt.subplots_adjust(wspace=-.1)
            plt.show()
            
        consonants_stats = pd.DataFrame.from_dict(consonants_stats, orient="index")
        display(consonants_stats)
    
    ipw.interact(show_occlusions, offset=(0, 10))

ipw.interactive(show_dataset, dataset_name=datasets_occlusions.keys())

interactive(children=(Dropdown(description='dataset_name', options=(), value=None), Output()), _dom_classes=('…

In [6]:
def show_agent(agent_alias):
    agent_path = agents_alias[agent_alias]
    agent = ImitativeAgent.reload(agent_path, load_nn=False)
    synth_dataset = agent.synthesizer.dataset
    palate = synth_dataset.palate
    agent_ema = agents_ema[agent_path]
    
    display_xlim = (synth_dataset.ema_limits["xmin"] * 0.95, synth_dataset.ema_limits["xmax"] * 1.05)
    display_ylim = (synth_dataset.ema_limits["ymin"] * 0.95, synth_dataset.ema_limits["ymax"] * 1.05)
    
    def show_occlusions(offset=2):
        consonants_stats = {}
    
        for dataset_name in agent.config["dataset"]["names"]:
            dataset = Dataset(dataset_name)
            items_ema = agent_ema[dataset_name]
            dataset_occlusions = datasets_occlusions[dataset_name]

            for consonant, occlusions in dataset_occlusions.items():
                plt.figure(figsize=(12, 3), dpi=60)

                ax_start = plt.subplot(121, aspect="equal")
                ax_start.set_title("%s start (jerk=%s)" % (consonant, agent.config["training"]["jerk_loss_weight"]))
                ax_start.set_xlim(*display_xlim)
                ax_start.set_ylim(*display_ylim)
                ax_start.plot(palate[:, 0], palate[:, 1])
                ax_start.set_xticks([])
                ax_start.set_yticks([])

                ax_stop = plt.subplot(122, aspect="equal")
                ax_stop.set_title("%s stop (jerk=%s)" % (consonant, agent.config["training"]["jerk_loss_weight"]))
                ax_stop.set_xlim(*display_xlim)
                ax_stop.set_ylim(*display_ylim)
                ax_stop.plot(palate[:, 0], palate[:, 1])
                ax_stop.set_xticks([])
                ax_stop.set_yticks([])

                occlusions_start_ema = []
                occlusions_stop_ema = []
                for occlusion in occlusions:
                    item_ema = items_ema[occlusion[1]]
                    occlusions_start_ema.append(item_ema[occlusion[2] - offset])
                    occlusions_stop_ema.append(item_ema[occlusion[3] + offset])

                occlusions_start_ema = np.array(occlusions_start_ema)
                occlusions_stop_ema = np.array(occlusions_stop_ema) 
                
                occlusions_stats = consonants_stats[consonant] = {}
                for occlusions_type, occlusions_ema in {"start": occlusions_start_ema, "stop": occlusions_stop_ema}.items():
                    lips_distance = np.sqrt(np.sum((occlusions_ema[:, 10:12] - occlusions_ema[:, 8:10]) ** 2, axis=1))
                    occlusions_stats["%s_lips" % occlusions_type] = "%.2f ±%.2f" % (lips_distance.mean(), lips_distance.std())

                    tongue_tip_distance = abx_utils.coil_distances_from_palate(occlusions_ema[:, 2:4], palate)
                    occlusions_stats["%s_tongue_tip" % occlusions_type] = "%.2f ±%.2f" % (tongue_tip_distance.mean(), tongue_tip_distance.std())

                    tongue_mid_distance = abx_utils.coil_distances_from_palate(occlusions_ema[:, 4:6], palate)
                    occlusions_stats["%s_tongue_mid" % occlusions_type] = "%.2f ±%.2f" % (tongue_mid_distance.mean(), tongue_mid_distance.std())

                ax_start.scatter(occlusions_start_ema[:, 0::2], occlusions_start_ema[:, 1::2], c="tab:blue", s=2)
                ax_stop.scatter(occlusions_stop_ema[:, 0::2], occlusions_stop_ema[:, 1::2], c="tab:blue", s=2)

                plt.subplots_adjust(wspace=-.1)
                plt.show()
                 
        consonants_stats = pd.DataFrame.from_dict(consonants_stats, orient="index")
        display(consonants_stats)
    ipw.interact(show_occlusions, offset=(0, 10))

ipw.interactive(show_agent, agent_alias=sorted(agents_alias.keys()))

interactive(children=(Dropdown(description='agent_alias', options=(), value=None), Output()), _dom_classes=('w…