In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import defaultdict, Counter
import functools
import itertools
from pathlib import Path
import pickle
import re

from fastdist import fastdist
import lemminflect
import matplotlib.pyplot as plt
from matplotlib import transforms
import nltk
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
import torch
from tqdm import tqdm

from src.analysis import analogy
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechHiddenStateDataset


In [7]:
base_model = "w2v2_8"

model_class = "discrim-rnn_32-mAP1"
model_name = "word_broad_10frames_fixedlen25"

inflection_results_path = "inflection_results.parquet"
all_cross_instances_path = "all_cross_instances.parquet"
most_common_allomorphs_path = "most_common_allomorphs.csv"
false_friends_path = "false_friends.csv"

train_dataset = "librispeech-train-clean-100"
# hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
hidden_states_path = f"/scratch/jgauthier/{base_model}_{train_dataset}.h5"
state_space_specs_path = f"state_space_spec.h5"
embeddings_path = "ID"

output_dir = f"."

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

agg_fns = [
    "mean",
]

## Prepare model representations

In [None]:
if embeddings_path == "ID":
    model_representations = SpeechHiddenStateDataset.from_hdf5(hidden_states_path).states
else:
    with open(embeddings_path, "rb") as f:
        model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory_agg = prepare_state_trajectory(model_representations, state_space_spec, 
                                          agg_fn_spec="mean", agg_fn_dimension=1)

In [22]:
agg, agg_src = flatten_trajectory(trajectory_agg)

## Prepare metadata

In [23]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label", "instance_idx", "frame_idx"]).sort_index()

In [24]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [25]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

In [26]:
all_cross_instances = pd.read_parquet(all_cross_instances_path)

In [27]:
inflection_results_df = pd.read_parquet(inflection_results_path)

In [28]:
most_common_allomorphs = pd.read_csv(most_common_allomorphs_path)
false_friends_df = pd.read_csv(false_friends_path)

## Behavioral tests

In [29]:
# general queries for all experiments to exclude special edge cases;
# logic doesn't make sense in most experiments
all_query = "not exclude_main"

experiments = {
    "basic": {
        "group_by": ["inflection"],
        "all_query": all_query,
    },
    "regular": {
        "group_by": ["inflection", "is_regular"],
        "all_query": all_query,
    },
    # "NNS_to_VBZ": {
    #     "base_query": "inflection == 'NNS' and is_regular",
    #     "inflected_query": "inflection == 'VBZ' and is_regular",
    # },
    # "VBZ_to_NNS": {
    #     "base_query": "inflection == 'VBZ' and is_regular",
    #     "inflected_query": "inflection == 'NNS' and is_regular",
    # },
    "regular_to_irregular": {
        "group_by": ["inflection"],
        "base_query": "is_regular",
        "inflected_query": "not is_regular",
        "all_query": all_query,
    },
    "irregular_to_regular": {
        "group_by": ["inflection"],
        "base_query": "not is_regular",
        "inflected_query": "is_regular",
        "all_query": all_query,
    },
    "nn_vb_ambiguous": {
        "group_by": ["inflection", "base_ambig_NN_VB"],
        "base_query": "is_regular",
        "inflected_query": "is_regular",
        "all_query": all_query,
    },
    "random_to_NNS": {
        "base_query": "inflection == 'random'",
        "inflected_query": "inflection == 'NNS'",
        "all_query": all_query,
    },
    "random_to_VBZ": {
        "base_query": "inflection == 'random'",
        "inflected_query": "inflection == 'VBZ'",
        "all_query": all_query,
    },
    "false_friends": {
        "all_query": "inflection.str.contains('FF')",
        "group_by": ["inflection"],
        "equivalence_keys": ["base", "inflected", "post_divergence"],
    }
}

In [30]:
# generate experiments testing transfer from each of top allomorphs in NNS, VBZ
# to each other
transfer_allomorphs = most_common_allomorphs.groupby("inflection").most_common_allomorph.apply(lambda xs: xs.value_counts().head(3).index.tolist()).to_dict()
study_unambiguous_transfer = ["NNS", "VBZ"]
for infl1, infl2 in itertools.product(study_unambiguous_transfer, repeat=2):
    for allomorph1 in transfer_allomorphs[infl1]:
        for allomorph2 in transfer_allomorphs[infl2]:
            experiments[f"unambiguous-{infl1}_{allomorph1}_to_{infl2}_{allomorph2}"] = {
                "base_query": f"inflection == '{infl1}' and is_regular and base_ambig_NN_VB == False and post_divergence == '{allomorph1}'",
                "inflected_query": f"inflection == '{infl2}' and is_regular and base_ambig_NN_VB == False and post_divergence == '{allomorph2}'",
                "all_query": all_query,
            }

In [31]:
# generate experiments testing transfer from
# 1. false friend allomorph to matching inflection allomorph
# 2. false friend allomorph to non-matching inflection allomorph
# 3. inflection allomorph to matching false friend allomorph
# 4. inflection allomorph to non-matching false friend allomorph
transfer_allomorphs = most_common_allomorphs.groupby("inflection").most_common_allomorph.apply(lambda xs: xs.value_counts().head(3).index.tolist()).to_dict()
study_false_friends = ["NNS", "VBZ"]
for (inflection, post_divergence), _ in false_friends_df.groupby(["inflection", "post_divergence"]):
    if inflection not in study_false_friends:
        continue
    for transfer_allomorph in transfer_allomorphs[inflection]:
        experiments[f"{inflection}-FF-{post_divergence}-to-{inflection}_{transfer_allomorph}"] = {
            "base_query": f"inflection == '{inflection}-FF-{post_divergence}'",
            "inflected_query": f"inflection == '{inflection}' and is_regular and base_ambig_NN_VB == False and post_divergence == '{transfer_allomorph}'",
        }
        experiments[f"{inflection}_{transfer_allomorph}-to-{inflection}-FF-{post_divergence}"] = {
            "base_query": f"inflection == '{inflection}' and is_regular and base_ambig_NN_VB == False and post_divergence == '{transfer_allomorph}'",
            "inflected_query": f"inflection == '{inflection}-FF-{post_divergence}'",
        }

for inflection in study_false_friends:
    for t1, t2 in itertools.combinations(transfer_allomorphs[inflection], 2):
        experiments[f"{inflection}-FF-{t1}-to-{inflection}-FF-{t2}"] = {
            "base_query": f"inflection == '{inflection}-FF-{t1}'",
            "inflected_query": f"inflection == '{inflection}-FF-{t2}'",
        }

In [None]:
experiment_results = pd.concat({
    experiment: analogy.run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_cross_instances,
        agg, agg_src,
        num_samples=1000,
        seed=seed,
        device="cuda")
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])
experiment_results["correct"] = experiment_results.predicted_label == experiment_results.gt_label
experiment_results

### Save

In [35]:
experiment_results.to_csv(f"{output_dir}/experiment_results.csv")

## Plots

### Regularity

In [None]:
inflection_results_df.groupby(["inflection", "is_regular"]).size()

In [37]:
plot_regular_inflections = inflection_results_df.groupby(["inflection", "is_regular"]).size().unstack().fillna(0)
plot_regular_inflections = plot_regular_inflections[plot_regular_inflections.min(1) > 0]
plot_regular_inflections = sorted(plot_regular_inflections.index) + ["VBG", "random"]

In [None]:
f, ax = plt.subplots(1, len(plot_regular_inflections), figsize=(3 * len(plot_regular_inflections), 2.5),
                     squeeze=True)

regular_df = experiment_results.loc["regular"]
regular_df = pd.concat([regular_df, pd.DataFrame(regular_df.group.tolist()).add_prefix("group")], axis=1)
regular_transfer_df = experiment_results.loc["regular_to_irregular"]
regular_transfer_df["group"] = regular_transfer_df.group.str[0]
irregular_transfer_df = experiment_results.loc["irregular_to_regular"]
irregular_transfer_df["group"] = irregular_transfer_df.group.str[0]
regular_results = {}
for inflection in plot_regular_inflections:
    regular_results[inflection] = np.array([
        [regular_df.query("group0 == @inflection and group1 == True").correct.mean(),
         regular_transfer_df.query("group == @inflection").correct.mean()],
        [irregular_transfer_df.query("group == @inflection").correct.mean(),
         regular_df.query("group0 == @inflection and group1 == False").correct.mean()],
    ])

vmin = min(v.min() for v in regular_results.values())
vmax = max(v.max() for v in regular_results.values())
for ax, inflection in zip(ax, plot_regular_inflections):
    sns.heatmap(regular_results[inflection], annot=True, fmt=".2f", ax=ax,
                vmin=vmin, vmax=vmax, cbar=True,
                xticklabels=["Regular", "Irregular"],
                yticklabels=["Regular", "Irregular"])
    ax.set_title(inflection)
    ax.set_xlabel("Test")
    ax.set_ylabel("Train")

f.tight_layout()

### Root NN/VB ambiguity

In [None]:
inflection_results_df.groupby(["inflection", "base_ambig_NN_VB"]).size()

In [None]:
inflection_results_df.groupby(["inflection", "base_ambig_NN_VB"]).sample(5)

In [41]:
agg_nnvb_results = []

nnvb_expts = experiment_results.index.get_level_values("experiment").unique()
nnvb_expts = nnvb_expts[nnvb_expts.str.contains("unambiguous-")]

for expt in nnvb_expts:
    inflection_from, allomorph_from, inflection_to, allomorph_to = \
        re.findall(r"unambiguous-(\w+)_([\w\s]+)_to_(\w+)_([\w\s]+)", expt)[0]
    expt_df = experiment_results.loc[expt].copy()

    num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))
    # DEV
    # if num_seen_words < 10:
    #     print(f"Skipping {expt} due to only {num_seen_words} seen words")
    #     continue

    expt_df["inflection_from"] = inflection_from
    expt_df["allomorph_from"] = allomorph_from
    expt_df["inflection_to"] = inflection_to
    expt_df["allomorph_to"] = allomorph_to

    agg_nnvb_results.append(expt_df)

all_nnvb_results = pd.concat(agg_nnvb_results)

In [70]:
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("from_freq"),
                              left_on="base_from", right_index=True)
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("to_freq"),
                              left_on="base_to", right_index=True)
all_nnvb_results["to_freq_bin"] = pd.cut(all_nnvb_results.to_freq, bins=5, labels=[f"Q{i}" for i in range(1, 6)])
all_nnvb_results["from_freq_bin"] = pd.cut(all_nnvb_results.from_freq, bins=5, labels=[f"Q{i}" for i in range(1, 6)])

In [None]:
nnvb_results_summary = all_nnvb_results.groupby(["inflection_from", "inflection_to",
                                                 "allomorph_from", "allomorph_to"]) \
    .correct.agg(["count", "mean"]) \
    .query("count >= 0") \
    .reset_index()

nnvb_results_summary["source_label"] = nnvb_results_summary.inflection_from + " " + nnvb_results_summary.allomorph_from
nnvb_results_summary["target_label"] = nnvb_results_summary.inflection_to + " " + nnvb_results_summary.allomorph_to

nnvb_results_summary["transfer_label"] = nnvb_results_summary.inflection_from + " -> " + nnvb_results_summary.inflection_to
nnvb_results_summary["phon_label"] = nnvb_results_summary.allomorph_from + " " + nnvb_results_summary.allomorph_to

# only retain cases where we have data in both transfer directions from source <-> target within inflection
nnvb_results_summary["complement_exists"] = nnvb_results_summary.apply(lambda row: len(nnvb_results_summary.query("source_label == @row.target_label and target_label == @row.source_label")), axis=1)
nnvb_results_summary = nnvb_results_summary.query("complement_exists > 0").drop(columns=["complement_exists"])

nnvb_results_summary

In [None]:
sns.heatmap(nnvb_results_summary.set_index(["source_label", "target_label"]).sort_index()["mean"].unstack("target_label"), annot=True)

In [None]:
# sns.catplot(data=nnvb_results_summary, x="transfer_label", y="mean", hue="phon_label", kind="swarm", aspect=2)
order = nnvb_results_summary.groupby("transfer_label")["mean"].mean().sort_values(ascending=False).index
sns.catplot(data=nnvb_results_summary, x="transfer_label", y="mean", kind="box", order=order)

In [None]:
nnvb_results_summary2 = all_nnvb_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

nnvb_results_summary2["transfer_label"] = nnvb_results_summary2.inflection_from + " -> " + nnvb_results_summary2.inflection_to
nnvb_results_summary2

In [None]:
ax = sns.heatmap(nnvb_results_summary2.set_index(["inflection_from", "inflection_to"]).correct.unstack("inflection_to"),
            annot=True)
ax.set_xlabel("Test")
ax.set_ylabel("Train")
ax.set_title("Accuracy")

In [None]:
ax = sns.heatmap(nnvb_results_summary2.set_index(["inflection_from", "inflection_to"]).gt_label_rank.unstack("inflection_to"),
            annot=True)
ax.set_xlabel("Test")
ax.set_ylabel("Train")
ax.set_title("Mean rank of GT")

In [None]:
ax = sns.heatmap(nnvb_results_summary2.set_index(["inflection_from", "inflection_to"]).gt_distance.unstack("inflection_to"),
            annot=True)
ax.set_xlabel("Test")
ax.set_ylabel("Train")
ax.set_title("Median distance to GT")

### False friends

In [None]:
false_friend_expts = experiment_results.index.get_level_values("experiment").unique()
false_friend_expts = false_friend_expts[false_friend_expts.str.contains("FF")]
# false_friend_expts = false_friend_expts.tolist() + ["false_friends"]
sorted(false_friend_expts)

In [52]:
all_ff_results = []

for expt_name in false_friend_expts:
    expt_df = experiment_results.loc[expt_name].copy()
    num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))

    # DEV
    # if num_seen_words < 10:
    #     print(f"Skipping {expt} due to only {num_seen_words} seen words")
    #     continue

    if expt_name.count("-FF-") == 2:
        allomorph_from, allomorph_to = re.findall(r"-FF-([\w\s]+)-to-.+FF-([\w\s]+)", expt_name)[0]
        ff_from, ff_to = True, True
    else:
        try:
            allomorph_from, allomorph_to = re.findall(r"_([\w\s]+)-to-.+FF-([\w\s]+)", expt_name)[0]
            # is the false friend on the "from" side?
            ff_from, ff_to = False, True
        except:
            allomorph_from, allomorph_to = re.findall(r".+FF-([\w\s]+)-to-.+_([\w\s]+)", expt_name)[0]
            ff_from, ff_to = True, False

    expt_df["allomorph_from"] = allomorph_from
    expt_df["allomorph_to"] = allomorph_to

    if ff_from:
        expt_df["inflection_from"] = expt_df.inflection_from.str.replace("-FF-.+", "-FF", regex=True)
    if ff_to:
        expt_df["inflection_to"] = expt_df.inflection_to.str.replace("-FF-.+", "-FF", regex=True)

    all_ff_results.append(expt_df)

all_ff_results = pd.concat(all_ff_results)

In [53]:
expt_df = experiment_results.loc["false_friends"].copy()
expt_df["allomorph_from"] = expt_df.inflection_from.str.extract(r"FF-(.+)$")
expt_df["allomorph_to"] = expt_df.inflection_to.str.extract(r"FF-(.+)$")
expt_df["inflection_from"] = expt_df.inflection_from.str.replace("-FF-.+", "-FF", regex=True)
expt_df["inflection_to"] = expt_df.inflection_to.str.replace("-FF-.+", "-FF", regex=True)

expt_df = expt_df[expt_df.inflection_from.isin(all_ff_results.inflection_from.unique())]

all_ff_results = pd.concat([all_ff_results, expt_df])

all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("from_freq"), left_on="base_from", right_index=True)
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("to_freq"), left_on="base_to", right_index=True)

all_ff_results["from_freq_bin"] = pd.qcut(all_ff_results.from_freq, 5, labels=[f"Q{i}" for i in range(1, 6)])
all_ff_results["to_freq_bin"] = pd.qcut(all_ff_results.to_freq, 5, labels=[f"Q{i}" for i in range(1, 6)])

In [54]:
all_ff_results.loc[all_ff_results.inflection_from.str.contains("-FF"), "is_strong_ff"] = \
    pd.merge(all_ff_results.loc[all_ff_results.inflection_from.str.contains("-FF")].assign(inflection_from_base=lambda xs: xs.inflection_from.str.replace("-FF", "")),
         false_friends_df.reset_index()[["inflection", "base", "strong_expected", "strong"]]
            .rename(columns={
                "inflection": "inflection_from_base",
                "base": "base_from",
                "strong_expected": "strong_ff_allomorph",
                "strong": "is_strong_ff"}),
         on=["inflection_from_base", "base_from"]).is_strong_ff

all_ff_results.loc[all_ff_results.inflection_to.str.contains("-FF"), "is_strong_ff"] = \
    pd.merge(all_ff_results.loc[all_ff_results.inflection_to.str.contains("-FF")]
                .assign(inflection_to_base=lambda xs: xs.inflection_to.str.replace("-FF", ""))
                .drop(columns=["is_strong_ff"]),
            false_friends_df.reset_index()[["inflection", "base", "strong_expected", "strong"]]
                .rename(columns={
                    "inflection": "inflection_to_base",
                    "base": "base_to",
                    "strong_expected": "strong_ff_allomorph",
                    "strong": "is_strong_ff"}),
            on=["inflection_to_base", "base_to"]).is_strong_ff

In [55]:
# ONLY STRONG
all_ff_results = all_ff_results[all_ff_results.is_strong_ff]

In [56]:
ff_results_summary = all_ff_results.groupby(["inflection_from", "inflection_to",
                                             "allomorph_from", "allomorph_to"]) \
        .correct.agg(["count", "mean"]) \
        .query("count >= 0") \
        .reset_index()

ff_results_summary["source_label"] = ff_results_summary.inflection_from + " " + ff_results_summary.allomorph_from
ff_results_summary["target_label"] = ff_results_summary.inflection_to + " " + ff_results_summary.allomorph_to

ff_results_summary["transfer_label"] = ff_results_summary.inflection_from + " -> " + ff_results_summary.inflection_to
ff_results_summary["phon_label"] = ff_results_summary.allomorph_from + " " + ff_results_summary.allomorph_to

In [57]:
# For FF1->FF2 results, fill out the upper triangle
extra_rows = []
for _, row in ff_results_summary[ff_results_summary.inflection_from.str.contains("-FF") & ff_results_summary.inflection_to.str.contains("-FF")].iterrows():
    if row.allomorph_to == row.allomorph_from:
        continue
    extra_rows.append(row.copy())
    extra_rows[-1].inflection_from, extra_rows[-1].inflection_to = extra_rows[-1].inflection_to, extra_rows[-1].inflection_from
    extra_rows[-1].allomorph_from, extra_rows[-1].allomorph_to = extra_rows[-1].allomorph_to, extra_rows[-1].allomorph_from
    extra_rows[-1].source_label, extra_rows[-1].target_label = extra_rows[-1].target_label, extra_rows[-1].source_label
    extra_rows[-1].transfer_label = extra_rows[-1].transfer_label.replace(" -> ", " <- ")

ff_results_summary = pd.concat([ff_results_summary, pd.DataFrame(extra_rows)])

In [None]:
ff_results_summary.sort_values("mean", ascending=False)

In [None]:
f, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(pd.concat([nnvb_results_summary, ff_results_summary]).set_index(["source_label", "target_label"]).sort_index()["mean"].unstack("target_label"), annot=True)

In [60]:
ff_results_summary["ff_allomorph"] = ff_results_summary.apply(
    lambda row: row.allomorph_from if row.inflection_from.endswith("-FF") else row.allomorph_to, axis=1)
ff_results_summary["inflection_base"] = ff_results_summary.inflection_from.str.replace("-FF", "")
ff_results_summary["ff_direction"] = ff_results_summary.inflection_from.str.contains("-FF").map({True: "from", False: "to"})

In [None]:
order = ff_results_summary.groupby("ff_allomorph")["mean"].mean().sort_values(ascending=False).index
g = sns.catplot(data=ff_results_summary, x="ff_allomorph", y="mean", hue="inflection_base", col="ff_direction",
                order=order, kind="point")

In [None]:
g2 = sns.catplot(data=nnvb_results_summary.query("inflection_from == inflection_to"),
                 x="allomorph_from", y="mean", hue="inflection_from", kind="point")
# g2.axes.flat[0].set_ylim(g.axes.flat[0].get_ylim())

In [63]:
ff_results_summary2 = all_ff_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

ff_results_summary2["transfer_label"] = ff_results_summary2.inflection_from + " -> " + ff_results_summary2.inflection_to

# add in data for NNS->NNS and VBZ->VBZ
ff_results_summary2 = pd.concat([ff_results_summary2, nnvb_results_summary2.query("inflection_from == inflection_to")], axis=0)

In [None]:
sns.heatmap(ff_results_summary2.set_index(["inflection_from", "inflection_to"]).correct.unstack("inflection_to"),
            annot=True)

In [None]:
all_ff_results["transfer_label"] = all_ff_results.inflection_from + " -> " + all_ff_results.inflection_to
g = sns.catplot(data=all_ff_results, x="from_freq_bin", y="correct", kind="point", 
                col="transfer_label", col_wrap=2,
                height=3, aspect=1.5)

# Add twin axes and histograms
for ax in g.axes.flat:
    twin_ax = ax.twinx()  # Create a twin y-axis
    col_name = ax.get_title().replace("transfer_label = ", "")  # Extract facet label

    # Get subset of data for this facet
    subset = all_ff_results[all_ff_results["transfer_label"] == col_name]

    # Plot histogram on twin axis
    sns.barplot(data=subset.groupby("from_freq_bin").apply(lambda xs: xs.base_from.nunique()),
                ax=twin_ax, alpha=0.3, color="blue")

    twin_ax.set_ylabel("Count")  # Label twin axis
    twin_ax.grid(False)  # Remove extra gridlines for clarity
    for spine in twin_ax.spines.values():
        spine.set_visible(False)

g.figure.suptitle("Frequency effect of source word")
g.tight_layout()

In [None]:
all_ff_results["transfer_label"] = all_ff_results.inflection_from + " -> " + all_ff_results.inflection_to
g = sns.catplot(data=all_ff_results, x="to_freq_bin", y="correct", kind="point", 
                col="transfer_label", col_wrap=2,
                height=3, aspect=1.5)

# Add twin axes and histograms
for ax in g.axes.flat:
    twin_ax = ax.twinx()  # Create a twin y-axis
    col_name = ax.get_title().replace("transfer_label = ", "")  # Extract facet label

    # Get subset of data for this facet
    subset = all_ff_results[all_ff_results["transfer_label"] == col_name]

    # Plot histogram on twin axis
    sns.barplot(data=subset.groupby("to_freq_bin").apply(lambda xs: xs.base_to.nunique()),
                ax=twin_ax, alpha=0.3, color="blue")

    twin_ax.set_ylabel("Count")  # Label twin axis
    twin_ax.grid(False)  # Remove extra gridlines for clarity
    for spine in twin_ax.spines.values():
        spine.set_visible(False)

g.figure.suptitle("Frequency effect of target word")
g.tight_layout()

### Frequency effects

In [None]:
g = sns.catplot(data=all_nnvb_results, x="from_freq_bin", y="correct", kind="point",
                row="inflection_from", col="inflection_to", height=3, aspect=1.5)

# Add twin axes and histograms
for ax in g.axes.flat:
    twin_ax = ax.twinx()  # Create a twin y-axis
    inflection_from, inflection_to = re.findall(r"inflection_from = (.+) \| inflection_to = (.+)", ax.get_title())[0]

    # Get subset of data for this facet
    subset = all_nnvb_results.query("inflection_from == @inflection_from and inflection_to == @inflection_to")

    # Plot histogram on twin axis
    sns.barplot(data=subset.groupby("from_freq_bin").apply(lambda xs: xs.base_from.nunique()),
                ax=twin_ax, alpha=0.3, color="blue")

    twin_ax.set_ylabel("Count")  # Label twin axis
    twin_ax.grid(False)  # Remove extra gridlines for clarity
    for spine in twin_ax.spines.values():
        spine.set_visible(False)

g.figure.suptitle("Frequency effect of source word")
g.tight_layout()

In [None]:
g = sns.catplot(data=all_nnvb_results, x="to_freq_bin", y="correct", kind="point",
                row="inflection_from", col="inflection_to", height=3, aspect=1.5)

# Add twin axes and histograms
for ax in g.axes.flat:
    twin_ax = ax.twinx()  # Create a twin y-axis
    inflection_from, inflection_to = re.findall(r"inflection_from = (.+) \| inflection_to = (.+)", ax.get_title())[0]

    # Get subset of data for this facet
    subset = all_nnvb_results.query("inflection_from == @inflection_from and inflection_to == @inflection_to")

    # Plot histogram on twin axis
    sns.barplot(data=subset.groupby("to_freq_bin").apply(lambda xs: xs.base_to.nunique()),
                ax=twin_ax, alpha=0.3, color="blue")

    twin_ax.set_ylabel("Count")  # Label twin axis
    twin_ax.grid(False)  # Remove extra gridlines for clarity
    for spine in twin_ax.spines.values():
        spine.set_visible(False)

g.figure.suptitle("Frequency effect of target word")
g.tight_layout()

## Save all analyses

In [73]:
all_ff_results.to_csv(f"{output_dir}/all_ff_results.csv")
all_nnvb_results.to_csv(f"{output_dir}/all_nnvb_results.csv")

In [74]:
ff_results_summary.to_csv(f"{output_dir}/ff_results_summary.csv")
nnvb_results_summary.to_csv(f"{output_dir}/nnvb_results_summary.csv")