In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict, Counter
import functools
import itertools
from pathlib import Path
import pickle
import re

from fastdist import fastdist
import lemminflect
import matplotlib.pyplot as plt
from matplotlib import transforms
import nltk
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
import torch
from tqdm import tqdm

from src.analysis import analogy
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory

In [None]:
base_model = "w2v2_8"

model_class = "discrim-rnn_32-mAP1"
model_name = "word_broad_10frames_fixedlen25"

# model_class = "ff_32"
# model_name = "word_broad_10frames"

train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{train_dataset}.npy"

pos_counts_path = "data/pos_counts.pkl"

seed = 1234

max_samples_per_word = 100

metric = "cosine"

agg_fns = [
    "mean",
]

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
# model_representations /= np.linalg.norm(model_representations, axis=1, keepdims=True)

In [None]:
state_space_spec = state_space_spec.subsample_instances(max_samples_per_word)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

In [None]:
trajectory_aggs = {agg_fn: aggregate_state_trajectory(trajectory, state_space_spec, agg_fn, keepdims=True)
                   for agg_fn in tqdm(agg_fns)}

In [None]:
trajectory_aggs_flat = {k: flatten_trajectory(v) for k, v in trajectory_aggs.items()}

In [None]:
agg, agg_src = trajectory_aggs_flat["mean"]

In [None]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label", "instance_idx", "frame_idx"]).sort_index()

In [None]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [None]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

## Experiment setup

### Pre-compute other auxiliary features

### Token-level features

#### Post-divergence analysis

In [None]:
inflection_instance_df = pd.DataFrame(inflection_instances)

# Now merge with type-level information.
inflection_instance_df = pd.merge(inflection_instance_df,
                                  inflection_results_df.reset_index(),
                                  how="left",
                                  on=["inflection", "base", "inflected"])
inflection_instance_df

In [None]:
# return most-common-allomorph information to main df
# NB this may collapse across different orthographic inflected forms
most_common_allomorphs = inflection_instance_df.groupby(["inflection", "base"]).post_divergence \
    .apply(lambda xs: xs.value_counts().idxmax()) \
    .rename("most_common_allomorph").reset_index()
pd.merge(inflection_results_df, most_common_allomorphs,
         on=["inflection", "base"], validate="m:1")

In [None]:
inflection_cross_instances = []
base_cross_instances = []

for inflection, row in tqdm(inflection_results_df.iterrows(), total=len(inflection_results_df)):
    inflected_flat_idxs = np.nonzero(agg_src[:, 0] == row.inflected_idx)[0]
    inflected_forms = cut_phonemic_forms.loc[row.inflected]
    for inflected_flat_idx in inflected_flat_idxs:
        inflected_instance_idx = agg_src[inflected_flat_idx, 1]
        inflection_cross_instances.append({
            "inflection": inflection,
            "base": row.base,
            "inflected": row.inflected,
            "inflected_instance_idx": inflected_instance_idx,
            "inflected_phones": inflected_forms.loc[inflected_instance_idx]
        })

    base_flat_idxs = np.nonzero(agg_src[:, 0] == row.base_idx)[0]
    base_forms = cut_phonemic_forms.loc[row.base]
    for base_flat_idx in base_flat_idxs:
        base_instance_idx = agg_src[base_flat_idx, 1]
        base_cross_instances.append({
            "inflection": inflection,
            "base": row.base,
            "inflected": row.inflected,
            "base_instance_idx": base_instance_idx,
            "base_phones": base_forms.loc[base_instance_idx]
        })

In [None]:
# add in post-divergence information
inflection_cross_instances_df = pd.DataFrame(inflection_cross_instances)
merge_on = ["inflection", "base", "inflected", "inflected_instance_idx"]
inflection_cross_instances_df = pd.merge(inflection_cross_instances_df,
                                         inflection_instance_df[merge_on + ["post_divergence"]],
                                         on=merge_on)

all_cross_instances = pd.merge(pd.DataFrame(base_cross_instances),
         inflection_cross_instances_df,
         on=["inflection", "base", "inflected"],
         how="outer")

# Now merge with type-level information.
all_cross_instances = pd.merge(inflection_results_df.reset_index(),
                               all_cross_instances,
                               on=["inflection", "base", "inflected"],
                               validate="1:m")

all_cross_instances["exclude_main"] = False
all_cross_instances

### Add false friends

In [None]:
def compute_false_friends():
    false_friends_dfs = {}
    inflection_allomorph_grouper = most_common_allomorphs \
        [~most_common_allomorphs.inflection.isin(("random", "NOT-latin"))] \
        .groupby("inflection").most_common_allomorph \
        .apply(lambda xs: xs.value_counts()[:3]).index
    for inflection, post_divergence in tqdm(inflection_allomorph_grouper):
        avoid_inflections = {"POS", inflection}
        if inflection == "NNS":
            avoid_inflections.add("VBZ")
        elif inflection == "VBZ":
            avoid_inflections.add("NNS")
        avoid_inflections = list(avoid_inflections)

        try:
            false_friends_dfs[inflection, post_divergence] = \
                analogy.prepare_false_friends(
                    inflection_results_df,
                    inflection_instance_df,
                    cut_phonemic_forms,
                    post_divergence,
                    avoid_inflections=avoid_inflections)
        except:
            print("Failed for", inflection, post_divergence)
            continue

    return false_friends_dfs

false_friends_cache_path = Path(output_dir) / "false_friends.pt"
if false_friends_cache_path.exists():
    false_friends_dfs = torch.load(false_friends_cache_path)
else:
    false_friends_dfs = compute_false_friends()
    torch.save(false_friends_dfs, false_friends_cache_path)

In [None]:
false_friends_df = pd.concat(false_friends_dfs, names=["inflection", "post_divergence"]).droplevel(-1)

# manually exclude some cases that don't get filtered out, often just because they're too
# low frequency for both true base and inflected form to appear

# share exclusion list for NNS and VBZ since we have experiments relating these two
# so this is any false-friend for which their is a phonologically identical "base"
# that could instantiate a VBZ or NNS inflection
exclude_NNS_VBZ = ("adds americans arabs assyrians berries carlyle's childs christians "
                   "counties cruise dares dealings delawares europeans excellencies "
                   "fins fours galleries gaze germans indians isles maids mary's negroes "
                   "nuns peas phrase pyes reflections rodgers romans russians simpsons "
                   "spaniards sundays vickers weeds wigwams williams "
                   "jews odds news hose dis yes ice cease peace s us "
                    
                   "greeks lapse mix philips trunks its "
                    
                   "breeches occurrences personages").split()
false_friends_manual_exclude = {
    "NNS": exclude_NNS_VBZ,
    "VBZ": exclude_NNS_VBZ,
    "VBD": ("armored bald bard counseled crude dared enquired healed knowed legged "
            "mourned natured renowned rude second ward wild willed withered "

            "tract wrapped fitted hearted heralded intrusted knitted wretched").split(),
    "VBG": ("ceiling daring fleeting morning roaming wasting weaving weighing "
            "whining willing chuckling kneeling sparkling startling").split()
}

false_friends_df = false_friends_df.groupby("inflection", as_index=False).apply(
    lambda xs: xs[~xs.inflected.isin(false_friends_manual_exclude.get(xs.name, []))]).droplevel(0)

# exclude the (quite interesting) cases where the "base" and "inflected" form are
# actually orthographically matched, and we're seeing the divergence due to a pronunciation
# variant (e.g. don't as D OW N vs D O WN T)
false_friends_df = false_friends_df[false_friends_df.base != false_friends_df.inflected]

false_friends_df

In [None]:
# Define sets or lists for final-phoneme checks
SIBILANTS = {"S", "Z", "SH", "CH", "JH", "ZH"}
VOICELESS = {"P", "T", "K", "F", "TH"}  # Could add others as needed

def guess_nns_vbz_allomorph(base_phones):
    """
    Given a list of CMUDICT phones for a base form, 
    return the 'expected' post-divergence allomorph 
    (S, Z, or IH Z, etc.) for the English plural / 3sg verb.
    """
    last_phone = base_phones[-1]
    
    if last_phone in SIBILANTS:
        # e.g., 'CH' -> "IH Z"
        return "IH Z"
    elif last_phone in VOICELESS:
        # e.g., 'K', 'P', 'T' -> "S"
        return "S"
    else:
        # default to voiced => "Z"
        return "Z"
    

ALVEOLAR_STOPS = {"T", "D"}
# Example set of voiceless consonants (non-exhaustive—adjust as needed).
VOICELESS = {"P", "F", "K", "S", "SH", "CH", "TH"}  # Typically would also have /ʃ/, etc.

def guess_past_allomorph(base_phones):
    """
    Given a list of CMUDICT phones for a base form,
    return the 'expected' post-divergence allomorph
    (T, D, or IH D) for the English past tense.
    """
    last_phone = base_phones[-1]
    
    if last_phone in ALVEOLAR_STOPS:
        # E.g., "want" -> "wanted" => "AH0 D"
        return "IH D"
    elif last_phone in VOICELESS:
        # E.g., "jump" -> "jumped" => "T"
        return "T"
    else:
        # default to voiced => "D"
        return "D"


false_friends_df.loc[["NNS", "VBZ"], "strong_expected"] = false_friends_df.loc[["NNS", "VBZ"]].apply(lambda xs: guess_nns_vbz_allomorph(xs.base_form.split(" ")), axis=1)
false_friends_df.loc[["VBD"], "strong_expected"] = false_friends_df.loc[["VBD"]].apply(lambda xs: guess_past_allomorph(xs.base_form.split(" ")), axis=1)
false_friends_df["strong"] = false_friends_df.index.get_level_values("post_divergence") == false_friends_df.strong_expected

In [None]:
false_friends_df.groupby(["inflection", "post_divergence", "strong"]).size()

In [None]:
false_friends_df.query("inflection == 'VBZ' and post_divergence == 'S'")

In [None]:
false_friends_df.groupby(["inflection", "strong", "strong_expected"]).sample(frac=.1)

In [None]:
cross_false_friends_df = pd.merge(false_friends_df.reset_index(),
         cut_phonemic_forms.reset_index().rename(
             columns={"label": "base", "description": "base_form",
                      "instance_idx": "base_instance_idx"}),
         on=["base", "base_form"], how="left")
cross_false_friends_df = pd.merge(cross_false_friends_df,
         cut_phonemic_forms.reset_index().rename(
             columns={"label": "inflected", "description": "inflected_form",
                      "instance_idx": "inflected_instance_idx"}),
         on=["inflected", "inflected_form"], how="left")

# update to match all_cross_instances schema
cross_false_friends_df = cross_false_friends_df.rename(
    columns={"base_form": "base_phones",
             "inflected_form": "inflected_phones"})
cross_false_friends_df["base_idx"] = cross_false_friends_df.base.map({l: i for i, l in enumerate(state_space_spec.labels)})
cross_false_friends_df["inflected_idx"] = cross_false_friends_df.inflected.map({l: i for i, l in enumerate(state_space_spec.labels)})
cross_false_friends_df["is_regular"] = True

cross_false_friends_df["inflection"] = (cross_false_friends_df.inflection + "-FF-").str.cat(cross_false_friends_df.post_divergence, sep="")
cross_false_friends_df["exclude_main"] = True
cross_false_friends_df

In [None]:
all_cross_instances = pd.concat([all_cross_instances, cross_false_friends_df], axis=0)

## Behavioral tests

In [None]:
# general queries for all experiments to exclude special edge cases;
# logic doesn't make sense in most experiments
all_query = "not exclude_main"

experiments = {
    "basic": {
        "group_by": ["inflection"],
        "all_query": all_query,
    },
    "regular": {
        "group_by": ["inflection", "is_regular"],
        "all_query": all_query,
    },
    # "NNS_to_VBZ": {
    #     "base_query": "inflection == 'NNS' and is_regular",
    #     "inflected_query": "inflection == 'VBZ' and is_regular",
    # },
    # "VBZ_to_NNS": {
    #     "base_query": "inflection == 'VBZ' and is_regular",
    #     "inflected_query": "inflection == 'NNS' and is_regular",
    # },
    "regular_to_irregular": {
        "group_by": ["inflection"],
        "base_query": "is_regular",
        "inflected_query": "not is_regular",
        "all_query": all_query,
    },
    "irregular_to_regular": {
        "group_by": ["inflection"],
        "base_query": "not is_regular",
        "inflected_query": "is_regular",
        "all_query": all_query,
    },
    "nn_vb_ambiguous": {
        "group_by": ["inflection", "base_ambig_NN_VB"],
        "base_query": "is_regular",
        "inflected_query": "is_regular",
        "all_query": all_query,
    },
    "random_to_NNS": {
        "base_query": "inflection == 'random'",
        "inflected_query": "inflection == 'NNS'",
        "all_query": all_query,
    },
    "random_to_VBZ": {
        "base_query": "inflection == 'random'",
        "inflected_query": "inflection == 'VBZ'",
        "all_query": all_query,
    },
    "false_friends": {
        "all_query": "inflection.str.contains('FF')",
        "group_by": ["inflection"],
        "equivalence_keys": ["base", "inflected", "post_divergence"],
    }
}

In [None]:
# generate experiments testing transfer from each of top allomorphs in NNS, VBZ
# to each other
transfer_allomorphs = most_common_allomorphs.groupby("inflection").most_common_allomorph.apply(lambda xs: xs.value_counts().head(3).index.tolist()).to_dict()
study_unambiguous_transfer = ["NNS", "VBZ"]
for infl1, infl2 in itertools.product(study_unambiguous_transfer, repeat=2):
    for allomorph1 in transfer_allomorphs[infl1]:
        for allomorph2 in transfer_allomorphs[infl2]:
            experiments[f"unambiguous-{infl1}_{allomorph1}_to_{infl2}_{allomorph2}"] = {
                "base_query": f"inflection == '{infl1}' and is_regular and base_ambig_NN_VB == False and post_divergence == '{allomorph1}'",
                "inflected_query": f"inflection == '{infl2}' and is_regular and base_ambig_NN_VB == False and post_divergence == '{allomorph2}'",
                "all_query": all_query,
            }

In [None]:
# generate experiments testing transfer from
# 1. false friend allomorph to matching inflection allomorph
# 2. false friend allomorph to non-matching inflection allomorph
# 3. inflection allomorph to matching false friend allomorph
# 4. inflection allomorph to non-matching false friend allomorph
transfer_allomorphs = most_common_allomorphs.groupby("inflection").most_common_allomorph.apply(lambda xs: xs.value_counts().head(3).index.tolist()).to_dict()
study_false_friends = ["NNS", "VBZ"]
for (inflection, post_divergence), _ in false_friends_df.groupby(["inflection", "post_divergence"]):
    if inflection not in study_false_friends:
        continue
    for transfer_allomorph in transfer_allomorphs[inflection]:
        experiments[f"{inflection}-FF-{post_divergence}-to-{inflection}_{transfer_allomorph}"] = {
            "base_query": f"inflection == '{inflection}-FF-{post_divergence}'",
            "inflected_query": f"inflection == '{inflection}' and is_regular and base_ambig_NN_VB == False and post_divergence == '{transfer_allomorph}'",
        }
        experiments[f"{inflection}_{transfer_allomorph}-to-{inflection}-FF-{post_divergence}"] = {
            "base_query": f"inflection == '{inflection}' and is_regular and base_ambig_NN_VB == False and post_divergence == '{transfer_allomorph}'",
            "inflected_query": f"inflection == '{inflection}-FF-{post_divergence}'",
        }

for inflection in study_false_friends:
    for t1, t2 in itertools.combinations(transfer_allomorphs[inflection], 2):
        experiments[f"{inflection}-FF-{t1}-to-{inflection}-FF-{t2}"] = {
            "base_query": f"inflection == '{inflection}-FF-{t1}'",
            "inflected_query": f"inflection == '{inflection}-FF-{t2}'",
        }

In [None]:
experiment_results = pd.concat({
    experiment: analogy.run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_cross_instances,
        agg, agg_src,
        num_samples=1000,
        seed=42,
        device="cuda:2")
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])
experiment_results["correct"] = experiment_results.predicted_label == experiment_results.gt_label
experiment_results

### Serialize

In [None]:
torch.save({
    "inflection_results_df": inflection_results_df,
    "inflection_instance_df": inflection_instance_df,
    "all_cross_instances": all_cross_instances,
    "experiment_results": experiment_results,
}, f"{output_dir}/analogy_data_20250210.pt")

In [None]:
ser = torch.load(f"{output_dir}/analogy_data_20250210.pt")

In [None]:
inflection_results_df = ser["inflection_results_df"]
inflection_instance_df = ser["inflection_instance_df"]
all_cross_instances = ser["all_cross_instances"]
experiment_results = ser["experiment_results"]

In [None]:
experiment_results.groupby(["experiment", "group"]).correct.mean().sort_values(ascending=False)

In [None]:
experiment_results.groupby(["experiment", "group"]).gt_label_rank.mean().sort_values()

In [None]:
experiment_results.groupby(["experiment", "group"]).gt_distance.mean().sort_values()

### Experiment plots

#### Regularity

In [None]:
inflection_results_df.groupby(["inflection", "is_regular"]).size()

In [None]:
plot_regular_inflections = inflection_results_df.groupby(["inflection", "is_regular"]).size().unstack().fillna(0)
plot_regular_inflections = plot_regular_inflections[plot_regular_inflections.min(1) > 0]
plot_regular_inflections = sorted(plot_regular_inflections.index) + ["VBG", "random"]

In [None]:
f, ax = plt.subplots(1, len(plot_regular_inflections), figsize=(3 * len(plot_regular_inflections), 2.5),
                     squeeze=True)

regular_df = experiment_results.loc["regular"]
regular_df = pd.concat([regular_df, pd.DataFrame(regular_df.group.tolist()).add_prefix("group")], axis=1)
regular_transfer_df = experiment_results.loc["regular_to_irregular"]
regular_transfer_df["group"] = regular_transfer_df.group.str[0]
irregular_transfer_df = experiment_results.loc["irregular_to_regular"]
irregular_transfer_df["group"] = irregular_transfer_df.group.str[0]
regular_results = {}
for inflection in plot_regular_inflections:
    regular_results[inflection] = np.array([
        [regular_df.query("group0 == @inflection and group1 == True").correct.mean(),
         regular_transfer_df.query("group == @inflection").correct.mean()],
        [irregular_transfer_df.query("group == @inflection").correct.mean(),
         regular_df.query("group0 == @inflection and group1 == False").correct.mean()],
    ])

vmin = min(v.min() for v in regular_results.values())
vmax = max(v.max() for v in regular_results.values())
for ax, inflection in zip(ax, plot_regular_inflections):
    sns.heatmap(regular_results[inflection], annot=True, fmt=".2f", ax=ax,
                vmin=vmin, vmax=vmax, cbar=True,
                xticklabels=["Regular", "Irregular"],
                yticklabels=["Regular", "Irregular"])
    ax.set_title(inflection)
    ax.set_xlabel("Test")
    ax.set_ylabel("Train")

f.tight_layout()

#### Root NN/VB ambiguity

In [None]:
inflection_results_df.groupby(["inflection", "base_ambig_NN_VB"]).size()

In [None]:
inflection_results_df.groupby(["inflection", "base_ambig_NN_VB"]).sample(5)

In [None]:
inflection_instance_df.groupby(["inflection", "base_ambig_NN_VB", "post_divergence"]).size().sort_values(ascending=False).head(10).sort_index()

In [None]:
inflection_instance_df.query("base_ambig_NN_VB == False").groupby(["inflection", "base"]).head(1).groupby(["inflection", "post_divergence"]).size() \
    .loc[[("NNS", "S"), ("NNS", "Z"), ("NNS", "IH Z"), ("VBZ", "Z"), ("VBZ", "S"), ("VBZ", "IH Z")]]

In [None]:
agg_nnvb_results = []

nnvb_expts = experiment_results.index.get_level_values("experiment").unique()
nnvb_expts = nnvb_expts[nnvb_expts.str.contains("unambiguous-")]

for expt in nnvb_expts:
    inflection_from, allomorph_from, inflection_to, allomorph_to = \
        re.findall(r"unambiguous-(\w+)_([\w\s]+)_to_(\w+)_([\w\s]+)", expt)[0]
    expt_df = experiment_results.loc[expt].copy()

    num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))
    if num_seen_words < 10:
        print(f"Skipping {expt} due to only {num_seen_words} seen words")
        continue

    expt_df["inflection_from"] = inflection_from
    expt_df["allomorph_from"] = allomorph_from
    expt_df["inflection_to"] = inflection_to
    expt_df["allomorph_to"] = allomorph_to

    agg_nnvb_results.append(expt_df)

all_nnvb_results = pd.concat(agg_nnvb_results)

In [None]:
nnvb_results_summary = all_nnvb_results.groupby(["inflection_from", "inflection_to",
                                                 "allomorph_from", "allomorph_to"]) \
    .correct.agg(["count", "mean"]) \
    .query("count >= 0") \
    .reset_index()

nnvb_results_summary["source_label"] = nnvb_results_summary.inflection_from + " " + nnvb_results_summary.allomorph_from
nnvb_results_summary["target_label"] = nnvb_results_summary.inflection_to + " " + nnvb_results_summary.allomorph_to

nnvb_results_summary["transfer_label"] = nnvb_results_summary.inflection_from + " -> " + nnvb_results_summary.inflection_to
nnvb_results_summary["phon_label"] = nnvb_results_summary.allomorph_from + " " + nnvb_results_summary.allomorph_to

# only retain cases where we have data in both transfer directions from source <-> target within inflection
nnvb_results_summary["complement_exists"] = nnvb_results_summary.apply(lambda row: len(nnvb_results_summary.query("source_label == @row.target_label and target_label == @row.source_label")), axis=1)
nnvb_results_summary = nnvb_results_summary.query("complement_exists > 0").drop(columns=["complement_exists"])

nnvb_results_summary

In [None]:
sns.heatmap(nnvb_results_summary.set_index(["source_label", "target_label"]).sort_index()["mean"].unstack("target_label"), annot=True)

In [None]:
# sns.catplot(data=nnvb_results_summary, x="transfer_label", y="mean", hue="phon_label", kind="swarm", aspect=2)
order = nnvb_results_summary.groupby("transfer_label")["mean"].mean().sort_values(ascending=False).index
sns.catplot(data=nnvb_results_summary, x="transfer_label", y="mean", kind="box", order=order)

In [None]:
# all_nnvb_results["transfer_label"] = all_nnvb_results.inflection_from + " -> " + all_nnvb_results.inflection_to
# plot_df = all_nnvb_results.groupby(["transfer_label", "base_to"]).correct.mean().reset_index()
# order = plot_df.groupby("transfer_label")["correct"].mean().sort_values(ascending=False).index
# sns.catplot(data=plot_df, x="transfer_label", y="correct", kind="box", aspect=2)

In [None]:
nnvb_results_summary2 = all_nnvb_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

nnvb_results_summary2["transfer_label"] = nnvb_results_summary2.inflection_from + " -> " + nnvb_results_summary2.inflection_to
nnvb_results_summary2

In [None]:
ax = sns.heatmap(nnvb_results_summary2.set_index(["inflection_from", "inflection_to"]).correct.unstack("inflection_to"),
            annot=True)
ax.set_xlabel("Test")
ax.set_ylabel("Train")
ax.set_title("Accuracy")

In [None]:
ax = sns.heatmap(nnvb_results_summary2.set_index(["inflection_from", "inflection_to"]).gt_label_rank.unstack("inflection_to"),
            annot=True)
ax.set_xlabel("Test")
ax.set_ylabel("Train")
ax.set_title("Mean rank of GT")

In [None]:
ax = sns.heatmap(nnvb_results_summary2.set_index(["inflection_from", "inflection_to"]).gt_distance.unstack("inflection_to"),
            annot=True)
ax.set_xlabel("Test")
ax.set_ylabel("Train")
ax.set_title("Median distance to GT")

##### false friends

In [None]:
false_friend_expts = experiment_results.index.get_level_values("experiment").unique()
false_friend_expts = false_friend_expts[false_friend_expts.str.contains("FF")]
# false_friend_expts = false_friend_expts.tolist() + ["false_friends"]
sorted(false_friend_expts)

In [None]:
all_ff_results = []

for expt_name in false_friend_expts:
    expt_df = experiment_results.loc[expt_name].copy()
    num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))
    if num_seen_words < 10:
        print(f"Skipping {expt} due to only {num_seen_words} seen words")
        continue

    if expt_name.count("-FF-") == 2:
        allomorph_from, allomorph_to = re.findall(r"-FF-([\w\s]+)-to-.+FF-([\w\s]+)", expt_name)[0]
        ff_from, ff_to = True, True
    else:
        try:
            allomorph_from, allomorph_to = re.findall(r"_([\w\s]+)-to-.+FF-([\w\s]+)", expt_name)[0]
            # is the false friend on the "from" side?
            ff_from, ff_to = False, True
        except:
            allomorph_from, allomorph_to = re.findall(r".+FF-([\w\s]+)-to-.+_([\w\s]+)", expt_name)[0]
            ff_from, ff_to = True, False

    expt_df["allomorph_from"] = allomorph_from
    expt_df["allomorph_to"] = allomorph_to

    if ff_from:
        expt_df["inflection_from"] = expt_df.inflection_from.str.replace("-FF-.+", "-FF", regex=True)
    if ff_to:
        expt_df["inflection_to"] = expt_df.inflection_to.str.replace("-FF-.+", "-FF", regex=True)

    all_ff_results.append(expt_df)

all_ff_results = pd.concat(all_ff_results)

In [None]:
expt_df = experiment_results.loc["false_friends"].copy()
expt_df["allomorph_from"] = expt_df.inflection_from.str.extract(r"FF-(.+)$")
expt_df["allomorph_to"] = expt_df.inflection_to.str.extract(r"FF-(.+)$")
expt_df["inflection_from"] = expt_df.inflection_from.str.replace("-FF-.+", "-FF", regex=True)
expt_df["inflection_to"] = expt_df.inflection_to.str.replace("-FF-.+", "-FF", regex=True)

expt_df = expt_df[expt_df.inflection_from.isin(all_ff_results.inflection_from.unique())]

all_ff_results = pd.concat([all_ff_results, expt_df])

all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("from_freq"), left_on="base_from", right_index=True)
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("to_freq"), left_on="base_to", right_index=True)

all_ff_results["from_freq_bin"] = pd.qcut(all_ff_results.from_freq, 5, labels=[f"Q{i}" for i in range(1, 6)])
all_ff_results["to_freq_bin"] = pd.qcut(all_ff_results.to_freq, 5, labels=[f"Q{i}" for i in range(1, 6)])

In [None]:
all_ff_results.loc[all_ff_results.inflection_from.str.contains("-FF"), "is_strong_ff"] = \
    pd.merge(all_ff_results.loc[all_ff_results.inflection_from.str.contains("-FF")].assign(inflection_from_base=lambda xs: xs.inflection_from.str.replace("-FF", "")),
         false_friends_df.reset_index()[["inflection", "base", "strong_expected", "strong"]]
            .rename(columns={
                "inflection": "inflection_from_base",
                "base": "base_from",
                "strong_expected": "strong_ff_allomorph",
                "strong": "is_strong_ff"}),
         on=["inflection_from_base", "base_from"]).is_strong_ff

all_ff_results.loc[all_ff_results.inflection_to.str.contains("-FF"), "is_strong_ff"] = \
    pd.merge(all_ff_results.loc[all_ff_results.inflection_to.str.contains("-FF")]
                .assign(inflection_to_base=lambda xs: xs.inflection_to.str.replace("-FF", ""))
                .drop(columns=["is_strong_ff"]),
            false_friends_df.reset_index()[["inflection", "base", "strong_expected", "strong"]]
                .rename(columns={
                    "inflection": "inflection_to_base",
                    "base": "base_to",
                    "strong_expected": "strong_ff_allomorph",
                    "strong": "is_strong_ff"}),
            on=["inflection_to_base", "base_to"]).is_strong_ff

In [None]:
# ONLY STRONG
all_ff_results = all_ff_results[all_ff_results.is_strong_ff]

In [None]:
false_friends_df.groupby(["inflection", "post_divergence", "strong"]).size()

In [None]:
# def sample_k(rows, k):
#     return rows.sample(k, replace=True) if len(rows) > k else rows
# false_friends_df.query("inflection == 'NNS'").groupby("post_divergence").apply(sample_k, 5)

In [None]:
# all_ff_results.groupby(["inflection_from", "inflection_to"]).apply(lambda xs: xs[["base_from", "base_to", "allomorph_from", "allomorph_to"]].sample(3)).droplevel(-1).reset_index()

In [None]:
ff_results_summary = all_ff_results.groupby(["inflection_from", "inflection_to",
                                             "allomorph_from", "allomorph_to"]) \
        .correct.agg(["count", "mean"]) \
        .query("count >= 0") \
        .reset_index()

ff_results_summary["source_label"] = ff_results_summary.inflection_from + " " + ff_results_summary.allomorph_from
ff_results_summary["target_label"] = ff_results_summary.inflection_to + " " + ff_results_summary.allomorph_to

ff_results_summary["transfer_label"] = ff_results_summary.inflection_from + " -> " + ff_results_summary.inflection_to
ff_results_summary["phon_label"] = ff_results_summary.allomorph_from + " " + ff_results_summary.allomorph_to

In [None]:
# For FF1->FF2 results, fill out the upper triangle
extra_rows = []
for _, row in ff_results_summary[ff_results_summary.inflection_from.str.contains("-FF") & ff_results_summary.inflection_to.str.contains("-FF")].iterrows():
    if row.allomorph_to == row.allomorph_from:
        continue
    extra_rows.append(row.copy())
    extra_rows[-1].inflection_from, extra_rows[-1].inflection_to = extra_rows[-1].inflection_to, extra_rows[-1].inflection_from
    extra_rows[-1].allomorph_from, extra_rows[-1].allomorph_to = extra_rows[-1].allomorph_to, extra_rows[-1].allomorph_from
    extra_rows[-1].source_label, extra_rows[-1].target_label = extra_rows[-1].target_label, extra_rows[-1].source_label
    extra_rows[-1].transfer_label = extra_rows[-1].transfer_label.replace(" -> ", " <- ")

ff_results_summary = pd.concat([ff_results_summary, pd.DataFrame(extra_rows)])

In [None]:
ff_results_summary.sort_values(["source_label", "target_label"])#sort_values("mean", ascending=False)

In [None]:
f, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(pd.concat([nnvb_results_summary, ff_results_summary]).set_index(["source_label", "target_label"]).sort_index()["mean"].unstack("target_label"), annot=True)

In [None]:
ff_results_summary["ff_allomorph"] = ff_results_summary.apply(
    lambda row: row.allomorph_from if row.inflection_from.endswith("-FF") else row.allomorph_to, axis=1)
ff_results_summary["inflection_base"] = ff_results_summary.inflection_from.str.replace("-FF", "")
ff_results_summary["ff_direction"] = ff_results_summary.inflection_from.str.contains("-FF").map({True: "from", False: "to"})

In [None]:
order = ff_results_summary.groupby("ff_allomorph")["mean"].mean().sort_values(ascending=False).index
g = sns.catplot(data=ff_results_summary, x="ff_allomorph", y="mean", hue="inflection_base", col="ff_direction",
                order=order, kind="point")

In [None]:
g2 = sns.catplot(data=nnvb_results_summary.query("inflection_from == inflection_to"),
                 x="allomorph_from", y="mean", hue="inflection_from", kind="point")
# g2.axes.flat[0].set_ylim(g.axes.flat[0].get_ylim())

In [None]:
ff_results_summary2 = all_ff_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

ff_results_summary2["transfer_label"] = ff_results_summary2.inflection_from + " -> " + ff_results_summary2.inflection_to

# add in data for NNS->NNS and VBZ->VBZ
ff_results_summary2 = pd.concat([ff_results_summary2, nnvb_results_summary2.query("inflection_from == inflection_to")], axis=0)

In [None]:
sns.heatmap(ff_results_summary2.set_index(["inflection_from", "inflection_to"]).correct.unstack("inflection_to"),
            annot=True)

###### Frequency effects

In [None]:
all_ff_results["transfer_label"] = all_ff_results.inflection_from + " -> " + all_ff_results.inflection_to
g = sns.catplot(data=all_ff_results, x="from_freq_bin", y="correct", kind="point", 
                col="transfer_label", col_wrap=2,
                height=3, aspect=1.5)

# Add twin axes and histograms
for ax in g.axes.flat:
    twin_ax = ax.twinx()  # Create a twin y-axis
    col_name = ax.get_title().replace("transfer_label = ", "")  # Extract facet label

    # Get subset of data for this facet
    subset = all_ff_results[all_ff_results["transfer_label"] == col_name]

    # Plot histogram on twin axis
    sns.barplot(data=subset.groupby("from_freq_bin").apply(lambda xs: xs.base_from.nunique()),
                ax=twin_ax, alpha=0.3, color="blue")

    twin_ax.set_ylabel("Count")  # Label twin axis
    twin_ax.grid(False)  # Remove extra gridlines for clarity
    for spine in twin_ax.spines.values():
        spine.set_visible(False)

g.figure.suptitle("Frequency effect of source word")
g.tight_layout()

In [None]:
all_ff_results["transfer_label"] = all_ff_results.inflection_from + " -> " + all_ff_results.inflection_to
g = sns.catplot(data=all_ff_results, x="to_freq_bin", y="correct", kind="point", 
                col="transfer_label", col_wrap=2,
                height=3, aspect=1.5)

# Add twin axes and histograms
for ax in g.axes.flat:
    twin_ax = ax.twinx()  # Create a twin y-axis
    col_name = ax.get_title().replace("transfer_label = ", "")  # Extract facet label

    # Get subset of data for this facet
    subset = all_ff_results[all_ff_results["transfer_label"] == col_name]

    # Plot histogram on twin axis
    sns.barplot(data=subset.groupby("to_freq_bin").apply(lambda xs: xs.base_to.nunique()),
                ax=twin_ax, alpha=0.3, color="blue")

    twin_ax.set_ylabel("Count")  # Label twin axis
    twin_ax.grid(False)  # Remove extra gridlines for clarity
    for spine in twin_ax.spines.values():
        spine.set_visible(False)

g.figure.suptitle("Frequency effect of target word")
g.tight_layout()

###### etc

In [None]:
# grouping by "to" word
plot_df = all_ff_results.groupby(["inflection_from", "inflection_to", "base_to"]).correct.mean().reset_index()
plot_df["transfer_label"] = plot_df.inflection_from + " -> " + plot_df.inflection_to
order = plot_df.groupby("transfer_label")["correct"].mean().sort_values(ascending=False).index
sns.catplot(data=plot_df, x="transfer_label", y="correct", kind="box", order=order)

In [None]:
from_ff_scores = all_ff_results[all_ff_results.inflection_from.str.contains("-FF") & ~all_ff_results.inflection_to.str.contains("-FF")] \
    .groupby(["allomorph_from", "allomorph_to", "base_from"]).correct.mean() \
    .sort_values(ascending=False).reset_index()
from_ff_scores = pd.merge(from_ff_scores, word_freq_df["LogFreq"], left_on="base_from", right_index=True)
from_ff_scores

In [None]:
sns.regplot(data=from_ff_scores, x="LogFreq", y="correct")

In [None]:
to_ff_scores = all_ff_results[~all_ff_results.inflection_from.str.contains("-FF") & all_ff_results.inflection_to.str.contains("-FF")] \
    .groupby(["allomorph_from", "allomorph_to", "base_to"]).correct.mean() \
    .sort_values(ascending=False).reset_index()
to_ff_scores = pd.merge(to_ff_scores, word_freq_df["LogFreq"], left_on="base_to", right_index=True)
to_ff_scores.head(40)

In [None]:
sns.regplot(data=to_ff_scores, x="LogFreq", y="correct")

In [None]:
# grouping over transfer allomorphs
order = ff_results_summary.groupby("transfer_label")["mean"].median().sort_values(ascending=False).index
sns.catplot(data=ff_results_summary, x="transfer_label", y="mean", kind="box", order=order)

##### Frequency analysis

#### Frequency

In [None]:
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("from_freq"),
                              left_on="base_from", right_index=True)
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("to_freq"),
                              left_on="base_to", right_index=True)

In [None]:
all_nnvb_results["to_freq_bin"] = pd.cut(all_nnvb_results.to_freq, bins=5, labels=[f"Q{i}" for i in range(1, 6)])
all_nnvb_results["from_freq_bin"] = pd.cut(all_nnvb_results.from_freq, bins=5, labels=[f"Q{i}" for i in range(1, 6)])

In [None]:
all_nnvb_results.groupby(["inflection_from", "inflection_to", "from_freq_bin", "base_from"]).correct.mean().dropna().reset_index().groupby(["inflection_from", "inflection_to"]).size()

In [None]:
all_nnvb_results.groupby(["inflection_from", "inflection_to", "from_freq_bin"]).correct.mean().sort_index()

In [None]:
g = sns.catplot(data=all_nnvb_results, x="from_freq_bin", y="correct", kind="point",
                row="inflection_from", col="inflection_to", height=3, aspect=1.5)

# Add twin axes and histograms
for ax in g.axes.flat:
    twin_ax = ax.twinx()  # Create a twin y-axis
    inflection_from, inflection_to = re.findall(r"inflection_from = (.+) \| inflection_to = (.+)", ax.get_title())[0]

    # Get subset of data for this facet
    subset = all_nnvb_results.query("inflection_from == @inflection_from and inflection_to == @inflection_to")

    # Plot histogram on twin axis
    sns.barplot(data=subset.groupby("from_freq_bin").apply(lambda xs: xs.base_from.nunique()),
                ax=twin_ax, alpha=0.3, color="blue")

    twin_ax.set_ylabel("Count")  # Label twin axis
    twin_ax.grid(False)  # Remove extra gridlines for clarity
    for spine in twin_ax.spines.values():
        spine.set_visible(False)

g.figure.suptitle("Frequency effect of source word")
g.tight_layout()

In [None]:
g = sns.catplot(data=all_nnvb_results, x="to_freq_bin", y="correct", kind="point",
                row="inflection_from", col="inflection_to", height=3, aspect=1.5)

# Add twin axes and histograms
for ax in g.axes.flat:
    twin_ax = ax.twinx()  # Create a twin y-axis
    inflection_from, inflection_to = re.findall(r"inflection_from = (.+) \| inflection_to = (.+)", ax.get_title())[0]

    # Get subset of data for this facet
    subset = all_nnvb_results.query("inflection_from == @inflection_from and inflection_to == @inflection_to")

    # Plot histogram on twin axis
    sns.barplot(data=subset.groupby("to_freq_bin").apply(lambda xs: xs.base_to.nunique()),
                ax=twin_ax, alpha=0.3, color="blue")

    twin_ax.set_ylabel("Count")  # Label twin axis
    twin_ax.grid(False)  # Remove extra gridlines for clarity
    for spine in twin_ax.spines.values():
        spine.set_visible(False)

g.figure.suptitle("Frequency effect of target word")
g.tight_layout()

### Representational analysis

In [None]:
rep_spec_df = experiment_results.loc["nn_vb_ambiguous"]
rep_spec_df["inflection"] = rep_spec_df.group.str[0]
rep_spec_df["ambiguous"] = rep_spec_df.group.str[1]
rep_spec_df = rep_spec_df[~rep_spec_df.ambiguous]
rep_spec_df

In [None]:
def get_inflection_vectors(results_df, correct_only=True, max_num_vector_samples=250):
    ret = {}
    if correct_only:
        results_df = results_df[results_df.correct]

    for inflection, group in results_df.groupby("inflection"):
        ret[inflection] = {}
        for base_word in group.base_from.unique():
            # HACK look up inflected form which we didn't save
            inflected_word = inflection_results_df.loc[inflection].query("base == @base_word").inflected.iloc[0]

            base_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(base_word))[0]
            inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(inflected_word))[0]

            if len(base_flat_idxs) > max_num_vector_samples:
                base_flat_idxs = np.random.choice(base_flat_idxs, size=max_num_vector_samples, replace=False)
            elif len(base_flat_idxs) < max_num_vector_samples:
                base_flat_idxs = np.random.choice(base_flat_idxs, size=max_num_vector_samples, replace=True)
            if len(inflected_flat_idxs) > max_num_vector_samples:
                inflected_flat_idxs = np.random.choice(inflected_flat_idxs, size=max_num_vector_samples, replace=False)
            elif len(inflected_flat_idxs) < max_num_vector_samples:
                inflected_flat_idxs = np.random.choice(inflected_flat_idxs, size=max_num_vector_samples, replace=True)

            difference_vectors = (agg[inflected_flat_idxs] - agg[base_flat_idxs]).mean(0)
            ret[inflection][base_word] = difference_vectors

    return ret

In [None]:
infl_vectors = get_inflection_vectors(rep_spec_df)

In [None]:
analyze_infl_types = sorted(infl_vectors.keys())
infl_distances = np.zeros((len(analyze_infl_types), len(analyze_infl_types)))
infl_distances_all = []
for i, inflection_i in enumerate(analyze_infl_types):
    for j, inflection_j in enumerate(analyze_infl_types):
        distances = 1 - fastdist.cosine_matrix_to_matrix(
            np.stack([v for v in infl_vectors[inflection_i].values()]),
            np.stack([v for v in infl_vectors[inflection_j].values()])
        )
        infl_distances[i, j] = distances.mean()
        infl_distances_all.extend((inflection_i, inflection_j, d) for d in distances.flatten())

In [None]:
infl_distances_all = pd.DataFrame(infl_distances_all, columns=["inflection_i", "inflection_j", "distance"])
infl_distances_all["relationship"] = infl_distances_all.inflection_i + " to " + infl_distances_all.inflection_j
g = sns.catplot(data=infl_distances_all, x="relationship", y="distance", kind="bar")
g.axes.flat[0].set_ylabel("Cosine distance")

In [None]:
from scipy.stats import ttest_ind

In [None]:
ttest_ind(infl_distances_all.query("relationship == 'NNS to NNS'").distance,
          infl_distances_all.query("relationship == 'VBZ to VBZ'").distance)

In [None]:
sns.heatmap(pd.DataFrame(infl_distances, index=analyze_infl_types, columns=analyze_infl_types),
            annot=True)

#### Lil PCA thing

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2).fit(agg)
agg_pca = pca.transform(agg)

In [None]:
# f, ax = plt.subplots(figsize=(12, 12))
# plot_words = rep_spec_df.groupby("inflection").sample(3)
# cmap = sns.color_palette("tab10", len(plot_words))

# max_plot_points = 20

# for i, (_, row) in enumerate(plot_words.iterrows()):
#     base_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(row.base_from))[0]
#     base_reps = agg_pca[base_flat_idxs]
#     if len(base_reps) > max_plot_points:
#         base_reps = base_reps[np.random.choice(len(base_reps), size=max_plot_points, replace=False)]
#     ax.scatter(*base_reps.T, color=cmap[i], label=row.base_from)

#     if row.inflection == "NNS":
#         # get target inflected form
#         nns_form = inflection_results_df.loc["NNS"].query("base == @row.base_from").inflected.iloc[0]
#         inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(nns_form))[0]
#         inflected_reps = agg_pca[inflected_flat_idxs]
#         if len(inflected_reps) > max_plot_points:
#             inflected_reps = inflected_reps[np.random.choice(len(inflected_reps), size=max_plot_points, replace=False)]
        
#         ax.scatter(*inflected_reps.T, color=cmap[i], marker="x")
#     elif row.inflection == "VBZ":
#         # get VBZ inflected form
#         vbz_form = inflection_results_df.loc["VBZ"].query("base == @row.base_from").inflected.iloc[0]
#         inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbz_form))[0]
#         inflected_reps = agg_pca[inflected_flat_idxs]
#         if len(inflected_reps) > max_plot_points:
#             inflected_reps = inflected_reps[np.random.choice(len(inflected_reps), size=max_plot_points, replace=False)]
#         ax.scatter(*inflected_reps.T, color=cmap[i], marker="x")

#         # get VBD inflected form
#         vbd_form = inflection_results_df.loc["VBD"].query("base == @row.base_from").inflected.iloc[0]
#         inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbd_form))[0]
#         inflected_reps = agg_pca[inflected_flat_idxs]
#         if len(inflected_reps) > max_plot_points:
#             inflected_reps = inflected_reps[np.random.choice(len(inflected_reps), size=max_plot_points, replace=False)]
#         ax.scatter(*inflected_reps.T, color=cmap[i], marker="*")

#         # get VBG inflected form
#         vbg_form = inflection_results_df.loc["VBG"].query("base == @row.base_from").inflected.iloc[0]
#         inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbg_form))[0]
#         inflected_reps = agg_pca[inflected_flat_idxs]
#         if len(inflected_reps) > max_plot_points:
#             inflected_reps = inflected_reps[np.random.choice(len(inflected_reps), size=max_plot_points, replace=False)]
#         ax.scatter(*inflected_reps.T, color=cmap[i], marker="^")

# plt.legend()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA

# Collect the vectors to be visualized
def draw(xs):
    xs = xs.drop_duplicates("base_from")
    xs = xs.sample(min(100, len(xs)))
    return xs
plot_words = rep_spec_df.query("correct").groupby("inflection").apply(draw)
plot_means = True

# Extract indices of relevant embeddings
all_indices = []
for _, row in plot_words.iterrows():
    base_idx = np.where(agg_src[:, 0] == state_space_spec.labels.index(row.base_from))[0]
    all_indices.extend(base_idx.tolist())

    if row.inflection == "NNS":
        nns_form = inflection_results_df.loc["NNS"].query("base == @row.base_from").inflected.iloc[0]
        nns_idx = np.where(agg_src[:, 0] == state_space_spec.labels.index(nns_form))[0]
        all_indices.extend(nns_idx.tolist())

    elif row.inflection == "VBZ":
        vbz_form = inflection_results_df.loc["VBZ"].query("base == @row.base_from").inflected.iloc[0]
        vbz_idx = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbz_form))[0]
        all_indices.extend(vbz_idx.tolist())

        try:
            vbd_form = inflection_results_df.loc["VBD"].query("base == @row.base_from").inflected.iloc[0]
            vbd_idx = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbd_form))[0]
            all_indices.extend(vbd_idx.tolist())
        except:
            pass

        try:
            vbg_form = inflection_results_df.loc["VBG"].query("base == @row.base_from").inflected.iloc[0]
            vbg_idx = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbg_form))[0]
            all_indices.extend(vbg_idx.tolist())
        except:
            pass

# Ensure unique indices and extract embeddings
all_indices = np.unique(all_indices)
agg_centered = (agg - agg.mean(0)) / agg.std(0)
selected_vectors = agg_centered[all_indices]  # Use raw embeddings before PCA

# Run PCA on the selected embeddings
pca = PCA(n_components=2)
# pca.fit(agg_centered)
pca.fit(selected_vectors)
pca_transformed = pca.transform(selected_vectors)

# Map back to original indices
pca_dict = {idx: pca_transformed[i] for i, idx in enumerate(all_indices)}

# Set up plot
plot_inflections = sorted(plot_words.inflection.unique())
n_plot_inflections = len(plot_inflections)
f, axs = plt.subplots(figsize=(12, 12 * n_plot_inflections), nrows=n_plot_inflections)
# cmap = sns.color_palette("tab10", len(plot_words))
cmap = sns.color_palette("tab10", 2)
cmap = {i: cmap[0 if row.inflection == "NNS" else 1] for i, (_, row) in enumerate(plot_words.iterrows())}

max_plot_points = 20

# Plot points after PCA transformation
for i, (_, row) in enumerate(plot_words.iterrows()):
    ax = axs.flat[plot_inflections.index(row.inflection)]

    base_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(row.base_from))[0]
    base_reps = np.array([pca_dict[idx] for idx in base_flat_idxs if idx in pca_dict])
    if plot_means:
        base_reps = base_reps.mean(0, keepdims=True)
    if len(base_reps) > max_plot_points:
        base_reps = base_reps[np.random.choice(len(base_reps), size=max_plot_points, replace=False)]
    ax.scatter(*base_reps.T, color=cmap[i], label=row.base_from, alpha=0.5, s=100)

    if row.inflection == "NNS":
        nns_form = inflection_results_df.loc["NNS"].query("base == @row.base_from").inflected.iloc[0]
        inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(nns_form))[0]
        inflected_reps = np.array([pca_dict[idx] for idx in inflected_flat_idxs if idx in pca_dict])
        if plot_means:
            inflected_reps = inflected_reps.mean(0, keepdims=True)
        if len(inflected_reps) > max_plot_points:
            inflected_reps = inflected_reps[np.random.choice(len(inflected_reps), size=max_plot_points, replace=False)]
        ax.scatter(*inflected_reps.T, color=cmap[i], marker="x", alpha=0.5)
        if plot_means: 
            # plot a link between base and here
            ax.plot([base_reps[0, 0], inflected_reps[0, 0]], [base_reps[0, 1], inflected_reps[0, 1]], color=cmap[i], alpha=0.5)

    elif row.inflection == "VBZ":
        vbz_form = inflection_results_df.loc["VBZ"].query("base == @row.base_from").inflected.iloc[0]
        inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbz_form))[0]
        inflected_reps = np.array([pca_dict[idx] for idx in inflected_flat_idxs if idx in pca_dict])
        if plot_means:
            inflected_reps = inflected_reps.mean(0, keepdims=True)
        if len(inflected_reps) > max_plot_points:
            inflected_reps = inflected_reps[np.random.choice(len(inflected_reps), size=max_plot_points, replace=False)]
        ax.scatter(*inflected_reps.T, color=cmap[i], marker="x", alpha=0.5)
        if plot_means:
            # plot a link between base and here
            ax.plot([base_reps[0, 0], inflected_reps[0, 0]], [base_reps[0, 1], inflected_reps[0, 1]], color=cmap[i], alpha=0.5)

        try:
            vbd_form = inflection_results_df.loc["VBD"].query("base == @row.base_from").inflected.iloc[0]
            inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbd_form))[0]
            inflected_reps = np.array([pca_dict[idx] for idx in inflected_flat_idxs if idx in pca_dict])
            if plot_means:
                inflected_reps = inflected_reps.mean(0, keepdims=True)
            if len(inflected_reps) > max_plot_points:
                inflected_reps = inflected_reps[np.random.choice(len(inflected_reps), size=max_plot_points, replace=False)]
            ax.scatter(*inflected_reps.T, color=cmap[i], marker="*", alpha=0.5)
            if plot_means:
                # plot a link between base and here
                ax.plot([base_reps[0, 0], inflected_reps[0, 0]], [base_reps[0, 1], inflected_reps[0, 1]], color=cmap[i], alpha=0.5)
        except: pass

        try:
            vbg_form = inflection_results_df.loc["VBG"].query("base == @row.base_from").inflected.iloc[0]
            inflected_flat_idxs = np.where(agg_src[:, 0] == state_space_spec.labels.index(vbg_form))[0]
            inflected_reps = np.array([pca_dict[idx] for idx in inflected_flat_idxs if idx in pca_dict])
            if plot_means:
                inflected_reps = inflected_reps.mean(0, keepdims=True)
            if len(inflected_reps) > max_plot_points:
                inflected_reps = inflected_reps[np.random.choice(len(inflected_reps), size=max_plot_points, replace=False)]
            ax.scatter(*inflected_reps.T, color=cmap[i], marker="^", alpha=0.5)
            if plot_means:
                # plot a link between base and here
                ax.plot([base_reps[0, 0], inflected_reps[0, 0]], [base_reps[0, 1], inflected_reps[0, 1]], color=cmap[i], alpha=0.5)
        except: pass