In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import itertools
import re

import numpy as np
import pandas as pd
from tqdm import tqdm

from src.analysis import analogy
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechHiddenStateDataset


In [None]:
base_model = "w2v2_pc_8"

model_class = "discrim-rnn_32-pc-mAP1"
model_name = "word_broad_10frames_fixedlen25"

inflection_results_path = "inflection_results.parquet"
all_cross_instances_path = "all_cross_instances.parquet"
most_common_allomorphs_path = "most_common_allomorphs.csv"
false_friends_path = "false_friends.csv"

train_dataset = "librispeech-train-clean-100"
# hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
hidden_states_path = f"/scratch/jgauthier/{base_model}_{train_dataset}.h5"
state_space_specs_path = f"outputs/analogy/inputs/{train_dataset}/w2v2_pc/state_space_spec.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{train_dataset}.npy"

output_dir = f"."

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

agg_fns = [
    ("mean_within_cut", "phoneme")
]

In [None]:
# general queries for all experiments to exclude special edge cases;
# logic doesn't make sense in most experiments
all_query = "not exclude_main"

experiments = {
    "basic": {
        "group_by": ["inflection"],
        "all_query": all_query,
    },
    "regular": {
        "group_by": ["inflection", "is_regular"],
        "all_query": all_query,
    },
    # "NNS_to_VBZ": {
    #     "base_query": "inflection == 'NNS' and is_regular",
    #     "inflected_query": "inflection == 'VBZ' and is_regular",
    # },
    # "VBZ_to_NNS": {
    #     "base_query": "inflection == 'VBZ' and is_regular",
    #     "inflected_query": "inflection == 'NNS' and is_regular",
    # },
    # "regular_to_irregular": {
    #     "group_by": ["inflection"],
    #     "base_query": "is_regular == True",
    #     "inflected_query": "is_regular == False",
    #     "all_query": all_query,
    # },
    # "irregular_to_regular": {
    #     "group_by": ["inflection"],
    #     "base_query": "is_regular == False",
    #     "inflected_query": "is_regular == True",
    #     "all_query": all_query,
    # },
    "nn_vb_ambiguous": {
        "group_by": ["inflection", "base_ambig_NN_VB"],
        "base_query": "is_regular == True",
        "inflected_query": "is_regular == True",
        "all_query": all_query,
    },
    "random_to_NNS": {
        "base_query": "inflection == 'random'",
        "inflected_query": "inflection == 'NNS'",
        "all_query": all_query,
    },
    "random_to_VBZ": {
        "base_query": "inflection == 'random'",
        "inflected_query": "inflection == 'VBZ'",
        "all_query": all_query,
    },
    "false_friends": {
        "all_query": "inflection.str.contains('FF')",
        "group_by": ["inflection"],
        "equivalence_keys": ["base", "inflected", "post_divergence"],
    }
}

In [None]:
# TODO document
study_unambiguous_transfer = ["NNS", "VBZ"]
study_false_friends = ["NNS", "VBZ", "VBD"]

In [None]:
fc_experiments = {
    ("Z", "S"): {
        "source_inflections": ["VBZ", "NNS"],
    },
    ("D", "T"): {
        "source_inflections": ["VBD"],
    },
    ("D", "IH D"): {
        "source_inflections": ["VBD"],
    },
    ("T", "IH D"): {
        "source_inflections": ["VBD"],
    },
}

## Prepare model representations

In [None]:
if embeddings_path == "ID":
    model_representations = SpeechHiddenStateDataset.from_hdf5(hidden_states_path).states
else:
    with open(embeddings_path, "rb") as f:
        model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec)
trajectory = aggregate_state_trajectory(trajectory, state_space_spec, agg_fns[0])

In [None]:
agg, agg_src = flatten_trajectory(trajectory)

## Prepare metadata

In [None]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label", "instance_idx", "frame_idx"]).sort_index()

In [None]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [None]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

In [None]:
all_cross_instances = pd.read_parquet(all_cross_instances_path)

In [None]:
inflection_results_df = pd.read_parquet(inflection_results_path)

In [None]:
most_common_allomorphs = pd.read_csv(most_common_allomorphs_path)
false_friends_df = pd.read_csv(false_friends_path)

## Prepare inclusions / exclusions

### Include homophones of the target as valid predictions

In [None]:
from collections import defaultdict


pron2label = defaultdict(set)
for label, rows in cut_phonemic_forms.groupby("label"):
    for pron in set(rows):
        pron2label[pron].add(label)

homophone_map = defaultdict(set)
for label_idx, label in enumerate(state_space_spec.labels):
    for pron in set(cut_phonemic_forms.loc[label]):
        homophone_map[label] |= pron2label[pron]

In [None]:
# Prepare to include predictions of homophones from analogy evaluations.
# create a map from inflected label idx -> all label idxs which should be ignored.
include_inflected_map = {state_space_spec.labels.index(label): {state_space_spec.labels.index(hom) for hom in homs}
                         for label, homs in homophone_map.items()}

### Exclude the base and all homophones as a valid prediction

In [None]:
# Prepare to exclude base and any homophones from analogy evaluations
# create a map from inflection + inflected label idx -> all label idxs which should be ignored.
exclude_inflected_map = {}
for (inflection, base, inflected, base_idx, inflected_idx), _ in all_cross_instances.groupby(["inflection", "base", "inflected", "base_idx", "inflected_idx"]):
    exclude_inflected_map[inflection, inflected_idx] = {state_space_spec.labels.index(hom) for hom in homophone_map[base]}

## Behavioral tests

In [None]:
# study 3 most frequent allomorphs of each inflection
transfer_allomorphs = most_common_allomorphs.groupby("inflection").most_common_allomorph.apply(lambda xs: xs.value_counts().head(3).index.tolist()).to_dict()

In [None]:
# generate experiments testing transfer from each of top allomorphs in NNS, VBZ
# to each other
for infl1, infl2 in itertools.product(study_unambiguous_transfer, repeat=2):
    for allomorph1 in transfer_allomorphs[infl1]:
        for allomorph2 in transfer_allomorphs[infl2]:
            experiments[f"unambiguous-{infl1}_{allomorph1}_to_{infl2}_{allomorph2}"] = {
                "base_query": f"inflection == '{infl1}' and is_regular == True and base_ambig_NN_VB == False and post_divergence == '{allomorph1}'",
                "inflected_query": f"inflection == '{infl2}' and is_regular == True and base_ambig_NN_VB == False and post_divergence == '{allomorph2}'",
                "all_query": all_query,
            }

In [None]:
# generate experiments testing transfer from
# 1. false friend allomorph to matching inflection allomorph
# 2. false friend allomorph to non-matching inflection allomorph
# 3. inflection allomorph to matching false friend allomorph
# 4. inflection allomorph to non-matching false friend allomorph
for (inflection, post_divergence), _ in false_friends_df.groupby(["inflection", "post_divergence"]):
    if inflection not in study_false_friends:
        continue
    for transfer_allomorph in transfer_allomorphs[inflection]:
        if inflection in ["NNS", "VBZ"]:
            ambig_clause = "base_ambig_NN_VB == {ambig} and "
        else:
            ambig_clause = ""

        ambig_positive = ambig_clause.format(ambig="True")
        ambig_negative = ambig_clause.format(ambig="False")

        experiments[f"{inflection}-FF-{post_divergence}-to-{inflection}_{transfer_allomorph}"] = {
            "base_query": f"inflection == '{inflection}-FF-{post_divergence}'",
            "inflected_query": f"inflection == '{inflection}' and is_regular == True and {ambig_negative} post_divergence == '{transfer_allomorph}'",
        }
        experiments[f"{inflection}_{transfer_allomorph}-to-{inflection}-FF-{post_divergence}"] = {
            "base_query": f"inflection == '{inflection}' and is_regular == True and {ambig_negative} post_divergence == '{transfer_allomorph}'",
            "inflected_query": f"inflection == '{inflection}-FF-{post_divergence}'",
        }

for inflection in study_false_friends:
    for t1, t2 in itertools.combinations(transfer_allomorphs[inflection], 2):
        experiments[f"{inflection}-FF-{t1}-to-{inflection}-FF-{t2}"] = {
            "base_query": f"inflection == '{inflection}-FF-{t1}'",
            "inflected_query": f"inflection == '{inflection}-FF-{t2}'",
        }

In [None]:
# generate experiments for forced-choice analysis
fc_types = [infl for infl in all_cross_instances.inflection.unique() if infl.startswith("FC")]
fc_types = set(re.findall(r"FC-([\w\s]+)_([\w\s]+)", infl)[0] for infl in fc_types)

for fc_pair, config in fc_experiments.items():
    fc_pair_name = "_".join(fc_pair)
    if fc_pair not in fc_types:
        raise ValueError(f"FC pair {fc_pair} not found in FC stimuli")
    
    for source_inflection in config["source_inflections"]:
        for source_allomorph in transfer_allomorphs[source_inflection]:
            experiments[f"FC-{fc_pair_name}-from_{source_inflection}-{source_allomorph}"] = {
                "base_query": f"inflection == '{source_inflection}' and post_divergence == '{source_allomorph}'",
                "inflected_query": f"inflection == 'FC-{fc_pair_name}'",
            }

In [None]:
instances_dev = all_cross_instances.drop(columns=["base", "base_idx", "base_phones", "base_instance_idx"]).drop_duplicates(["inflection", "inflected_idx", "inflected_instance_idx"])

In [None]:
instances_dev["divergence_phoneme_idx"] = instances_dev.inflected_phones.str.count(" ") - instances_dev.post_divergence.str.count(" ")
instances_dev["last_phoneme_idx"] = instances_dev.inflected_phones.str.count(" ")

In [None]:
instances_dev = instances_dev[instances_dev.divergence_phoneme_idx > 0]

In [None]:
from typing import Literal, Optional
import torch
from src.analysis.analogy import nxn_cos_sim
import logging

L = logging.getLogger(__name__)

def iter_equivalences(
        config, all_cross_instances, agg_src: np.ndarray,
        num_samples=100, max_num_vector_samples=250,
        divergence_index: Literal["first", "last"] = "last",
        seed=None,):
    
    # Pre-compute lookup from label idx, instance idx to flat idx
    if isinstance(agg_src, torch.Tensor):
        agg_src = agg_src.cpu().numpy()
    flat_idx_lookup = {(label_idx, instance_idx, phoneme_idx): flat_idx
                       for flat_idx, (label_idx, instance_idx, phoneme_idx) in enumerate(agg_src)}

    if seed is not None:
        np.random.seed(seed)

    if "group_by" in config:
        grouper = all_cross_instances.groupby(config["group_by"])
    else:
        grouper = [("", all_cross_instances)]

    for group, rows in tqdm(grouper, leave=False):
        print(group)

        try:
            if "base_query" in config:
                rows_from = rows.query(config["base_query"])
            else:
                rows_from = rows

            if "inflected_query" in config:
                rows_to = rows.query(config["inflected_query"])
            else:
                rows_to = rows

            if "all_query" in config:
                rows_from = rows_from.query(config["all_query"])
                rows_to = rows_to.query(config["all_query"])

            inflection_from = rows_from.inflection.iloc[0]
            inflection_to = rows_to.inflection.iloc[0]
        except IndexError:
            continue

        if len(rows_from) == 0 or len(rows_to) == 0:
            continue

        # prepare equivalences for 'from' and 'to' groups.
        # equivalences define the set of instances over which we can average representations
        # before computing the analogy.
        if "equivalence_keys" in config:
            from_equivalence_keys = config["equivalence_keys"]
            to_equivalence_keys = config["equivalence_keys"]
        else:
            from_equivalence_keys = ["inflected", "inflected_phones"]
            to_equivalence_keys = ["inflected", "inflected_phones"]

        # we must group on at least the forms themselves
        assert set(["inflected", "inflected_phones"]) <= set(from_equivalence_keys)
        assert set(["inflected", "inflected_phones"]) <= set(to_equivalence_keys)

        from_equiv = rows_from.groupby(from_equivalence_keys)
        to_equiv = rows_to.groupby(to_equivalence_keys)
        from_equiv_labels = [k for k, count in from_equiv.size().items() if count >= 1]
        to_equiv_labels = [k for k, count in to_equiv.size().items() if count >= 1]

        if len(set(from_equiv_labels) | set(to_equiv_labels)) <= 1:
            # not enough labels to support transfer.
            L.error(f"Skipping {group} due to insufficient labels")
            continue

        # Make sure labels are tuples
        if not isinstance(from_equiv_labels[0], tuple):
            from_equiv_labels = [(label,) for label in from_equiv_labels]
        if not isinstance(to_equiv_labels[0], tuple):
            to_equiv_labels = [(label,) for label in to_equiv_labels]

        # sample pairs of base forms
        candidate_pairs = [(x, y) for x, y in itertools.product(from_equiv_labels, to_equiv_labels) if x != y]
        num_samples_i = min(num_samples, len(candidate_pairs))
        samples = np.random.choice(len(candidate_pairs), num_samples_i, replace=False)

        for idx in tqdm(samples, leave=False):
            from_equiv_label_i, to_equiv_label_i = candidate_pairs[idx]
            rows_from_i = from_equiv.get_group(tuple(from_equiv_label_i))
            rows_to_i = to_equiv.get_group(tuple(to_equiv_label_i))

            # sample pairs for comparison across the two forms
            n = min(max_num_vector_samples, max(len(rows_from_i), len(rows_to_i)))
            if len(rows_from_i) < n:
                rows_from_i = rows_from_i.sample(n, replace=True)
            elif len(rows_from_i) > n:
                rows_from_i = rows_from_i.sample(n, replace=False)

            if len(rows_to_i) < n:
                rows_to_i = rows_to_i.sample(n, replace=True)
            elif len(rows_to_i) > n:
                rows_to_i = rows_to_i.sample(n, replace=False)

            from_label = rows_from_i.inflected.iloc[0]
            from_idx = rows_from_i.inflected_idx.iloc[0]
            to_label = rows_to_i.inflected.iloc[0]
            to_idx = rows_to_i.inflected_idx.iloc[0]

            # what are the "base" and "inflected" forms?
            from_inflected_phones = rows_from_i.inflected_phones.iloc[0].split(" ")
            from_base_phones = from_inflected_phones[:rows_from_i.divergence_phoneme_idx.iloc[0]]
            from_post_divergence = from_inflected_phones[rows_from_i.divergence_phoneme_idx.iloc[0]:]
            to_inflected_phones = rows_to_i.inflected_phones.iloc[0].split(" ")
            to_base_phones = to_inflected_phones[:rows_to_i.divergence_phoneme_idx.iloc[0]]
            to_post_divergence = to_inflected_phones[rows_to_i.divergence_phoneme_idx.iloc[0]:]

            # compute individual representation indices
            if divergence_index == "first":
                # draw representation of the inflected form from the first diverging phoneme
                from_inflected_flat_idx = torch.tensor(
                    [flat_idx_lookup[(row.inflected_idx, row.inflected_instance_idx, row.divergence_phoneme_idx)]
                    for _, row in rows_from_i.iterrows()])
            elif divergence_index == "last":
                # draw representation of the inflected form from the last phoneme of the word
                from_inflected_flat_idx = torch.tensor(
                    [flat_idx_lookup[(row.inflected_idx, row.inflected_instance_idx, row.last_phoneme_idx)]
                    for _, row in rows_from_i.iterrows()])

            # TODO design choice: do we take repr from previous phoneme or averaged over all previous
            # phonemes?
            from_base_flat_idx = torch.tensor(
                [flat_idx_lookup[(row.inflected_idx, row.inflected_instance_idx, row.divergence_phoneme_idx - 1)]
                 for _, row in rows_from_i.iterrows()])
            to_base_flat_idx = torch.tensor(
                [flat_idx_lookup[(row.inflected_idx, row.inflected_instance_idx, row.divergence_phoneme_idx - 1)]
                 for _, row in rows_to_i.iterrows()])
            
            yield {
                "group": group,

                "from_label": from_label,
                "from_idx": from_idx,
                "to_label": to_label,
                "to_idx": to_idx,

                "from_inflected_phones": " ".join(from_inflected_phones),
                "from_base_phones": " ".join(from_base_phones),
                "from_post_divergence": " ".join(from_post_divergence),

                "to_inflected_phones": " ".join(to_inflected_phones),
                "to_base_phones": " ".join(to_base_phones),
                "to_post_divergence": " ".join(to_post_divergence),

                "inflection_from": inflection_from,
                "inflection_to": inflection_to,
                "from_equiv_label_i": from_equiv_label_i,
                "to_equiv_label_i": to_equiv_label_i,
                
                "from_inflected_flat_idx": from_inflected_flat_idx,
                "from_base_flat_idx": from_base_flat_idx,
                "to_base_flat_idx": to_base_flat_idx,                
            }

def run_experiment_equiv_level(
        experiment_name, config,
        state_space_spec, all_cross_instances,
        agg, agg_src,
        device: str = "cpu",
        verbose=False,
        num_samples=100, max_num_vector_samples=250,
        seed=None,
        exclude_idxs_from_predictions: Optional[dict[int, list[int]]] = None,
        include_idxs_in_predictions: Optional[dict[int, list[int]]] = None):
    print(experiment_name)

    # move data to device
    agg = torch.tensor(agg).to(device)
    agg_src = torch.tensor(agg_src).to(device)
    
    results = []
    for sample in iter_equivalences(
            config, all_cross_instances, agg_src,
            num_samples=num_samples,
            max_num_vector_samples=max_num_vector_samples,
            seed=seed):

        from_inflected_flat_idx = sample["from_inflected_flat_idx"]
        from_base_flat_idx = sample["from_base_flat_idx"]
        to_base_flat_idx = sample["to_base_flat_idx"]

        # Critical analogy logic
        pair_difference = agg[from_inflected_flat_idx] - agg[from_base_flat_idx]
        pair_base = agg[to_base_flat_idx]

        pair_predicted = pair_base + pair_difference
        pair_predicted /= torch.norm(pair_predicted, dim=1, keepdim=True)

        references, references_src = agg, agg_src
        with torch.no_grad():
            dists = 1 - nxn_cos_sim(pair_predicted, references)
            # mean over instances of pair
            dists = dists.mean(0)
        ranks = dists.argsort()

        if exclude_idxs_from_predictions is not None:
            valid_idxs = torch.tensor(list(exclude_idxs_from_predictions[sample["inflection_to"], sample["to_idx"]]))
            ranks = ranks[~torch.isin(ranks, valid_idxs.to(ranks))]
        if include_idxs_in_predictions is not None:
            valid_idxs = torch.tensor(list(include_idxs_in_predictions[sample["to_idx"]]))
            gt_rank = torch.where(torch.isin(references_src[ranks, 0], valid_idxs.to(ranks)))[0][0].item()
        else:
            gt_rank = torch.where(references_src[ranks, 0] == sample["to_idx"])[0][0].item()

        gt_distance = dists[gt_rank].item()
        predicted_phoneme_idx = references_src[gt_rank, 2].item()

        if verbose:
            for dist, (label_idx, instance_idx, _) in zip(dists[ranks[:5]], references_src[ranks[:5]]):
                print(f"{sample['group']} {sample['from_equiv_label_i']} -> {sample['to_equiv_label_i']}: {state_space_spec.labels[label_idx]} {instance_idx}")

        nearest_neighbor = references_src[ranks[0]]
        results.append({
            "group": sample["group"],

            "from": sample["from_label"],
            "to": sample["to_label"],

            "from_inflected_phones": sample["from_inflected_phones"],
            "from_base_phones": sample["from_base_phones"],
            "from_post_divergence": sample["from_post_divergence"],

            "to_inflected_phones": sample["to_inflected_phones"],
            "to_base_phones": sample["to_base_phones"],
            "to_post_divergence": sample["to_post_divergence"],

            "inflection_from": sample["inflection_from"],
            "inflection_to": sample["inflection_to"],
            "from_equiv_label": sample["from_equiv_label_i"],
            "to_equiv_label": sample["to_equiv_label_i"],

            "predicted_label_idx": nearest_neighbor[0].item(),
            "predicted_label": state_space_spec.labels[nearest_neighbor[0]],
            "predicted_instance_idx": nearest_neighbor[1].item(),
            "predicted_phoneme_idx": predicted_phoneme_idx,
            "gt_label": sample["to_label"],
            "gt_label_idx": sample["to_idx"],
            "gt_label_rank": gt_rank,
            "gt_distance": gt_distance,
        })

    return pd.DataFrame(results)

In [None]:
# experiment = "basic"
# config = experiments[experiment]
# ret = run_experiment_equiv_level(
#     experiment, config, state_space_spec, instances_dev,
#     agg, agg_src,
#     num_samples=20,
#     device="cpu",
#     include_idxs_in_predictions=include_inflected_map,
#     exclude_idxs_from_predictions=exclude_inflected_map,
# )

In [None]:
experiment_results = pd.concat({
    experiment: analogy.run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_cross_instances,
        agg, agg_src,
        num_samples=1000,
        seed=seed,
        device="cuda")
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])
experiment_results["correct"] = experiment_results.predicted_label == experiment_results.gt_label
experiment_results

### Save

In [None]:
experiment_results.to_csv(f"{output_dir}/experiment_results.csv")