In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict
import itertools
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import analogy, analogy_pseudocausal
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechHiddenStateDataset


In [None]:
torch.set_num_threads(8)

In [None]:
base_model = "w2v2_pc_8"

model_class = "ffff_32-pc-mAP1"#discrim-rnn_32-pc-mAP1"
model_name = "word_broad_10frames_fixedlen25"

train_dataset = "librispeech-train-clean-100"
# hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
hidden_states_path = f"/scratch/jgauthier/{base_model}_{train_dataset}.h5"
state_space_specs_path = f"outputs/analogy/inputs/{train_dataset}/w2v2_pc/state_space_spec.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{train_dataset}.npy"

output_dir = f"."

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

agg_fns = [
    ("mean_within_cut", "phoneme")
]

## Prepare model representations

In [None]:
if embeddings_path == "ID":
    model_representations = SpeechHiddenStateDataset.from_hdf5(hidden_states_path).states
else:
    with open(embeddings_path, "rb") as f:
        model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec)
trajectory = aggregate_state_trajectory(trajectory, state_space_spec, agg_fns[0])

In [None]:
agg, agg_src = flatten_trajectory(trajectory)

## Prepare metadata

In [None]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

agg_flat_idxs = pd.Series(list(range(len(agg_src))),
                          index=pd.MultiIndex.from_tuples([tuple(xs) for xs in agg_src],
                                                          names=["label_idx", "instance_idx", "frame_idx"]))
cuts_df = pd.merge(cuts_df, agg_flat_idxs.rename("traj_flat_idx"), left_index=True, right_index=True)

In [None]:
label2idx = {l: i for i, l in enumerate(state_space_spec.labels)}

In [None]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [None]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
word_freq_df = word_freq_df.loc[~word_freq_df.index.duplicated()]
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq_rel", "TwitterFreq_rel", "NewsFreq_rel"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

## Prepare items

In [None]:
next_phon_set = set("AH ER IH L S Z T D M N".split())
target_cohort_length = 2
# defines an alternative "small" cohort: prefixes which have only N of the above phones
target_small_cohort_size = 3
assert target_small_cohort_size < len(next_phon_set)

In [None]:
cohorts = defaultdict(set)
for phones in tqdm(cut_phonemic_forms.unique()):
    phones = tuple(phones.split())
    for i in range(len(phones)):
        cohorts[phones[:i + 1]].add(phones)

csz_next = pd.DataFrame([(" ".join(coh), " ".join(item), item[len(coh)]) for coh, items in cohorts.items()
                            for item in items if len(item) > len(coh)],
                            columns=["cohort", "item", "next_phoneme"])

In [None]:
expt_cohort = csz_next[csz_next.cohort.str.count(" ") == target_cohort_length - 1] \
    .groupby("cohort").filter(lambda xs: set(xs.next_phoneme) >= next_phon_set) \
    .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))
expt_cohort

In [None]:
# Now search for type-small cohorts -- cohorts which only have N of the phone set
expt_cohort_small = csz_next[csz_next.cohort.str.count(" ") == target_cohort_length - 1].groupby("cohort").filter(lambda xs: len(set(xs.next_phoneme)) == target_small_cohort_size and set(xs.next_phoneme) <= next_phon_set) \
    .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))
expt_cohort_small

### Prepare instance-level metadata

In [None]:
all_instances = []
all_prediction_equivalences = {}

# Sample at most this many combinations of cohort + next phone
max_items_per_cohort_and_next_phone = 15

label2idx = {l: i for i, l in enumerate(state_space_spec.labels)}
for cohort, next_phons in tqdm(expt_cohort.items(), total=len(expt_cohort)):
    for phon in next_phons:
        if phon not in next_phon_set:
            continue

        inflected_phones = f"{cohort} {phon}"
        instances = cut_phonemic_forms[cut_phonemic_forms.str.match(f"{inflected_phones}\\b")].index

        # Pick the top K labels with the highest frequency from the cohort.
        coh_labels = instances.get_level_values("label").str.replace("'s$", "", regex=True)
        if len(coh_labels) > max_items_per_cohort_and_next_phone:
            label_freqs = word_freq_df.reindex(coh_labels.unique()).LogFreq.fillna(word_freq_df.LogFreq.min())
            keep_labels = label_freqs.nlargest(max_items_per_cohort_and_next_phone).index
            instances = instances[coh_labels.isin(keep_labels)]
            print(cohort, phon, len(instances))
        
        equiv_key = (inflected_phones,)
        if equiv_key not in all_prediction_equivalences:
            all_prediction_equivalences[equiv_key] = \
                analogy_pseudocausal.prepare_prediction_equivalences(cuts_df, cut_phonemic_forms, cohort, phon)

        for label, instance_idx in instances:
            all_instances.append({
                "base_phones": cohort,
                "inflected_phones": inflected_phones,
                "post_divergence": phon,

                "inflection": phon,
                "next_phoneme_in_restricted_set": phon in next_phon_set,

                "cohort_length": target_cohort_length,
                "next_phoneme_idx": target_cohort_length,

                "inflected": label,
                "inflected_idx": label2idx[label],
                "inflected_instance_idx": instance_idx,
            })

In [None]:
for cohort, next_phons in tqdm(expt_cohort_small.items(), total=len(expt_cohort_small)):
    for phon in next_phons:
        if phon not in next_phon_set:
            continue
        inflected_phones = f"{cohort} {phon}"
        instances = cut_phonemic_forms[cut_phonemic_forms.str.match(f"{inflected_phones}\\b")].index

        # Pick the top K labels with the highest frequency from the cohort.
        coh_labels = instances.get_level_values("label").str.replace("'s$", "", regex=True)
        if len(coh_labels) > max_items_per_cohort_and_next_phone:
            label_freqs = word_freq_df.reindex(coh_labels.unique()).LogFreq.fillna(word_freq_df.LogFreq.min())
            keep_labels = label_freqs.nlargest(max_items_per_cohort_and_next_phone).index
            instances = instances[coh_labels.isin(keep_labels)]

        equiv_key = (inflected_phones,)
        if equiv_key not in all_prediction_equivalences:
            all_prediction_equivalences[equiv_key] = \
                analogy_pseudocausal.prepare_prediction_equivalences(cuts_df, cut_phonemic_forms,
                                                                     cohort, phon)

        for label, instance_idx in instances:
            all_instances.append({
                "base_phones": cohort,
                "inflected_phones": inflected_phones,
                "post_divergence": phon,

                "inflection": f"small-{phon}",
                "next_phoneme_in_restricted_set": phon in next_phon_set,

                "cohort_length": target_cohort_length,
                "next_phoneme_idx": target_cohort_length,

                "inflected": label,
                "inflected_idx": label2idx[label],
                "inflected_instance_idx": instance_idx,
            })

In [None]:
all_instances_df = pd.DataFrame(all_instances)
all_instances_df

In [None]:
all_instances_df.to_csv(f"{output_dir}/pseudocausal_broad_instances.csv")

In [None]:
torch.save(all_prediction_equivalences, f"{output_dir}/pseudocausal_broad_prediction_equivalences.pt")

In [None]:
all_instances_df.groupby(["base_phones", "post_divergence"]).apply(lambda xs: len(xs.inflected.unique())).sort_values()

## Behavioral tests

In [None]:
experiments = {
    f"{a}_to_{b}": {
        "base_query": f"inflection == '{a}'",
        "inflected_query": f"inflection == '{b}'",
        "equivalence_keys": ["inflected_phones", "inflected"],
        "prediction_equivalence_keys": ["to_inflected_phones"],
    }
    for a, b in itertools.product(next_phon_set, repeat=2)
}

In [None]:
small_targets = all_instances_df[all_instances_df.inflection.str.startswith("small-")].inflection.str.split("small-").str[1].unique()
for phone in small_targets:
    for source_phone in next_phon_set:
        experiments[f"{source_phone}-to-small-{phone}"] = {
            "base_query": f"inflection == '{source_phone}'",
            "inflected_query": f"inflection == 'small-{phone}'",
            "equivalence_keys": ["inflected_phones", "inflected"],
            "prediction_equivalence_keys": ["to_inflected_phones"],
        }

In [None]:
experiment_results = pd.concat({
    experiment: analogy_pseudocausal.run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_instances_df,
        agg, agg_src,
        cut_phonemic_forms=cut_phonemic_forms,
        prediction_equivalences=all_prediction_equivalences,
        num_samples=1000,
        max_num_vector_samples=100,
        seed=seed,
        device="cuda:2")
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])

In [None]:
experiment_results["control"] = experiment_results.inflection_to.str.split("-").str[-1] != experiment_results.inflection_from

experiment_results["matches_cohort_correct"] = experiment_results.matches_cohort_target_rank == 0
experiment_results["matches_next_phoneme_correct"] = experiment_results.matches_next_phoneme_target_rank == 0
experiment_results.to_csv(f"{output_dir}/pseudocausal_broad_experiment_results-{model_class}.csv")

### Analyze

In [None]:
# post_div_set = experiment_results.groupby("to_base_phones").apply(lambda xs: frozenset(xs.to_post_divergence))
# experiment_results["post_div_set"] = experiment_results.to_base_phones.map(post_div_set)

In [None]:
main_results = experiment_results[~experiment_results.index.get_level_values(0).str.contains("to-small-")]
small_results = experiment_results[experiment_results.index.get_level_values(0).str.contains("to-small-")]

### Main experiment results

In [None]:
main_results[["control", "matches_cohort_correct", "matches_next_phoneme_correct"]].value_counts().sort_index()

In [None]:
main_results[["control", "correct_base", "correct"]].value_counts().groupby("control").apply(lambda xs: xs / xs.sum()).sort_index()

In [None]:
plot_all_phones = False

full_phone_list = sorted(next_phon_set)
if plot_all_phones:
    full_phone_list += sorted(set(main_results.predicted_phone.fillna("NA").unique()) - set(full_phone_list))
heatmap_results = main_results \
    .groupby(["control", "correct_base", "inflection_from"]).predicted_phone.value_counts(normalize=True) \
    .reindex(pd.MultiIndex.from_product([[False, True], [False, True], sorted(next_phon_set), full_phone_list],
                                        names=["control", "correct_base", "inflection_from", "predicted_phone"])).fillna(0)

g = sns.FacetGrid(data=heatmap_results.reset_index(), row="control", col="correct_base", height=5, aspect=2 if plot_all_phones else 1.2, sharex=False, sharey=False)
def f(data, **kwargs):
    sns.heatmap(data.pivot_table(index="inflection_from", columns="predicted_phone", values="proportion").reindex(full_phone_list, axis=1))
g.map_dataframe(f, annot=True, cmap="Blues")

In [None]:
plot_all_phones = True

full_phone_list = sorted(next_phon_set)
if plot_all_phones:
    # full_phone_list += sorted(set(main_results.predicted_phone.fillna("NA").unique()) - set(full_phone_list))
    # DEV plot just the non-studied phones
    full_phone_list = sorted(set(main_results.predicted_phone.fillna("NA").unique()) - set(full_phone_list))
heatmap_results = main_results \
    .groupby(["control", "correct_base", "inflection_from"]).predicted_phone.value_counts(normalize=True) \
    .reindex(pd.MultiIndex.from_product([[False, True], [False, True], sorted(next_phon_set), full_phone_list],
                                        names=["control", "correct_base", "inflection_from", "predicted_phone"])).fillna(0)

g = sns.FacetGrid(data=heatmap_results.reset_index(), row="control", col="correct_base", height=5, aspect=2 if plot_all_phones else 1, sharex=False, sharey=False)
def f(data, **kwargs):
    sns.heatmap(data.pivot_table(index="inflection_from", columns="predicted_phone", values="proportion").reindex(full_phone_list, axis=1))
g.map_dataframe(f, annot=True, cmap="Blues")

In [None]:
main_results.query("not control and not correct")[["from_inflected_phones", "gt_label", "to_base_phones", "correct", "correct_base", "predicted_label", "predicted_phones", "from_post_divergence", "predicted_phone"]].sample(20).sort_values(["correct", "correct_base"], ascending=False)

In [None]:
sns.barplot(data=main_results,
            x="inflection_to", y="correct", hue="control")

In [None]:
sns.displot(main_results.query("not control").groupby(["from", "control"]).correct.mean().sort_values())

In [None]:
sns.displot(main_results.query("not control").groupby(["to", "control"]).correct.mean().sort_values())

In [None]:
d = main_results.query("not control").groupby(["from_base_phones"]).predicted_base_phones.value_counts(normalize=True).unstack().fillna(0)
e = -(d * np.log2(d)).sum(axis=1)
e.sort_values()

In [None]:
diversity = main_results.query("not control").groupby(["from", "from_base_phones", "from_post_divergence"]).predicted_phone.value_counts(normalize=True).unstack().fillna(0)
entropy = - (diversity * np.log2(diversity)).sum(axis=1)
entropy.sort_values().tail(20)

In [None]:
entropy.sort_values().head(20)

In [None]:
base_diversity = main_results.query("not control").groupby(["from", "from_base_phones", "from_post_divergence"]).predicted_base_phones.value_counts(normalize=True).unstack().fillna(0)
base_entropy = - (diversity * np.log2(diversity)).sum(axis=1)
base_entropy.sort_values().tail(20)

In [None]:
to_diversity = main_results.query("not control").groupby(["to", "to_base_phones", "to_post_divergence"]).predicted_phone.value_counts(normalize=True).unstack().fillna(0)
to_entropy = - (to_diversity * np.log2(to_diversity)).sum(axis=1)
to_entropy.sort_values().tail(20)

In [None]:
to_diversity.loc["careering"].melt().sort_values("value", ascending=False).head(10)

In [None]:
base_diversity.loc["unseen"].melt().sort_values("value", ascending=False).head(20)

In [None]:
main_results["from_base_final"] = main_results.from_base_phones.str.split(" ").str[-1]

In [None]:
main_results.query("not control and `to` == 'licenses'")[["from", "from_base_phones", "correct", "correct_base", "to", "predicted_label", "predicted_base_phones", "predicted_phone"]]

In [None]:
main_results.query("not control").groupby(["to", "to_base_phones"]).correct.agg(["mean", "count"]).query("count >= 10").sort_values("mean")

In [None]:
main_results.query("not control").groupby(["from_base_phones"]).correct_base.agg(["mean", "count"]).query("count >= 4").sort_values("mean")

In [None]:
main_with_freq = pd.merge(main_results, word_freq_df.LogFreq.rename("from_freq"),
                            left_on="from", right_index=True)
main_with_freq = pd.merge(main_with_freq, word_freq_df.LogFreq.rename("to_freq"),
                            left_on="to", right_index=True)

In [None]:
def get_mass(group):
    group = group.drop_duplicates("inflected")
    group = pd.merge(group, word_freq_df.LogFreq,
                     left_on="inflected", right_index=True).set_index(["inflected", "post_divergence"])
    # mass = group.LogFreq ** 10 / (group.LogFreq ** 10).sum()
    mass = group.LogFreq / group.LogFreq.sum()
    return mass
    
masses = all_instances_df[~all_instances_df.inflection.str.startswith("small-")].groupby("base_phones").apply(get_mass)

In [None]:
sns.displot(data=masses.reset_index(), x="LogFreq", hue="base_phones", kind="ecdf")

In [None]:
sns.regplot(data=main_with_freq.query("not control").groupby(["from", "from_freq"]).correct.agg(["mean", "count"]).reset_index().query("count >= 20"),
                x="from_freq", y="mean")

In [None]:
sns.regplot(data=main_with_freq.query("not control").groupby(["to", "to_freq"]).correct.agg(["mean", "count"]).reset_index().query("count >= 20"),
                x="to_freq", y="mean")

### Small cohorts

In [None]:
small_results[["control", "correct_base", "correct"]].value_counts().sort_index()

In [None]:
small_results[["control", "correct_base", "correct"]].value_counts().groupby("control").apply(lambda xs: xs / xs.sum()).sort_index()

In [None]:
small_results["attested"] = small_results.apply(lambda x: x.from_post_divergence in eval(x.post_div_set) if isinstance(x.post_div_set, str) else x.from_post_divergence in x.post_div_set, axis=1)
small_results["condition"] = "main"
small_results.loc[small_results.control & small_results.attested, "condition"] = "control_attested"
small_results.loc[small_results.control & ~small_results.attested, "condition"] = "control_unattested"

In [None]:
sns.catplot(data=small_results, x="inflection_to", y="correct", hue="condition", kind="bar", aspect=2.5)

In [None]:
plot_all_phones = False

full_phone_list = sorted(next_phon_set)
if plot_all_phones:
    full_phone_list += sorted(set(small_results.predicted_phone.fillna("NA").unique()) - set(full_phone_list))
heatmap_results = small_results \
    .groupby(["control", "correct_base", "inflection_from"]).predicted_phone.value_counts(normalize=True) \
    .reindex(pd.MultiIndex.from_product([[False, True], [False, True], sorted(next_phon_set), full_phone_list],
                                        names=["control", "correct_base", "inflection_from", "predicted_phone"])).fillna(0)

g = sns.FacetGrid(data=heatmap_results.reset_index(), row="control", col="correct_base", height=5, aspect=2 if plot_all_phones else 1.2, sharex=False, sharey=False)
def f(data, **kwargs):
    sns.heatmap(data.pivot_table(index="inflection_from", columns="predicted_phone", values="proportion").reindex(full_phone_list, axis=1))
g.map_dataframe(f, annot=True, cmap="Blues")

In [None]:
small_results.query("not control")[["from_inflected_phones", "gt_label", "to_inflected_phones", "correct", "correct_base", "predicted_label", "predicted_phones"]].sample(20) \
    .sort_values(["correct", "correct_base"], ascending=False)

In [None]:
experiment_results[experiment_results.inflection_to.str.startswith("small-")].groupby(["to_base_phones", "inflection_from"]).predicted_phone \
    .value_counts(normalize=True).unstack().fillna(0)#.reindex(columns=next_phon_set).fillna(0).sort_index().sort_index(axis=1)

In [None]:
# TODO look at a bunch of individual prediction examples to get an intuition for what is happening here.

In [None]:
experiment_results[experiment_results.inflection_to.str.startswith("small-")].groupby("post_div_set").correct.agg(["count", "mean"]).sort_values("mean")

In [None]:
experiment_results[(experiment_results.to_base_phones == "AA F") & (experiment_results.to_post_divergence == "T")][["from_inflected_phones", "from_post_divergence", "to", "predicted_label", "predicted_phones"]]

In [None]:
# Plot predicted phone distributions for predicted words with correct base
small_cohort_results = small_results.query("correct_base").groupby(["to_base_phones", "to_post_divergence"]).predicted_phone.value_counts(normalize=True).unstack().fillna(0).reindex(columns=next_phon_set).fillna(0).sort_index().sort_index(axis=1)
small_cohort_results

n_cols = 3
n_rows = int(np.ceil(small_cohort_results.index.get_level_values("to_base_phones").nunique() / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

for i, (ax, ((base_phones, target), row)) in enumerate(zip(axes.flat, small_cohort_results.sample(n_rows * n_cols).sort_index().iterrows())):
    row = row.rename("accuracy").to_frame().reset_index()
    row["in_cohort"] = row.predicted_phone.isin(expt_cohort_small.loc[base_phones])
    sns.barplot(data=row, x="predicted_phone", y="accuracy", ax=ax)
    ax.set_title(f"{base_phones} + {target}")
    ax.set_xlabel("Predicted phone")
    ax.set_ylabel("Probability")
    ax.set_ylim(0, 1)
    ax.grid(axis="y")

plt.tight_layout()bhdvh

In [None]:
# Plot predicted phone distributions for predicted words with correct base
small_cohort_results = small_results.query("not correct_base").groupby(["to_base_phones", "to_post_divergence"]).predicted_phone.value_counts(normalize=True).unstack().fillna(0).reindex(columns=next_phon_set).fillna(0).sort_index().sort_index(axis=1)
small_cohort_results

n_cols = 3
n_rows = int(np.ceil(small_cohort_results.index.get_level_values("to_base_phones").nunique() / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

for i, (ax, ((base_phones, target), row)) in enumerate(zip(axes.flat, small_cohort_results.sample(n_rows * n_cols).sort_index().iterrows())):
    row = row.rename("accuracy").to_frame().reset_index()
    row["in_cohort"] = row.predicted_phone.isin(expt_cohort_small.loc[base_phones])
    sns.barplot(data=row, x="predicted_phone", y="accuracy", hue="in_cohort", ax=ax)
    ax.set_title(f"{base_phones} + {target}")
    ax.set_xlabel("Predicted phone")
    ax.set_ylabel("Probability")
    ax.set_ylim(0, 1)
    ax.legend(title="In attested cohort?")
    ax.grid(axis="y")

plt.tight_layout()

In [None]:
sns.barplot(data=pd.concat({"small": small_results.query("not control").groupby("to_post_divergence").correct.mean(),
"main": main_results.query("not control").groupby("to_post_divergence").correct.mean()}, names=["size"]).reset_index(),
    x="to_post_divergence", y="correct", hue="size")

### Save

In [None]:
experiment_results.to_csv(f"{output_dir}/pseudocausal_broad_experiment_results.csv")

In [None]:
experiment_results = pd.read_csv(f"{output_dir}/pseudocausal_broad_experiment_results.csv", index_col=[0, 1])

In [None]:
experiment_results