a trial has four defining factors. NB that a trial is more than an item now; it’s the conjunction of two items and an actual target we are trying to reach:

- source inflected phones (e.g. P ER P)
- target base phones (e.g. D AW)
- target GT next phone (e.g. D)
- target actual desired next phone (e.g. B)

from boolean statements relating these values we can derive critical conditions:

- **Control**: source next phone != target next phone. Tests how reachable the target
- **Weak experiment**: Can we reach the GT next phone? True when target GT == target actual
- **Strong experiment**: Can we reach non-GT next phones which are attested in the lexicon? True when target GT != target actual

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from collections import defaultdict
import itertools
import re

import matplotlib.pyplot as plt
import numpy as np
from omegaconf import OmegaConf
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import analogy, analogy_pseudocausal
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechHiddenStateDataset


In [None]:
torch.set_num_threads(8)

In [None]:
base_model = "w2v2_pc_8"

model_class = "ffff_32-pc-mAP1"#discrim-rnn_32-pc-mAP1"
model_name = "word_broad_10frames_fixedlen25"

train_dataset = "librispeech-train-clean-100"
dataset = train_dataset
experiment = "phoneme_at_1"

# hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
hidden_states_path = f"/scratch/jgauthier/{base_model}_{train_dataset}.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{dataset}.npy"

inputs_dir = f"outputs/analogy_pseudocausal_broad/inputs/{dataset}/w2v2_pc/{experiment}"
instances_path = f"{inputs_dir}/instances.csv"
state_space_specs_path = f"{inputs_dir}/state_space_spec.h5"

output_dir = f"."

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

In [None]:
# load OmegaConf from yaml with `experiment`
config = OmegaConf.load(f"conf/experiments/analogy_pseudocausal/{experiment}.yaml")

In [None]:
agg_fns = [
    ("mean_within_cut", config.unit_level)
]

## Prepare model representations

In [None]:
if embeddings_path == "ID":
    model_representations = SpeechHiddenStateDataset.from_hdf5(hidden_states_path).states
else:
    with open(embeddings_path, "rb") as f:
        model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec)
trajectory = aggregate_state_trajectory(trajectory, state_space_spec, agg_fns[0])

In [None]:
agg, agg_src = flatten_trajectory(trajectory)

## Prepare metadata

In [None]:
cuts_df = state_space_spec.cuts.xs(config.unit_level, level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

agg_flat_idxs = pd.Series(list(range(len(agg_src))),
                          index=pd.MultiIndex.from_tuples([tuple(xs) for xs in agg_src],
                                                          names=["label_idx", "instance_idx", "frame_idx"]))
cuts_df = pd.merge(cuts_df, agg_flat_idxs.rename("traj_flat_idx"), left_index=True, right_index=True)

In [None]:
label2idx = {l: i for i, l in enumerate(state_space_spec.labels)}

In [None]:
if type(cuts_df.description.iloc[0]) == tuple:
    cuts_df["description"] = cuts_df.description.apply(''.join)
cut_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [None]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
word_freq_df = word_freq_df.loc[~word_freq_df.index.duplicated()]
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq_rel", "TwitterFreq_rel", "NewsFreq_rel"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

## Prepare inputs

In [None]:
all_instances_df = pd.read_csv(instances_path)
all_instances_df["base_phones"] = all_instances_df["base_phones"].fillna("")

In [None]:
next_unit_set = set(all_instances_df.post_divergence)

In [None]:
# Take the 20 most frequent cohorts.
study_cohorts = all_instances_df.groupby("base_phones").inflection.value_counts() \
    .groupby("base_phones").filter(lambda xs: len(xs) >= 5) \
    .groupby("base_phones").sum().sort_values().tail(20)

In [None]:
expt_cohort = defaultdict(set)
for _, row in all_instances_df[all_instances_df.base_phones.isin(study_cohorts.index)].iterrows():
    expt_cohort[row["base_phones"]].add(row["post_divergence"])
expt_cohort = dict(expt_cohort)

In [None]:
cohort_length = 2
all_prediction_equivalences = {
    (inflected_phones,): {
        "matches_next_phoneme": set(),
        "matches_next_phoneme_weak": set(),
        "matches_cohort": set(),
    }
    for inflected_phones in all_instances_df["inflected_phones"].unique()
}

In [None]:
cdf = cuts_df.reset_index()

In [None]:
cdf["base_phones"] = cdf.groupby(["label_idx", "instance_idx"]).description.transform(lambda xs: " ".join(xs[:2]) if len(xs) > cohort_length + 1 else None)

In [None]:
cdf["next_unit"] = cdf.groupby(["label_idx", "instance_idx"]).description.transform(lambda xs: xs.iloc[cohort_length] if len(xs) > cohort_length else None)

In [None]:
infl_phones = set(tuple(inflected_phones.strip().split()) for inflected_phones in all_instances_df["inflected_phones"].unique())
all_prediction_equivalences = {
    (" ".join(inflected_phones),): {
        "matches_next_phoneme": set(),
        "matches_next_phoneme_weak": set(),
        "matches_cohort": set(),
    }
    for inflected_phones in infl_phones
}
for next_unit, rows in tqdm(cdf.groupby("next_unit")):
    for inflected_phones in infl_phones:
        if len(inflected_phones) > cohort_length and inflected_phones[cohort_length] == next_unit:
            all_prediction_equivalences[" ".join(inflected_phones),]["matches_next_phoneme_weak"] |= \
                set(rows.traj_flat_idx)
            all_prediction_equivalences[" ".join(inflected_phones),]["matches_next_phoneme"] |= \
                set(rows[rows.frame_idx == cohort_length].traj_flat_idx)
for cohort, rows in tqdm(cdf.groupby("base_phones")):
    for inflected_phones in infl_phones:
        if len(inflected_phones) > cohort_length and " ".join(inflected_phones[:cohort_length]) == cohort:
            all_prediction_equivalences[" ".join(inflected_phones),]["matches_cohort"] |= \
                set(rows[rows.frame_idx >= cohort_length].traj_flat_idx)

In [None]:
for (inflected_phones,), equivs in all_prediction_equivalences.items():
    equivs["matches_cohort_and_next_phoneme"] = \
        equivs["matches_cohort"] & equivs["matches_next_phoneme"]
    equivs["matches_cohort_and_next_phoneme_weak"] = \
        equivs["matches_cohort"] & equivs["matches_next_phoneme_weak"]

In [None]:
all_prediction_equivalences = {
    key: {k: torch.tensor(list(vs)) for k, vs in equivs.items()}
    for key, equivs in all_prediction_equivalences.items()
}

In [None]:
# cdf = cuts_df.reset_index()
# for inflected_phones, rows in tqdm(all_instances_df.groupby("inflected_phones")):
#     phoneme_match_instances = all_instances_df[all_instances_df.post_divergence == rows.post_divergence.iloc[0]]
#     cohort_match_instances = all_instances_df[all_instances_df.base_phones == rows.base_phones.iloc[0]]

#     phoneme_match_cuts = cdf.merge(phoneme_match_instances[["inflected_idx", "inflected_instance_idx"]],
#                         left_on=["label_idx", "instance_idx"],
#                         right_on=["inflected_idx", "inflected_instance_idx"])
#     cohort_match_cuts = cdf.merge(cohort_match_instances[["inflected_idx", "inflected_instance_idx"]],
#                         left_on=["label_idx", "instance_idx"],
#                         right_on=["inflected_idx", "inflected_instance_idx"]) \
#         .query("frame_idx >= @cohort_length")

#     strong_phoneme_match_cuts = phoneme_match_cuts[phoneme_match_cuts.frame_idx == cohort_length]
#     weak_phoneme_match_cuts = phoneme_match_cuts

#     all_prediction_equivalences[(inflected_phones,)]['matches_next_phoneme'].update(
#         strong_phoneme_match_cuts.traj_flat_idx)
#     all_prediction_equivalences[(inflected_phones,)]['matches_next_phoneme_weak'].update(
#         weak_phoneme_match_cuts.traj_flat_idx)
#     all_prediction_equivalences[(inflected_phones,)]['matches_cohort'].update(
#         cohort_match_cuts.traj_flat_idx)

In [None]:
# # Prepare prediction equivalences: effectively a set of evaluations which 
# # can be run on any individual prediction trial, establishing which outputs
# # are "correct" or incorrect
# all_prediction_equivalences = {}

# for (base_phones, inflected_phones, next_unit), _ in tqdm(all_instances_df.groupby(["base_phones", "inflected_phones", "inflection"])):
#     equiv_key = (inflected_phones,)
#     all_prediction_equivalences[equiv_key] = \
#         analogy_pseudocausal.prepare_prediction_equivalences(cuts_df, cut_forms, base_phones, next_unit)

## Generate trials

In [None]:
counterfactual_inflections = [
    {"base_phones": cohort,
     "inflected_phones": f"{cohort} {unit}".strip(),
     "counterfactual_inflection": unit,
     "post_divergence": unit}
    for cohort, next_units in expt_cohort.items()
    for unit in next_units
]

In [None]:
ctf_trials = pd.merge(
    all_instances_df[["base_phones", "inflection", "cohort_length", "next_phoneme_idx", "inflected", "inflected_idx", "inflected_instance_idx"]],
    pd.DataFrame(counterfactual_inflections))
ctf_trials = ctf_trials[ctf_trials.inflection != ctf_trials.counterfactual_inflection]
ctf_trials["inflection"] = "ctf-" + ctf_trials.counterfactual_inflection
ctf_trials = ctf_trials.drop(columns=["counterfactual_inflection"])

In [None]:
all_trials = pd.concat([ctf_trials, all_instances_df])

## Behavioral tests

In [None]:
# ground-truth experiments:
# these use arbitrary sources to try to predict the ground-truth next phoneme observed in word tokens
gt_experiments = {
    f"gt-{source_inflection}_{prefix}_{target_inflection}": {
        "base_query": f"inflection == '{source_inflection}'",
        "inflected_query": f"base_phones == '{prefix}' and inflection == '{target_inflection}'",
        "equivalence_keys": ["inflected_phones", "inflected"],
        "prediction_equivalence_keys": ["to_inflected_phones"],
    }
    for source_inflection in next_unit_set
    for prefix, valid_next_phones in expt_cohort.items()
    for target_inflection in valid_next_phones
}

In [None]:
# counterfactual experiments:
# these use arbitrary sources to try to generate other phoneme completions which are not the
# ground-truth next phoneme observed in word tokens, but which are consistent with an attested
# word prefix in the lexicon
ctf_experiments = {
    f"ctf-{source_inflection}_{prefix}_{target_inflection}": {
        "base_query": f"inflection == '{source_inflection}'",
        "inflected_query": f"base_phones == '{prefix}' and inflection == 'ctf-{target_inflection}'",
        "equivalence_keys": ["inflected_phones", "inflected"],
        "prediction_equivalence_keys": ["to_inflected_phones"],
    }
    for source_inflection in next_unit_set
    for prefix, valid_next_phones in expt_cohort.items()
    for target_inflection in valid_next_phones
}

In [None]:
experiments = {
    # **gt_experiments,
    **ctf_experiments,
}

In [None]:
# TODO reinstate this
# small_targets = all_instances_df[all_instances_df.inflection.str.startswith("small-")].inflection.str.split("small-").str[1].unique()
# for phone in small_targets:
#     for source_phone in next_phon_set:
#         experiments[f"{source_phone}-to-small-{phone}"] = {
#             "base_query": f"inflection == '{source_phone}'",
#             "inflected_query": f"inflection == 'small-{phone}'",
#             "equivalence_keys": ["inflected_phones", "inflected"],
#             "prediction_equivalence_keys": ["to_inflected_phones"],
#         }

In [None]:
ctf_experiments

In [None]:
# def go():
#     name = "ctf-DH_AH N_M"
#     ret = analogy_pseudocausal.run_experiment_equiv_level(
#         name,
#         experiments[name],
#         state_space_spec, all_trials,
#         agg, agg_src,
#         cut_phonemic_forms=cut_forms,
#         prediction_equivalences=all_prediction_equivalences,
#         verbose=True,
#         num_samples=5,
#         max_num_vector_samples=100,
#         seed=seed,
#     device="cuda:1")

In [None]:
# %load_ext line_profiler
# %lprun -f analogy_pseudocausal.run_experiment_equiv_level go()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
t_agg = torch.tensor(agg, device=device)
t_agg_src = torch.tensor(agg_src, device=device)

# pre-compute flat idx lookup
flat_idx_lookup = {(label_idx, instance_idx, phoneme_idx): flat_idx
                    for flat_idx, (label_idx, instance_idx, phoneme_idx) in enumerate(agg_src)}

experiment_results = pd.concat({
    experiment: analogy_pseudocausal.run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_trials,
        t_agg, t_agg_src,
        flat_idx_lookup=flat_idx_lookup,
        cut_phonemic_forms=cut_forms,
        prediction_equivalences=all_prediction_equivalences,
        num_samples=50,
        max_num_vector_samples=100,
        seed=seed,
        device=device)
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])

In [None]:
experiment_results["control"] = experiment_results.inflection_to.str.split("-").str[-1] != experiment_results.inflection_from
experiment_results["ctf"] = experiment_results.inflection_to.str.startswith("ctf-")
experiment_results["inflection_to_clean"] = experiment_results.inflection_to.str.replace("^[a-z]+-", "", regex=True)
experiment_results.to_csv(f"{output_dir}/experiment_results.csv")