In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import itertools
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
from tqdm import tqdm

from src.analysis import analogy
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechHiddenStateDataset


In [None]:
base_model = "w2v2_8"

model_class = "discrim-rnn_32-mAP1"
model_name = "word_broad_10frames_fixedlen25"

inflection_results_path = "inflection_results.parquet"
# all_cross_instances_path = "all_cross_instances.parquet"
all_cross_instances_path = "outputs/analogy/inputs/librispeech-train-clean-100/w2v2/all_cross_instances.parquet"
most_common_allomorphs_path = "most_common_allomorphs.csv"
false_friends_path = "false_friends.csv"

train_dataset = "librispeech-train-clean-100"
hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"state_space_spec.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{train_dataset}.npy"

output_dir = f"."

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

agg_fns = [
    "mean",
]

In [None]:
# general queries for all experiments to exclude special edge cases;
# logic doesn't make sense in most experiments
all_query = "not exclude_main"

experiments = {
    "basic": {
        "group_by": ["inflection"],
        "all_query": all_query,
    },
    "regular": {
        "group_by": ["inflection", "is_regular"],
        "all_query": all_query,
    },
    # "NNS_to_VBZ": {
    #     "base_query": "inflection == 'NNS' and is_regular",
    #     "inflected_query": "inflection == 'VBZ' and is_regular",
    # },
    # "VBZ_to_NNS": {
    #     "base_query": "inflection == 'VBZ' and is_regular",
    #     "inflected_query": "inflection == 'NNS' and is_regular",
    # },
    # "regular_to_irregular": {
    #     "group_by": ["inflection"],
    #     "base_query": "is_regular == True",
    #     "inflected_query": "is_regular == False",
    #     "all_query": all_query,
    # },
    # "irregular_to_regular": {
    #     "group_by": ["inflection"],
    #     "base_query": "is_regular == False",
    #     "inflected_query": "is_regular == True",
    #     "all_query": all_query,
    # },
    "nn_vb_ambiguous": {
        "group_by": ["inflection", "base_ambig_NN_VB"],
        "base_query": "is_regular == True",
        "inflected_query": "is_regular == True",
        "all_query": all_query,
    },
    "random_to_NNS": {
        "base_query": "inflection == 'random'",
        "inflected_query": "inflection == 'NNS'",
        "all_query": all_query,
    },
    "random_to_VBZ": {
        "base_query": "inflection == 'random'",
        "inflected_query": "inflection == 'VBZ'",
        "all_query": all_query,
    },
    "false_friends": {
        "all_query": "inflection.str.contains('FF')",
        "group_by": ["inflection"],
        "equivalence_keys": ["base", "inflected", "post_divergence"],
    }
}

In [None]:
# TODO document
study_unambiguous_transfer = ["NNS", "VBZ"]
study_false_friends = ["NNS", "VBZ", "VBD"]

In [None]:
fc_experiments = {
    ("Z", "S"): {
        "source_inflections": ["VBZ", "NNS"],
    },
    ("D", "T"): {
        "source_inflections": ["VBD"],
    },
    ("D", "IH D"): {
        "source_inflections": ["VBD"],
    },
    ("T", "IH D"): {
        "source_inflections": ["VBD"],
    },
}

## Prepare model representations

In [None]:
if embeddings_path == "ID":
    model_representations = SpeechHiddenStateDataset.from_hdf5(hidden_states_path).states
else:
    with open(embeddings_path, "rb") as f:
        model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory_agg = prepare_state_trajectory(model_representations, state_space_spec, 
                                          agg_fn_spec="mean", agg_fn_dimension=1)

In [None]:
agg, agg_src = flatten_trajectory(trajectory_agg)

## Prepare metadata

In [None]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label", "instance_idx", "frame_idx"]).sort_index()

In [None]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [None]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

In [None]:
all_cross_instances = pd.read_parquet(all_cross_instances_path)

In [None]:
inflection_results_df = pd.read_parquet(inflection_results_path)

In [None]:
most_common_allomorphs = pd.read_csv(most_common_allomorphs_path)
false_friends_df = pd.read_csv(false_friends_path)

## Homophone preparation

In [None]:
from collections import defaultdict


pron2label = defaultdict(set)
for label, rows in cut_phonemic_forms.groupby("label"):
    for pron in set(rows):
        pron2label[pron].add(label)

homophone_map = defaultdict(set)
for label_idx, label in enumerate(state_space_spec.labels):
    for pron in set(cut_phonemic_forms.loc[label]):
        homophone_map[label] |= pron2label[pron]

In [None]:
# Prepare to exclude predictions of homophones from analogy evaluations.
# create a map from label idx -> all label idxs which should be ignored.
homophone_map = {state_space_spec.labels.index(label): {state_space_spec.labels.index(hom) for hom in homs}
                 for label, homs in homophone_map.items()}

## Behavioral tests

In [None]:
# study 3 most frequent allomorphs of each inflection
transfer_allomorphs = most_common_allomorphs.groupby("inflection").most_common_allomorph.apply(lambda xs: xs.value_counts().head(3).index.tolist()).to_dict()

In [None]:
# generate experiments testing transfer from each of top allomorphs in NNS, VBZ
# to each other
for infl1, infl2 in itertools.product(study_unambiguous_transfer, repeat=2):
    for allomorph1 in transfer_allomorphs[infl1]:
        for allomorph2 in transfer_allomorphs[infl2]:
            experiments[f"unambiguous-{infl1}_{allomorph1}_to_{infl2}_{allomorph2}"] = {
                "base_query": f"inflection == '{infl1}' and is_regular == True and base_ambig_NN_VB == False and post_divergence == '{allomorph1}'",
                "inflected_query": f"inflection == '{infl2}' and is_regular == True and base_ambig_NN_VB == False and post_divergence == '{allomorph2}'",
                "all_query": all_query,
            }

In [None]:
# generate experiments testing transfer from
# 1. false friend allomorph to matching inflection allomorph
# 2. false friend allomorph to non-matching inflection allomorph
# 3. inflection allomorph to matching false friend allomorph
# 4. inflection allomorph to non-matching false friend allomorph
for (inflection, post_divergence), _ in false_friends_df.groupby(["inflection", "post_divergence"]):
    if inflection not in study_false_friends:
        continue
    for transfer_allomorph in transfer_allomorphs[inflection]:
        if inflection in ["NNS", "VBZ"]:
            ambig_clause = "base_ambig_NN_VB == {ambig} and "
        else:
            ambig_clause = ""

        ambig_positive = ambig_clause.format(ambig="True")
        ambig_negative = ambig_clause.format(ambig="False")
        experiments[f"{inflection}-FF-{post_divergence}-to-{inflection}_{transfer_allomorph}"] = {
            "base_query": f"inflection == '{inflection}-FF-{post_divergence}'",
            "inflected_query": f"inflection == '{inflection}' and is_regular == True and {ambig_negative} post_divergence == '{transfer_allomorph}'",
        }
        experiments[f"{inflection}_{transfer_allomorph}-to-{inflection}-FF-{post_divergence}"] = {
            "base_query": f"inflection == '{inflection}' and is_regular == True and {ambig_negative} post_divergence == '{transfer_allomorph}'",
            "inflected_query": f"inflection == '{inflection}-FF-{post_divergence}'",
        }

for inflection in study_false_friends:
    for t1, t2 in itertools.combinations(transfer_allomorphs[inflection], 2):
        experiments[f"{inflection}-FF-{t1}-to-{inflection}-FF-{t2}"] = {
            "base_query": f"inflection == '{inflection}-FF-{t1}'",
            "inflected_query": f"inflection == '{inflection}-FF-{t2}'",
        }

In [None]:
# generate experiments for forced-choice analysis
fc_types = [infl for infl in all_cross_instances.inflection.unique() if infl.startswith("FC")]
fc_types = set(re.findall(r"FC-([\w\s]+)_([\w\s]+)", infl)[0] for infl in fc_types)

for fc_pair, config in fc_experiments.items():
    fc_pair_name = "_".join(fc_pair)
    if fc_pair not in fc_types:
        raise ValueError(f"FC pair {fc_pair} not found in FC stimuli")
    
    for source_inflection in config["source_inflections"]:
        for source_allomorph in transfer_allomorphs[source_inflection]:
            experiments[f"FC-{fc_pair_name}-from_{source_inflection}-{source_allomorph}"] = {
                "base_query": f"inflection == '{source_inflection}' and post_divergence == '{source_allomorph}'",
                "inflected_query": f"inflection == 'FC-{fc_pair_name}'",
            }

In [None]:
# experiment = "unambiguous-NNS_Z_to_NNS_Z"
# config = experiments[experiment]
# ret = analogy.run_experiment_equiv_level(
#     experiment, config, state_space_spec, all_cross_instances,
#     agg, agg_src,
#     num_samples=20,
#     device="cpu",
#     include_idxs_in_predictions=homophone_map,
#     verbose=True,
# )

In [None]:
experiment_results = pd.concat({
    experiment: analogy.run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_cross_instances,
        agg, agg_src,
        num_samples=1000,
        seed=seed,
        device="cuda")
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])
experiment_results["correct"] = experiment_results.predicted_label == experiment_results.gt_label
experiment_results

### Save

In [None]:
experiment_results.to_csv(f"{output_dir}/experiment_results.csv")