a trial has four defining factors. NB that a trial is more than an item now; it’s the conjunction of two items and an actual target we are trying to reach:

- source inflected phones (e.g. P ER P)
- target base phones (e.g. D AW)
- target GT next phone (e.g. D)
- target actual desired next phone (e.g. B)

from boolean statements relating these values we can derive critical conditions:

- **Control**: source next phone != target next phone. Tests how reachable the target
- **Weak experiment**: Can we reach the GT next phone? True when target GT == target actual
- **Strong experiment**: Can we reach non-GT next phones which are attested in the lexicon? True when target GT != target actual

In [None]:
%load_ext autoreload
%autoreload 2

In [96]:
from collections import defaultdict
import itertools
import re

import matplotlib.pyplot as plt
import numpy as np
from omegaconf import OmegaConf
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import analogy, analogy_pseudocausal
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechHiddenStateDataset


In [None]:
torch.set_num_threads(8)

In [97]:
base_model = "w2v2_pc_8"

model_class = "ffff_32-pc-mAP1"#discrim-rnn_32-pc-mAP1"
model_name = "word_broad_10frames_fixedlen25"

train_dataset = "librispeech-train-clean-100"
dataset = train_dataset
experiment = "syllable_at_0"

# hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
hidden_states_path = f"/scratch/jgauthier/{base_model}_{train_dataset}.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{dataset}.npy"

inputs_dir = f"outputs/analogy_pseudocausal_broad/inputs/{dataset}/w2v2_pc/{experiment}"
instances_path = f"{inputs_dir}/instances.csv"
state_space_specs_path = f"{inputs_dir}/state_space_spec.h5"

output_dir = f"."

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

In [100]:
# load OmegaConf from yaml with `experiment`
config = OmegaConf.load(f"conf/experiments/analogy_pseudocausal/{experiment}.yaml")

In [None]:
agg_fns = [
    ("mean_within_cut", config.unit_level)
]

## Prepare model representations

In [None]:
if embeddings_path == "ID":
    model_representations = SpeechHiddenStateDataset.from_hdf5(hidden_states_path).states
else:
    with open(embeddings_path, "rb") as f:
        model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

In [94]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec)
trajectory = aggregate_state_trajectory(trajectory, state_space_spec, agg_fns[0])

  0%|          | 0/32046 [00:00<?, ?it/s]

Aggregating:   0%|          | 0/32046 [00:00<?, ?label/s]

In [95]:
agg, agg_src = flatten_trajectory(trajectory)

## Prepare metadata

In [101]:
cuts_df = state_space_spec.cuts.xs(config.unit_level, level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

agg_flat_idxs = pd.Series(list(range(len(agg_src))),
                          index=pd.MultiIndex.from_tuples([tuple(xs) for xs in agg_src],
                                                          names=["label_idx", "instance_idx", "frame_idx"]))
cuts_df = pd.merge(cuts_df, agg_flat_idxs.rename("traj_flat_idx"), left_index=True, right_index=True)

In [102]:
label2idx = {l: i for i, l in enumerate(state_space_spec.labels)}

In [103]:
if type(cuts_df.description.iloc[0]) == tuple:
    cuts_df["description"] = cuts_df.description.apply(''.join)
cut_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [None]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
word_freq_df = word_freq_df.loc[~word_freq_df.index.duplicated()]
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq_rel", "TwitterFreq_rel", "NewsFreq_rel"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

## Prepare inputs

In [104]:
all_instances_df = pd.read_csv(instances_path)
all_instances_df["base_phones"] = all_instances_df["base_phones"].fillna("")

In [106]:
next_unit_set = set(all_instances_df.post_divergence)

In [107]:
expt_cohort = defaultdict(set)
for _, row in all_instances_df.iterrows():
    expt_cohort[row["base_phones"]].add(row["post_divergence"])
expt_cohort = dict(expt_cohort)

In [105]:
# Prepare prediction equivalences: effectively a set of evaluations which 
# can be run on any individual prediction trial, establishing which outputs
# are "correct" or incorrect
all_prediction_equivalences = {}

for (base_phones, inflected_phones, next_unit), _ in tqdm(all_instances_df.groupby(["base_phones", "inflected_phones", "inflection"])):
    equiv_key = (inflected_phones,)
    all_prediction_equivalences[equiv_key] = \
        analogy_pseudocausal.prepare_prediction_equivalences(cuts_df, cut_forms, base_phones, next_unit)

  0%|          | 0/92 [00:00<?, ?it/s]

## Generate trials

In [108]:
counterfactual_inflections = [
    {"base_phones": cohort,
     "inflected_phones": f"{cohort} {unit}".strip(),
     "counterfactual_inflection": unit,
     "post_divergence": unit}
    for cohort, next_units in expt_cohort.items()
    for unit in next_units
]

In [109]:
ctf_trials = pd.merge(
    all_instances_df[["base_phones", "inflection", "cohort_length", "next_phoneme_idx", "inflected", "inflected_idx", "inflected_instance_idx"]],
    pd.DataFrame(counterfactual_inflections))
ctf_trials = ctf_trials[ctf_trials.inflection != ctf_trials.counterfactual_inflection]
ctf_trials["inflection"] = "ctf-" + ctf_trials.counterfactual_inflection
ctf_trials = ctf_trials.drop(columns=["counterfactual_inflection"])

In [110]:
all_trials = pd.concat([ctf_trials, all_instances_df])

## Behavioral tests

In [112]:
# ground-truth experiments:
# these use arbitrary sources to try to predict the ground-truth next phoneme observed in word tokens
gt_experiments = {
    f"gt-{source_inflection}_{prefix}_{target_inflection}": {
        "base_query": f"inflection == '{source_inflection}'",
        "inflected_query": f"base_phones == '{prefix}' and inflection == '{target_inflection}'",
        "equivalence_keys": ["inflected_phones", "inflected"],
        "prediction_equivalence_keys": ["to_inflected_phones"],
    }
    for source_inflection in next_unit_set
    for prefix, valid_next_phones in expt_cohort.items()
    for target_inflection in valid_next_phones
}

In [113]:
# counterfactual experiments:
# these use arbitrary sources to try to generate other phoneme completions which are not the
# ground-truth next phoneme observed in word tokens, but which are consistent with an attested
# word prefix in the lexicon
ctf_experiments = {
    f"ctf-{source_inflection}_{prefix}_{target_inflection}": {
        "base_query": f"inflection == '{source_inflection}'",
        "inflected_query": f"base_phones == '{prefix}' and inflection == 'ctf-{target_inflection}'",
        "equivalence_keys": ["inflected_phones", "inflected"],
        "prediction_equivalence_keys": ["to_inflected_phones"],
    }
    for source_inflection in next_unit_set
    for prefix, valid_next_phones in expt_cohort.items()
    for target_inflection in valid_next_phones
}

In [114]:
experiments = {
    # **gt_experiments,
    **ctf_experiments,
}

In [115]:
# TODO reinstate this
# small_targets = all_instances_df[all_instances_df.inflection.str.startswith("small-")].inflection.str.split("small-").str[1].unique()
# for phone in small_targets:
#     for source_phone in next_phon_set:
#         experiments[f"{source_phone}-to-small-{phone}"] = {
#             "base_query": f"inflection == '{source_phone}'",
#             "inflected_query": f"inflection == 'small-{phone}'",
#             "equivalence_keys": ["inflected_phones", "inflected"],
#             "prediction_equivalence_keys": ["to_inflected_phones"],
#         }

In [None]:
# name = "ctf-AH__SEH"
# ret = analogy_pseudocausal.run_experiment_equiv_level(
#     name,
#     ctf_experiments[name],
#     state_space_spec, all_trials,
#     agg, agg_src,
#     cut_phonemic_forms=cut_forms,
#     prediction_equivalences=all_prediction_equivalences,
#     verbose=True,
#     num_samples=5,
#     max_num_vector_samples=100,
#     seed=seed,
#     device="cpu")

ctf-AH__SEH


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

 ('AH', "other's") -> ('SEH', 'involved') invited IHN VAY TIHD 0.3593381481586116
 ('AH', "other's") -> ('SEH', 'involved') involuntarily IHN VAA LAHN TER AH LIY 0.3602719551589702
 ('AH', "other's") -> ('SEH', 'involved') unavailing AH NAH VEY LIHNG 0.36736555376849994
 ('AH', "other's") -> ('SEH', 'involved') understood AHN DER STUHD 0.37438648830618393
 ('AH', "other's") -> ('SEH', 'involved') involve IHN VAALV 0.37769021418356064
 ('AH', "other's") -> ('SEH', 'lincoln') england IHNG GLAHND 0.3784921365883768
 ('AH', "other's") -> ('SEH', 'lincoln') olenin AA LIH NIHN 0.3803331381943457
 ('AH', "other's") -> ('SEH', 'lincoln') olenin AA LIH NIHN 0.38131541685301873
 ('AH', "other's") -> ('SEH', 'lincoln') english IHNG GLIHSH 0.38221096715512803
 ('AH', "other's") -> ('SEH', 'lincoln') enquired IHN KWAYRD 0.38841173498850334
 ('AH', "another's") -> ('SEH', 'solution') conversation KAAN VER SEY SHAHN 0.4585924603634903
 ('AH', "another's") -> ('SEH', 'solution') affection AH FEHK SHAH

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
t_agg = torch.tensor(agg, device=device)
t_agg_src = torch.tensor(agg_src, device=device)

# pre-compute flat idx lookup
flat_idx_lookup = {(label_idx, instance_idx, phoneme_idx): flat_idx
                    for flat_idx, (label_idx, instance_idx, phoneme_idx) in enumerate(agg_src)}

experiment_results = pd.concat({
    experiment: analogy_pseudocausal.run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_trials,
        t_agg, t_agg_src,
        flat_idx_lookup=flat_idx_lookup,
        cut_phonemic_forms=cut_forms,
        prediction_equivalences=all_prediction_equivalences,
        num_samples=250,
        max_num_vector_samples=100,
        seed=seed,
        device=device)
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])

In [None]:
experiment_results["control"] = experiment_results.inflection_to.str.split("-").str[-1] != experiment_results.inflection_from
experiment_results["ctf"] = experiment_results.inflection_to.str.startswith("ctf-")
experiment_results["inflection_to_clean"] = experiment_results.inflection_to.str.replace("^[a-z]+-", "", regex=True)
experiment_results.to_csv(f"{output_dir}/experiment_results.csv")

In [None]:
metric = "matches_next_phoneme_weak_target_rank"
experiment_results.query("not control and ctf").groupby(["ctf", "to_base_phones", "inflection_to_clean"])[metric].mean().sort_values()

In [None]:
metric = "matches_next_phoneme_weak_target_rank"
advantage_df = \
    experiment_results.query("not control and ctf").groupby(["ctf", "to_base_phones", "inflection_to_clean"])[metric].mean() - \
    experiment_results.query("control and ctf").groupby(["ctf", "to_base_phones", "inflection_to_clean"])[metric].mean()
advantage_df.sort_values()

In [None]:
experiment_results.groupby(["ctf", "control", "inflection_to_clean"])[metric].mean()

In [None]:
experiment_results.query("ctf")[["inflection_from", "inflection_to"]].value_counts()

In [None]:
advantage_mat = experiment_results.query("ctf").pivot_table(
    index=["inflection_from"],
    columns=["inflection_to"],
    values="matches_next_phoneme_weak_target_rank",
)
sns.heatmap(advantage_mat / np.diag(advantage_mat.values), cmap="coolwarm", center=1)