In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import re

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm
from statsmodels.formula.api import logit, ols

from src.analysis.state_space import StateSpaceAnalysisSpec
from src.utils import concat_csv_with_indices

In [None]:
sns.set_context("paper", font_scale=2)

In [None]:
false_friends_path = "outputs/analogy/inputs/librispeech-test-clean/w2v2/false_friends.csv"
state_space_path = "outputs/analogy/inputs/librispeech-test-clean/w2v2/state_space_spec.h5"
most_common_allomorphs_path = "outputs/analogy/inputs/librispeech-test-clean/w2v2/most_common_allomorphs.csv"
cross_instances_path = "outputs/analogy/inputs/librispeech-test-clean/w2v2/all_cross_instances.parquet"
output_dir = "analogy_figures_heldout"

In [None]:
# Grouping variables on experiment results dataframe to select a single run
run_groupers = ["base_model_name", "model_name", "equivalence"]

# plot_runs = [(f"w2v2_{i}", "ff_32", "word_broad_10frames_fixedlen25") for i in range(12)] + \
#             [(f"w2v2_{i}", "discrim-ff_32", "word_broad_10frames_fixedlen25") for i in range(12)] + \
#             [(f"w2v2_{i}", "id", "id") for i in range(12)] + \
#             [(f"w2v2_{i}", "ffff_32", "word_broad_10frames_fixedlen25") for i in range(12)]
plot_runs = [(f"w2v2_{i}", "ffff_32", "word_broad_10frames_fixedlen25") for i in range(12)] + \
             [(f"w2v2_{i}", "id", "id") for i in range(12)]
# plot_runs = [("w2v2_8", "ff_32", "word_broad_10frames_fixedlen25"),]

main_plot_run = ("w2v2_8", "ffff_32", "word_broad_10frames_fixedlen25")
foil_plot_run = ("w2v2_8", "id", "id")
main_plot_name = "Word"
foil_plot_name = "Wav2Vec"
# main_plot_run = ("w2v2_8", "discrim-ff_32", "word_broad_10frames_fixedlen25")
# choose a vmin, vmax so that all heatmaps have the same color scale
main_plot_vmin, main_plot_vmax = 0.4, 0.9

plot_inflections = ["NNS", "VBZ"]
plot_metrics = ["correct", "gt_label_rank", "gt_distance"]

## Load metadata

In [None]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq_rel", "TwitterFreq_rel", "NewsFreq_rel"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

In [None]:
false_friends_df = pd.read_csv(false_friends_path)

In [None]:
most_common_allomorphs = pd.read_csv(most_common_allomorphs_path)

In [None]:
all_cross_instances = pd.read_parquet(cross_instances_path)

## Load results

In [None]:
all_results = concat_csv_with_indices(
        "outputs/analogy/runs/**/librispeech-test-clean/experiment_results.csv",
        [lambda p: p.parents[1].name, lambda p: p.parents[2].name,
            lambda p: p.parents[3].name],
        ["equivalence", "model_name", "base_model_name"]) \
    .droplevel(-1).reset_index()

In [None]:
all_id_results = concat_csv_with_indices(
        "outputs/analogy/runs_id/librispeech-test-clean/**/experiment_results.csv",
        [lambda p: p.parents[0].name],
        ["base_model_name"]) \
    .droplevel(-1).reset_index()
all_id_results["model_name"] = "id"
all_id_results["equivalence"] = "id"

In [None]:
all_results = pd.concat([all_results, all_id_results], ignore_index=True)

In [None]:
all_results["group"] = all_results.group.apply(lambda x: eval(x) if not (isinstance(x, float) and np.isnan(x)) else None)
all_results = all_results.reset_index()

### Exclusions

In [None]:
# post-hoc exclude some words which have base/inflected forms which are homophonous with items of another category
exclusions = [
    ("VBZ", "know"), # nose
    ("VBZ", "seem"), # seam
    ("VBZ", "please"), # pleas
    ("VBZ", "write"), # right
    ("VBZ", "meet"), # meat
    ("VBZ", "read"), # reed
]

In [None]:
# anti join
all_results = all_results.merge(pd.DataFrame(exclusions, columns=["inflection_from", "base_from"]),
                  on=["inflection_from", "base_from"], how="left", indicator=True) \
    .query("_merge == 'left_only'").drop(columns="_merge") \
    .merge(pd.DataFrame(exclusions, columns=["inflection_to", "base_to"]),
           on=["inflection_to", "base_to"], how="left", indicator=True) \
    .query("_merge == 'left_only'").drop(columns="_merge")

### Layer-wise

In [None]:
plot_lw = all_results.query("experiment == 'regular'").copy()
plot_lw = plot_lw.groupby(run_groupers + ["group", "inflection_from"])[["correct", "gt_label_rank", "gt_distance"]].mean()
plot_lw = plot_lw.reindex([(*plot_run, group, inflection_from)
                           for group in plot_lw.index.get_level_values("group").unique()
                           for inflection_from in plot_lw.index.get_level_values("inflection_from").unique()
                           for plot_run in plot_runs]).reset_index()
plot_lw["group0"] = plot_lw.group.apply(lambda x: x[0] if x is not None else None)
plot_lw["group1"] = plot_lw.group.apply(lambda x: x[1] if x is not None else None)
plot_lw["layer"] = plot_lw.base_model_name.str.extract(r"_(\d+)$").astype(int) + 1
plot_lw["model_name"] = plot_lw["model_name"].map({"id": "Wav2Vec", main_plot_run[1]: "Word"})

lw_random = plot_lw[plot_lw.group0 == "random"].groupby(["inflection_from", "layer"])[["correct", "gt_label_rank"]].mean().reset_index().dropna()

plot_lw = plot_lw[plot_lw.inflection_from.isin(plot_inflections)]
plot_lw = plot_lw[(plot_lw.group1 == True)]

In [None]:
g = sns.catplot(data=plot_lw, x="layer", y="correct", hue="model_name", row="inflection_from",
                kind="point", height=3, aspect=3)

for ax, row_name in zip(g.axes.flat, g.row_names):
    sns.lineplot(data=lw_random,
                 x="layer", y="correct", ax=ax, color="gray", linestyle="--",
                 legend=None)
    ax.set_title(ax.get_title().split("=")[1].strip())
    ax.set_ylabel("Correct")
    if ax.get_xlabel() == "layer":
        ax.set_xlabel("Layer")

# g.figure.tight_layout()
g.figure.savefig(f"{output_dir}/layer_wise.pdf")

In [None]:
hue_order = ["Word", "Wav2Vec"]
g = sns.catplot(data=plot_lw, x="layer", y="gt_label_rank",
                hue="model_name", hue_order=hue_order,
                row="inflection_from", row_order=plot_inflections,
                kind="point", sharey=False, height=3, aspect=1.25)

for (row, col, hue), data in g.facet_data():
    ax = g.axes[row, col]
    ax.set_title(ax.get_title().split("=")[1].strip())
    ax.set_ylabel("Rank")
    if ax.get_xlabel() == "layer":
        ax.set_xlabel("Layer")

    # Add an inset showing the middle layers
    inset_ax = ax.inset_axes([0.45, 0.45, 0.4, 0.5])
    sns.pointplot(data=data[data.layer.between(7, 9)], x="layer", y="gt_label_rank",
                  hue="model_name", hue_order=hue_order,
                  ax=inset_ax, legend=None)
    inset_ax.set_xlabel(None)
    inset_ax.set_ylabel(None)
    for spine in inset_ax.spines.values():
        spine.set_edgecolor("gray")
    inset_ax.tick_params(color="gray")

g.legend.remove()
g.add_legend(title="Model", bbox_to_anchor=(0.19, 0.99), loc="upper left")
g.figure.tight_layout()
g.figure.savefig(f"{output_dir}/layer_wise-rank.pdf")

In [None]:
g = sns.catplot(data=plot_lw, x="layer", y="gt_distance", hue="model_name", row="inflection_from",
                kind="point", height=3, aspect=3)

for ax, row_name in zip(g.axes.flat, g.row_names):
    # sns.lineplot(data=lw_random,
    #              x="layer", y="gt_label_rank", ax=ax, color="gray", linestyle="--",
    #              legend=None)
    ax.set_title(ax.get_title().split("=")[1].strip())
    ax.set_ylabel("Correct")
    if ax.get_xlabel() == "layer":
        ax.set_xlabel("Layer")

# g.figure.tight_layout()
g.figure.savefig(f"{output_dir}/layer_wise-distance.pdf")

## Compute controlled NNVB results

In [None]:
all_nnvb_results = []

for run, run_results in all_results.groupby(run_groupers):
    run_results = run_results.set_index("experiment")
    nnvb_expts = run_results.index.unique()
    nnvb_expts = nnvb_expts[nnvb_expts.str.contains("unambiguous-")]

    for expt in nnvb_expts:
        inflection_from, allomorph_from, inflection_to, allomorph_to = \
            re.findall(r"unambiguous-(\w+)_([\w\s]+)_to_(\w+)_([\w\s]+)", expt)[0]
        expt_df = run_results.loc[[expt]].copy()

        num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))
        # DEV
        # if num_seen_words < 10:
        #     print(f"Skipping {expt} due to only {num_seen_words} seen words")
        #     continue

        expt_df["inflection_from"] = inflection_from
        expt_df["allomorph_from"] = allomorph_from
        expt_df["inflection_to"] = inflection_to
        expt_df["allomorph_to"] = allomorph_to

        all_nnvb_results.append(expt_df)

all_nnvb_results = pd.concat(all_nnvb_results)

all_nnvb_results["inflected_from"] = all_nnvb_results.from_equiv_label.apply(lambda x: eval(x)[1])
all_nnvb_results["inflected_to"] = all_nnvb_results.to_equiv_label.apply(lambda x: eval(x)[1])

all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("from_base_freq"),
                            left_on="base_from", right_index=True)
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("from_inflected_freq"),
                            left_on="inflected_from", right_index=True)
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("to_base_freq"),
                              left_on="base_to", right_index=True)
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("to_inflected_freq"),
                            left_on="inflected_to", right_index=True)

all_nnvb_results["from_freq"] = all_nnvb_results[["from_base_freq", "from_inflected_freq"]].mean(axis=1)
all_nnvb_results["to_freq"] = all_nnvb_results[["to_base_freq", "to_inflected_freq"]].mean(axis=1)

_, frequency_bins = pd.qcut(pd.concat([all_nnvb_results.to_freq, all_nnvb_results.from_freq]), q=5, retbins=True)
all_nnvb_results["to_freq_bin"] = pd.cut(all_nnvb_results.to_freq, bins=frequency_bins, labels=[f"Q{i}" for i in range(1, 6)])
all_nnvb_results["from_freq_bin"] = pd.cut(all_nnvb_results.from_freq, bins=frequency_bins, labels=[f"Q{i}" for i in range(1, 6)])

In [None]:
def summarize_nnvb_run(rows):
    rows["source_label"] = rows.inflection_from + " " + rows.allomorph_from
    rows["target_label"] = rows.inflection_to + " " + rows.allomorph_to

    rows["transfer_label"] = rows.inflection_from + " -> " + rows.inflection_to
    rows["phon_label"] = rows.allomorph_from + " " + rows.allomorph_to

    # only retain cases where we have data in both transfer directions from source <-> target within inflection
    rows["complement_exists"] = rows.apply(lambda row: len(rows.query("source_label == @row.target_label and target_label == @row.source_label")), axis=1)
    rows = rows.query("complement_exists > 0").drop(columns=["complement_exists"])

    return rows

summary_groupers = ["inflection_from", "inflection_to", "allomorph_from", "allomorph_to"]
nnvb_results_summary = all_nnvb_results.groupby(run_groupers + summary_groupers) \
    .correct.agg(["count", "mean"]) \
    .query("count >= 0") \
    .reset_index(summary_groupers) \
    .groupby(run_groupers, group_keys=False) \
    .apply(summarize_nnvb_run) \
    .reset_index()

nnvb_results_summary

In [None]:
# plot_results = []
# for base_model_name, model_name, equivalence in plot_runs:
#     results_i = nnvb_results_summary.query("base_model_name == @base_model_name and model_name == @model_name and equivalence == @equivalence")
#     if len(results_i) > 0:
#         plot_results.append(results_i)
# num_plot_runs = len(plot_results)

# ncols = 4
# nrows = int(np.ceil(num_plot_runs / ncols))
# f, axs = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))

# for ax, results_i in zip(axs.flat, plot_results):
#     sns.heatmap(results_i.set_index(["source_label", "target_label"])["mean"].unstack(),
#                 vmin=0, vmax=1, ax=ax)
#     key_row = results_i.iloc[0]
#     ax.set_title(f"{key_row.base_model_name} -> {key_row.model_name} ({key_row.equivalence})")

### Focused plots

In [None]:
focus_base_model, focus_model, focus_equivalence = main_plot_run
foil_base_model, foil_model, foil_equivalence = "w2v2_8", "id", "id"

nnvb_focus = all_nnvb_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")
nnvb_foil = all_nnvb_results.query("base_model_name == @foil_base_model and model_name == @foil_model and equivalence == @foil_equivalence")
nnvb_focus["model_label"] = "Word"
nnvb_foil["model_label"] = "Wav2Vec"

nnvb_focus = pd.concat([nnvb_focus, nnvb_foil])

allomorph_labels = {"Z": "z", "S": "s", "IH Z": "ɪz"}
nnvb_focus["allomorph_from"] = nnvb_focus.allomorph_from.map(allomorph_labels)
nnvb_focus["allomorph_to"] = nnvb_focus.allomorph_to.map(allomorph_labels)
nnvb_focus

In [None]:
nnvb_results_summary = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to",
                                             "allomorph_from", "allomorph_to"]) \
    [plot_metrics].mean() \
    .reset_index().astype({"correct": float})

nnvb_results_summary["source_label"] = nnvb_results_summary.inflection_from + "\n" + nnvb_results_summary.allomorph_from
nnvb_results_summary["target_label"] = nnvb_results_summary.inflection_to + "\n" + nnvb_results_summary.allomorph_to

nnvb_results_summary["transfer_label"] = nnvb_results_summary.inflection_from + " -> " + nnvb_results_summary.inflection_to
nnvb_results_summary["phon_label"] = nnvb_results_summary.allomorph_from + " " + nnvb_results_summary.allomorph_to

# only retain cases where we have data in both transfer directions from source <-> target within inflection
nnvb_results_summary["complement_exists"] = nnvb_results_summary.apply(lambda row: len(nnvb_results_summary.query("source_label == @row.target_label and target_label == @row.source_label")), axis=1)
nnvb_results_summary = nnvb_results_summary.query("complement_exists > 0").drop(columns=["complement_exists"])

# drop VBZ IH Z, which only has 4 word types
nnvb_results_summary = nnvb_results_summary[(nnvb_results_summary.source_label != "VBZ\nɪz") & (nnvb_results_summary.target_label != "VBZ\nɪz")]

nnvb_results_summary

In [None]:
g = sns.catplot(data=nnvb_focus.melt(id_vars=["inflection_from", "allomorph_from", "inflection_to", "model_label"], value_vars=plot_metrics)
                        .assign(source_label=lambda xs: xs.inflection_from + " " + xs.allomorph_from),
                x="inflection_to", hue="source_label", y="value", col="model_label", row="variable", kind="bar", sharey="row")

In [None]:
for metric in plot_metrics:
    vmin = 0 if metric == "correct" else nnvb_results_summary[metric].min()
    vmax = nnvb_results_summary[metric].max()

    f, axs = plt.subplots(1, 3, figsize=(7 * 2, 6), gridspec_kw={'width_ratios': [1, 1, 0.04]})
    # f.suptitle(f"{metric}", fontsize=20)
    for i, (ax, (model_label, rows)) in enumerate(zip(axs, nnvb_results_summary.groupby("model_label"))):
        cbar_ax = None
        if i == 1:
            cbar_ax = axs.flat[-1]

        ax.set_title(model_label)
        sns.heatmap(rows.set_index(["source_label", "target_label"]).sort_index()[metric].unstack("target_label"),
                    vmin=vmin, vmax=vmax,
                    annot=True, fmt=".2g", ax=ax,
                    cbar=i == 1, cbar_ax=cbar_ax)

        ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
        ax.set_ylabel("Train")
        ax.set_xlabel("Test")

    f.tight_layout()
    f.savefig(f"{output_dir}/nnvb_allomorphs-{metric}.pdf", bbox_inches="tight")

In [None]:
nnvb_results_summary2 = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to"]) \
    [plot_metrics].mean().reset_index().astype({"correct": float})

nnvb_results_summary2["transfer_label"] = nnvb_results_summary2.inflection_from + " -> " + nnvb_results_summary2.inflection_to
nnvb_results_summary2

In [None]:
from matplotlib.gridspec import GridSpec


for metric in plot_metrics:
    vmin = 0 if metric == "correct" else nnvb_results_summary[metric].min()
    vmax = nnvb_results_summary2[metric].max()

    f = plt.figure(figsize=(3 * 2 + 2, 3))
    gs = GridSpec(1, 3, width_ratios=[1, 1, 0.08],
                  wspace=0.55)

    axs = [f.add_subplot(gs[0]), f.add_subplot(gs[1])]
    cbar_ax = f.add_subplot(gs[2])

    for i, (ax, (model_label, rows)) in enumerate(zip(axs, nnvb_results_summary2.groupby("model_label"))):
        cbar_ax_i = None
        if i == 1:
            cbar_ax_i = cbar_ax
        
        ax.set_title(model_label)
        sns.heatmap(rows.set_index(["inflection_from", "inflection_to"])[metric].unstack(),
                    annot=True, vmin=vmin, vmax=vmax, ax=ax,
                    cbar=i == 1, cbar_ax=cbar_ax)
        ax.set_xlabel("Test")
        ax.set_ylabel("Train")

    f = plt.gcf()
    f.tight_layout()
    f.savefig(f"{output_dir}/nnvb_results-{metric}.pdf", bbox_inches="tight")

In [None]:
nnvb_plot = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to", "base_to"])[plot_metrics].mean().reset_index() \
    .melt(id_vars=["model_label", "inflection_from", "inflection_to", "base_to"],
          value_vars=plot_metrics, var_name="metric", value_name="value")
nnvb_plot["transfer_label"] = nnvb_plot.inflection_from + " -> " + nnvb_plot.inflection_to
g = sns.catplot(data=nnvb_plot, x="transfer_label", y="value", kind="bar", hue="model_label", row="metric", sharey="row",
                errorbar="se", height=3, aspect=2.5)
g._legend.set_title("Model")
ax = g.axes.flat[0]

ax.set_xlabel("Evaluation")
ax.set_ylabel("Rank")

In [None]:
f, ax = plt.subplots(figsize=(9, 8))

transfer_label_pretty = {
    "NNS -> NNS": "NNS → NNS",
    "VBZ -> NNS": "VBZ → NNS",
    "NNS -> VBZ": "NNS → VBZ",
    "VBZ -> VBZ": "VBZ → VBZ",
}
ax = sns.barplot(data=nnvb_plot.set_index(sorted(set(nnvb_plot.columns) - {"value"}))
                    .value.unstack("metric").reset_index()
                    .assign(transfer_label=lambda xs: xs.transfer_label.map(transfer_label_pretty)),
                 x="gt_label_rank", y="transfer_label", hue="model_label",
                 errorbar="se", ax=ax)
ax.set_ylabel(None)
ax.set_xlabel("Rank")
ax.legend(title="Model")

### Regression analysis

In [None]:
def get_interaction_strength(rows):
    rows["correct"] = rows.correct.astype(int)
    
    # exclude rare
    rows = rows[~((rows.inflection_from == "VBZ") & rows.inflection_from == "IH Z") &
                ~((rows.inflection_to == "VBZ") & rows.inflection_to == "IH Z")]
    
    # standardize frequency
    rows["from_freq"] = (rows.from_freq - rows.from_freq.mean()) / rows.from_freq.std()
    rows["to_freq"] = (rows.to_freq - rows.to_freq.mean()) / rows.to_freq.std()
    
    # formula = "correct ~ C(inflection_from, Treatment(reference='NNS')) * C(inflection_to, Treatment(reference='NNS')) + " \
    #           "C(allomorph_from, Treatment(reference='Z')) * C(allomorph_to, Treatment(reference='Z')) +" \
    #           "from_freq + to_freq"
    # model = logit(formula, data=rows).fit()

    formula = "gt_label_rank ~ C(inflection_from, Treatment(reference='NNS')) * C(inflection_to, Treatment(reference='NNS')) * " \
              "C(allomorph_from, Treatment(reference='Z')) * C(allomorph_to, Treatment(reference='Z')) +" \
              "from_freq + to_freq"
    # fit OLS, remove outliers
    model = ols(formula, data=rows[rows.gt_label_rank < np.percentile(rows.gt_label_rank, 90)]).fit()

    return model.params

In [None]:
interaction_model_fits = all_nnvb_results.groupby(run_groupers).apply(get_interaction_strength) \
    .reset_index().melt(id_vars=run_groupers, value_name="coef_norm")
interaction_strengths = interaction_model_fits[interaction_model_fits.variable.str.contains(":", regex=True)]
interaction_strengths = interaction_strengths.groupby(run_groupers).coef_norm.apply(lambda xs: np.linalg.norm(xs, ord=1)).sort_values()

In [None]:
interaction_model_fits.query("base_model_name == 'w2v2_8' and model_name in ['ffff_32', 'ff_32', 'id']") \
    .assign(inter=lambda xs: xs.variable.str.contains(":"),
            variable=lambda xs: xs.variable.str.replace(r", Treatment\(reference='Z'\)|, Treatment\(reference='NNS'\)", r"", regex=True))

In [None]:
spaghetti_data = interaction_model_fits[interaction_model_fits.model_name.isin(("ffff_32", "id")) &
                                        (interaction_model_fits.base_model_name == "w2v2_8") &
                                        (interaction_model_fits.variable.str.contains(":"))] \
    .pivot_table(index="variable", columns="model_name", values="coef_norm")
# reorder
assert set(spaghetti_data.columns) == {"ffff_32", "id"}
spaghetti_data = spaghetti_data[["id", "ffff_32"]]
spaghetti_data.columns = spaghetti_data.columns.map({"id": "Wav2Vec", main_plot_run[1]: "Word"})
spaghetti_data = np.abs(spaghetti_data)

f, ax = plt.subplots(figsize=(5, 3))
for variable, row in spaghetti_data.iterrows():
    ax.plot(np.arange(len(row)), row, color=sns.color_palette()[0], marker="o", alpha=0.5)
ax.set_xticks(np.arange(len(row)))
ax.set_xticklabels(spaghetti_data.columns)
ax.set_xlim((-0.25, 1.25))
ax.set_xlabel("Model")
ax.set_ylabel("Interaction\nstrength")

f.tight_layout()
f.savefig(f"{output_dir}/interaction_strength_spaghetti.pdf", bbox_inches="tight")

In [None]:
plot_is = interaction_strengths.reindex(plot_runs).reset_index()
plot_is["layer"] = plot_is.base_model_name.str.extract(r"_(\d+)$").astype(int)
plot_is["model_name"] = plot_is["model_name"].map({"id": "Wav2Vec", main_plot_run[1]: "Word"})

g = sns.catplot(data=plot_is[plot_is.layer == 8],
                x="model_name", y="coef_norm", order=["Wav2Vec", "Word"],
                kind="bar", height=3, aspect=2)
g.axes.flat[0].set_ylabel("Allomorph/inflection\ninteraction strength", fontsize=14)
g.axes.flat[0].set_xlabel("Model")

g.tight_layout()
g.savefig(f"{output_dir}/interaction_strength.pdf", bbox_inches="tight")

### Correct for verb paradigm size

In [None]:
# study_df = nnvb_focus[(nnvb_focus.inflection_to == "VBZ")].copy()

In [None]:
# study_df["predicted_stem"] = study_df.predicted_label.str.replace(r"s$|ed$|ings?$", "", regex=True)
# study_df["base_to_stem"] = study_df.base_to.str.replace(r"e$", "", regex=True).replace(r"y$", "i", regex=True)
# study_df["predicted_within_inflection"] = \
#     (study_df.predicted_stem == study_df.base_to) | (study_df.predicted_stem == study_df.base_to_stem)
# vb_irregulars = [("do", "did"), ("do", "does"), ("begin", "began"), ("learn", "learnt"), ("send", "sent"), ("shine", "shone"), ("seem", "seem'd"), ("read", "red"),
#                  ("possess", "possesses"), ("bring", "brings"), ("carry", "carries"), ("occur", "occurred"), ("think", "thinkest"), ("grow", "grew"),
#                  ("put", "putting"), ("begin", "beginning"), ("give", "givest"),
#                  # homophones
#                  ("allow", "aloud"), ("write", "rights"), ("write", "wright's"), ("depend", "dependent")]
# for base, predicted in vb_irregulars:
#     study_df.loc[study_df.base_to == base, "predicted_within_inflection"] |= study_df.loc[study_df.base_to == base].predicted_label == predicted

In [None]:
# merge_keys = ["experiment", "equivalence", "model_label", "model_name", "base_model_name", "group", "inflection_from", "inflection_to", "base_from", "inflected_from", "base_to", "inflected_to"]
# nnvb_focus = pd.merge(nnvb_focus.reset_index(), study_df.reset_index()[merge_keys + ["predicted_within_inflection"]],
#          on=merge_keys, how="left")

In [None]:
# nnvb_focus["correct_or_predicted_within_inflection"] = nnvb_focus.correct | nnvb_focus.predicted_within_inflection

In [None]:
# nnvb_results_summary2 = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to"]) \
#     [["correct_or_predicted_within_inflection", "correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

# nnvb_results_summary2["transfer_label"] = nnvb_results_summary2.inflection_from + " -> " + nnvb_results_summary2.inflection_to
# nnvb_results_summary2

In [None]:
# f, axs = plt.subplots(1, 3, figsize=(4 * 2, 3), gridspec_kw={'width_ratios': [1, 1, 0.04]})

# vmin = 0
# vmax = nnvb_results_summary2.correct_or_predicted_within_inflection.max()

# for i, (ax, (model_label, rows)) in enumerate(zip(axs, nnvb_results_summary2.groupby("model_label"))):
#     cbar_ax = None
#     if i == 1:
#         cbar_ax = axs.flat[-1]
    
#     ax.set_title(model_label)
#     sns.heatmap(rows.set_index(["inflection_from", "inflection_to"]).correct_or_predicted_within_inflection.unstack(),
#                 vmin=vmin, vmax=vmax,
#                 annot=True, ax=ax, cbar=i == 1, cbar_ax=cbar_ax)
#     ax.set_xlabel("Test")
#     ax.set_ylabel("Train")

# f = plt.gcf()
# f.tight_layout()
# f.savefig(f"{output_dir}/nnvb_results-correct_inflection.pdf")

In [None]:
# nnvb_plot = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to", "base_to"]).correct_or_predicted_within_inflection.mean().reset_index()
# nnvb_plot["transfer_label"] = nnvb_plot.inflection_from + " -> " + nnvb_plot.inflection_to
# order = ["NNS -> VBZ", "VBZ -> VBZ", "VBZ -> NNS", "NNS -> NNS"]
# g = sns.catplot(data=nnvb_plot, x="transfer_label", y="correct_or_predicted_within_inflection", kind="bar", hue="model_label", order=order, errorbar="se", height=4, aspect=2)
# g._legend.set_title("Model")
# ax = g.axes.flat[0]

# ax.set_xlabel("Evaluation")
# ax.set_ylabel("Accuracy")

### Frequency analysis

In [None]:
sns.catplot(data=nnvb_focus.reset_index(),
            x="from_freq_bin", y="correct", hue="model_name",
            row="inflection_from", col="inflection_to", units="base_from", kind="point")

In [None]:
sns.catplot(data=nnvb_focus.reset_index(),
            x="to_freq_bin", y="correct", hue="model_name",
            row="inflection_from", col="inflection_to", units="base_to", kind="point")

#### On rank

In [None]:
sns.catplot(data=nnvb_focus.reset_index().query("model_name == @focus_model"),
            x="from_freq_bin", y="gt_label_rank", hue="model_name",
            row="inflection_from", col="inflection_to",
            units="base_from", kind="point",
            sharey=False)

In [None]:
sns.catplot(data=nnvb_focus.reset_index().query("model_name == @focus_model"),
            x="to_freq_bin", y="gt_label_rank", hue="model_name",
            row="inflection_from", col="inflection_to",
            sharey=False,
            units="base_to", kind="point")

## Controlled VBD analysis

In [None]:
all_vbd_results = all_results.query("experiment == 'regular' and inflection_from == 'VBD'")
all_vbd_results = pd.merge(all_vbd_results, most_common_allomorphs.rename(columns={"base": "base_from", "inflection": "inflection_from", "most_common_allomorph": "allomorph_from"}),
               on=["base_from", "inflection_from"], how="left")
all_vbd_results = pd.merge(all_vbd_results, most_common_allomorphs.rename(columns={"base": "base_to", "inflection": "inflection_to", "most_common_allomorph": "allomorph_to"}),
               on=["base_to", "inflection_to"], how="left")
all_vbd_results[["allomorph_from", "allomorph_to"]].value_counts()

In [None]:
# drop do -> did; how did these irregulars get in here?
exclude_vbd_base = ["do", "hide", "make"]
# exclude stress shift items; we don't know which token items have which stress
exclude_vbd_base += "record permit protest reject subject conduct contract conflict increase decrease contest insult impact address escort".split()
# multiple possible
exclude_vbd_base += "dream leap".split()
all_vbd_results = all_vbd_results[~all_vbd_results.base_from.isin(exclude_vbd_base) & ~all_vbd_results.base_to.isin(exclude_vbd_base)]

In [None]:
keep_vbd_allomorphs = all_vbd_results.allomorph_from.value_counts().head(3).index
all_vbd_results = all_vbd_results[all_vbd_results.allomorph_from.isin(keep_vbd_allomorphs)
                                  & all_vbd_results.allomorph_to.isin(keep_vbd_allomorphs)]

In [None]:
# Add frequency information

all_vbd_results["inflected_from"] = all_vbd_results.from_equiv_label.apply(lambda x: eval(x)[1])
all_vbd_results["inflected_to"] = all_vbd_results.to_equiv_label.apply(lambda x: eval(x)[1])

all_vbd_results = pd.merge(all_vbd_results, word_freq_df["LogFreq"].rename("from_base_freq"),
                           left_on="base_from", right_index=True)
all_vbd_results = pd.merge(all_vbd_results, word_freq_df["LogFreq"].rename("to_base_freq"),
                           left_on="base_to", right_index=True)
all_vbd_results = pd.merge(all_vbd_results, word_freq_df["LogFreq"].rename("from_inflected_freq"),
                            left_on="inflected_from", right_index=True)
all_vbd_results = pd.merge(all_vbd_results, word_freq_df["LogFreq"].rename("to_inflected_freq"),
                            left_on="inflected_to", right_index=True)

all_vbd_results["from_freq"] = all_vbd_results[["from_base_freq", "from_inflected_freq"]].mean(axis=1)
all_vbd_results["to_freq"] = all_vbd_results[["to_base_freq", "to_inflected_freq"]].mean(axis=1)

_, vbd_frequency_bins = pd.qcut(pd.concat([all_vbd_results.to_freq, all_vbd_results.from_freq]), q=3, retbins=True)
all_vbd_results["from_freq_bin"] = pd.cut(all_vbd_results.from_freq, bins=vbd_frequency_bins, labels=[f"Q{i}" for i in range(1, 4)])
all_vbd_results["to_freq_bin"] = pd.cut(all_vbd_results.to_freq, bins=vbd_frequency_bins, labels=[f"Q{i}" for i in range(1, 4)])

In [None]:
def summarize_vbd_run(rows):
    rows["source_label"] = rows.inflection_from + " " + rows.allomorph_from
    rows["target_label"] = rows.inflection_to + " " + rows.allomorph_to

    rows["transfer_label"] = rows.inflection_from + " -> " + rows.inflection_to
    rows["phon_label"] = rows.allomorph_from + " -> " + rows.allomorph_to

    return rows

summary_groupers = ["inflection_from", "inflection_to", "allomorph_from", "allomorph_to"]
vbd_results_summary = all_vbd_results.groupby(run_groupers + summary_groupers) \
    [plot_metrics].mean() \
    .reset_index(summary_groupers) \
    .groupby(run_groupers, group_keys=False) \
    .apply(summarize_vbd_run) \
    .reset_index()

vbd_results_summary

In [None]:
# plot_results = []
# for base_model_name, model_name, equivalence in plot_runs:
#     results_i = vbd_results_summary.query("base_model_name == @base_model_name and model_name == @model_name and equivalence == @equivalence")
#     if len(results_i) > 0:
#         plot_results.append(results_i)
# num_plot_runs = len(plot_results)

# ncols = 4
# nrows = int(np.ceil(num_plot_runs / ncols))
# f, axs = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))

# for ax, results_i in zip(axs.flat, plot_results):
#     sns.heatmap(results_i.set_index(["source_label", "target_label"]).correct.unstack(),
#                 vmin=0, vmax=1, ax=ax)
#     key_row = results_i.iloc[0]
#     ax.set_title(f"{key_row.base_model_name} -> {key_row.model_name} ({key_row.equivalence})")

In [None]:
vbd_focus = all_vbd_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")
vbd_foil = all_vbd_results.query("base_model_name == @foil_base_model and model_name == @foil_model and equivalence == @foil_equivalence")
vbd_focus["model_label"] = "Word"
vbd_foil["model_label"] = "Wav2Vec"
vbd_focus = pd.concat([vbd_focus, vbd_foil])

vbd_results_summary2 = vbd_focus.groupby(["model_label", "inflection_from", "inflection_to"]) \
    [plot_metrics].mean().reset_index()

vbd_results_summary2["transfer_label"] = vbd_results_summary2.inflection_from + " -> " + vbd_results_summary2.inflection_to
vbd_results_summary2

In [None]:
for metric in plot_metrics:
    vmin = 0 if metric == "correct" else nnvb_results_summary[metric].min()
    vmax = vbd_results_summary2[metric].max()

    f, axs = plt.subplots(1, 3, figsize=(4 * 2, 3), gridspec_kw={'width_ratios': [1, 1, 0.04]})

    for i, (ax, (model_label, rows)) in enumerate(zip(axs, vbd_results_summary2.groupby("model_label"))):
        cbar_ax = None
        if i == 1:
            cbar_ax = axs.flat[-1]
        
        ax.set_title(model_label)
        sns.heatmap(rows.set_index(["inflection_from", "inflection_to"])[metric].unstack(),
                    annot=True, vmin=vmin, vmax=vmax, ax=ax,
                    cbar=i == 1, cbar_ax=cbar_ax)
        ax.set_xlabel("Test")
        ax.set_ylabel("Train")

    f = plt.gcf()
    f.tight_layout()
    f.savefig(f"{output_dir}/vbd_results-{metric}.pdf")

### Focused plots

In [None]:
focus_base_model, focus_model, focus_equivalence = main_plot_run
foil_base_model, foil_model, foil_equivalence = "w2v2_8", "id", "id"

vbd_focus = all_vbd_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")
vbd_foil = all_vbd_results.query("base_model_name == @foil_base_model and model_name == @foil_model and equivalence == @foil_equivalence")
vbd_focus["model_label"] = "Word"
vbd_foil["model_label"] = "Wav2Vec"

vbd_focus = pd.concat([vbd_focus, vbd_foil])

allomorph_labels = {"D": "d", "T": "t", "IH D": "ɪd"}
vbd_focus["allomorph_from"] = vbd_focus.allomorph_from.map(allomorph_labels)
vbd_focus["allomorph_to"] = vbd_focus.allomorph_to.map(allomorph_labels)
vbd_focus

In [None]:
vbd_focus.query("inflection_to == 'VBD' and allomorph_to == 'ɪd'").groupby(["model_name", "base_to"]).gt_label_rank.mean().sort_values().tail(20)

In [None]:
vbd_focus.query("inflection_to == 'VBD' and allomorph_to == 't'").groupby(["model_name", "base_to"]).gt_label_rank.mean().sort_values().tail(20)

In [None]:
vbd_focus.query("inflection_to == 'VBD' and allomorph_to == 'd'").groupby(["model_name", "base_to"]).gt_label_rank.mean().sort_values()

In [None]:
vbd_results_summary = vbd_focus.groupby(["model_label", "inflection_from", "inflection_to",
                                             "allomorph_from", "allomorph_to"]) \
    [plot_metrics].mean() \
    .reset_index()

vbd_results_summary["source_label"] = vbd_results_summary.inflection_from + "\n" + vbd_results_summary.allomorph_from
vbd_results_summary["target_label"] = vbd_results_summary.inflection_to + "\n" + vbd_results_summary.allomorph_to

vbd_results_summary["transfer_label"] = vbd_results_summary.inflection_from + " -> " + vbd_results_summary.inflection_to
vbd_results_summary["phon_label"] = vbd_results_summary.allomorph_from + " " + vbd_results_summary.allomorph_to

# only retain cases where we have data in both transfer directions from source <-> target within inflection
vbd_results_summary["complement_exists"] = vbd_results_summary.apply(lambda row: len(vbd_results_summary.query("source_label == @row.target_label and target_label == @row.source_label")), axis=1)
vbd_results_summary = vbd_results_summary.query("complement_exists > 0").drop(columns=["complement_exists"])

vbd_results_summary

In [None]:
sns.catplot(data=vbd_focus.melt(id_vars=["inflection_from", "allomorph_from", "inflection_to", "model_label"], value_vars=plot_metrics)
                        .assign(source_label=lambda xs: xs.inflection_from + " " + xs.allomorph_from),
                x="inflection_to", hue="source_label", y="value", col="model_label", row="variable", kind="bar", sharey="row")

In [None]:
for metric in plot_metrics:
    vmin = 0 if metric == "correct" else nnvb_results_summary[metric].min()
    vmax = vbd_results_summary[metric].max()

    f, axs = plt.subplots(1, 3, figsize=(7 * 2, 6), gridspec_kw={'width_ratios': [1, 1, 0.04]})
    f.suptitle(f"{metric}", fontsize=20)
    for i, (ax, (model_label, rows)) in enumerate(zip(axs, vbd_results_summary.groupby("model_label"))):
        cbar_ax = None
        if i == 1:
            cbar_ax = axs.flat[-1]

        ax.set_title(model_label)
        sns.heatmap(rows.set_index(["source_label", "target_label"]).sort_index()[metric].unstack("target_label"),
                    vmin=vmin, vmax=vmax,
                    annot=True, ax=ax,
                    cbar=i == 1, cbar_ax=cbar_ax)

        ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
        ax.set_ylabel("Train")
        ax.set_xlabel("Test")

    f.tight_layout()
    f.savefig(f"{output_dir}/vbd_allomorphs-{metric}.pdf")

### Regression analysis

In [None]:
def get_vbd_interaction_strength(rows):
    rows["correct"] = rows.correct.astype(int)
    
    # exclude rare
    rows = rows[~((rows.inflection_from == "VBZ") & rows.inflection_from == "IH Z") &
                ~((rows.inflection_to == "VBZ") & rows.inflection_to == "IH Z")]
    
    # standardize frequency
    rows["from_freq"] = (rows.from_freq - rows.from_freq.mean()) / rows.from_freq.std()
    rows["to_freq"] = (rows.to_freq - rows.to_freq.mean()) / rows.to_freq.std()
    
    # formula = "correct ~ C(inflection_from, Treatment(reference='NNS')) * C(inflection_to, Treatment(reference='NNS')) + " \
    #           "C(allomorph_from, Treatment(reference='Z')) * C(allomorph_to, Treatment(reference='Z')) +" \
    #           "from_freq + to_freq"
    # model = logit(formula, data=rows).fit()

    formula = "gt_label_rank ~ C(allomorph_from, Treatment(reference='T')) * C(allomorph_to, Treatment(reference='T')) +" \
              "from_freq + to_freq"
    # fit OLS, remove outliers
    model = ols(formula, data=rows[rows.gt_label_rank < np.percentile(rows.gt_label_rank, 90)]).fit()

    return model.params

In [None]:
vbd_interaction_model_fits = all_vbd_results.groupby(run_groupers).apply(get_vbd_interaction_strength) \
    .reset_index().melt(id_vars=run_groupers, value_name="coef_norm")
vbd_interaction_model_fits = vbd_interaction_model_fits[vbd_interaction_model_fits.variable.str.contains(":", regex=True)]
vbd_interaction_model_fits = vbd_interaction_model_fits.groupby(run_groupers).coef_norm.apply(lambda xs: np.linalg.norm(xs, ord=1)).sort_values()

In [None]:
plot_is = vbd_interaction_model_fits.reindex(plot_runs).reset_index()
plot_is["layer"] = plot_is.base_model_name.str.extract(r"_(\d+)$").astype(int)

g = sns.catplot(data=plot_is[plot_is.layer == 8], x="model_name", y="coef_norm", kind="bar", height=3, aspect=2)
g.axes.flat[0].set_ylabel("Allomorph/inflection\ninteraction strength")
g.axes.flat[0].set_xlabel("Model")

g.savefig(f"{output_dir}/vbd_interaction_strength.pdf")

## False friend analysis

In [None]:
all_ff_results = []

for run, run_results in all_results.groupby(run_groupers):
    run_results = run_results.set_index("experiment")
    false_friend_expts = run_results.index.unique()
    false_friend_expts = false_friend_expts[false_friend_expts.str.contains("FF")]

    for expt_name in false_friend_expts:
        expt_df = run_results.loc[expt_name].copy()
        num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))

        if num_seen_words < 10:
            print(f"Skipping {expt} due to only {num_seen_words} seen words")
            continue

        if expt_name.count("-FF-") == 2:
            allomorph_from, allomorph_to = re.findall(r"-FF-([\w\s]+)-to-.+FF-([\w\s]+)", expt_name)[0]
            ff_from, ff_to = True, True
        else:
            try:
                allomorph_from, allomorph_to = re.findall(r"_([\w\s]+)-to-.+FF-([\w\s]+)", expt_name)[0]
                # is the false friend on the "from" side?
                ff_from, ff_to = False, True
            except:
                allomorph_from, allomorph_to = re.findall(r".+FF-([\w\s]+)-to-.+_([\w\s]+)", expt_name)[0]
                ff_from, ff_to = True, False

        expt_df["allomorph_from"] = allomorph_from
        expt_df["allomorph_to"] = allomorph_to

        if ff_from:
            expt_df["inflection_from"] = expt_df.inflection_from.str.replace("-FF-.+", "-FF", regex=True)
        if ff_to:
            expt_df["inflection_to"] = expt_df.inflection_to.str.replace("-FF-.+", "-FF", regex=True)

        expt_df["ff_from"] = ff_from
        expt_df["ff_to"] = ff_to

        all_ff_results.append(expt_df)

    # add within-false-friend tests
    expt_df = run_results.loc["false_friends"].copy()
    expt_df["allomorph_from"] = expt_df.inflection_from.str.extract(r"FF-(.+)$")
    expt_df["allomorph_to"] = expt_df.inflection_to.str.extract(r"FF-(.+)$")
    expt_df["inflection_from"] = expt_df.inflection_from.str.replace("-FF-.+", "-FF", regex=True)
    expt_df["inflection_to"] = expt_df.inflection_to.str.replace("-FF-.+", "-FF", regex=True)
    expt_df["ff_from"] = True
    expt_df["ff_to"] = True

    all_ff_results.append(expt_df)

    # expt_df = expt_df[expt_df.inflection_from.isin(all_ff_results.inflection_from.unique())]

all_ff_results = pd.concat(all_ff_results).reset_index()

ff_exclude = "a b c wreck d e eh wandering lo chiu ha hahn meek jew shew ah co des re san ol der k la ye ll"
ff_exclude_inflected = "bunce los oft mast hauled sward"

# exclude FF bases
all_ff_results = all_ff_results[~(all_ff_results.inflection_from.str.endswith("-FF") & all_ff_results.base_from.isin(ff_exclude.split()))]
all_ff_results = all_ff_results[~(all_ff_results.inflection_to.str.endswith("-FF") & all_ff_results.base_to.isin(ff_exclude.split()))]

all_ff_results["inflected_from"] = all_ff_results.from_equiv_label.apply(lambda x: eval(x)[1])
all_ff_results["inflected_to"] = all_ff_results.to_equiv_label.apply(lambda x: eval(x)[1])

# exclude FF inflected
all_ff_results = all_ff_results[~(all_ff_results.inflection_from.str.endswith("-FF") & all_ff_results.inflected_from.isin(ff_exclude_inflected.split()))]
all_ff_results = all_ff_results[~(all_ff_results.inflection_to.str.endswith("-FF") & all_ff_results.inflected_to.isin(ff_exclude_inflected.split()))]

# add frequency information
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("from_base_freq"), left_on="base_from", right_index=True)
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("to_base_freq"), left_on="base_to", right_index=True)
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("from_inflected_freq"), left_on="inflected_from", right_index=True)
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("to_inflected_freq"), left_on="inflected_to", right_index=True)
all_ff_results["from_freq"] = all_ff_results[["from_base_freq", "from_inflected_freq"]].mean(axis=1)
all_ff_results["to_freq"] = all_ff_results[["to_base_freq", "to_inflected_freq"]].mean(axis=1)

all_ff_results["transfer_label"] = all_ff_results.inflection_from + " -> " + all_ff_results.inflection_to

In [None]:
# Post-hoc fix some bugs
all_ff_results.loc[(all_ff_results.base_to == "tho") & (all_ff_results.predicted_label == "though") & (all_ff_results.gt_label_rank == 1), "correct"] = True
all_ff_results.loc[(all_ff_results.base_to == "philip") & (all_ff_results.predicted_label == "philip's"), "correct"] = True
all_ff_results.loc[(all_ff_results.base_to == "adam") & (all_ff_results.predicted_label == "adam's"), "correct"] = True
all_ff_results.loc[(all_ff_results.base_to == "who") & (all_ff_results.predicted_label == "who's"), "correct"] = True

In [None]:
# exclude bad VBD bases
print("before filtering: ", len(all_ff_results))
all_ff_results = all_ff_results[~(((all_ff_results.inflection_to == "VBD") & all_ff_results.base_to.isin(exclude_vbd_base)) |
                                  ((all_ff_results.inflection_from == "VBD") & all_ff_results.base_from.isin(exclude_vbd_base)))]
print("after filtering: ", len(all_ff_results))

In [None]:
false_friend_strong_lookup = false_friends_df.set_index(["base", "inflected", "post_divergence"]).strong.to_dict()

In [None]:
def get_is_strong(rows):
    keys = []
    row = rows.iloc[0]
    if "-FF" in row.inflection_from:
        keys.append((row.base_from, row.inflected_from, row.allomorph_from))
    if "-FF" in row.inflection_to:
        keys.append((row.base_to, row.inflected_to, row.allomorph_to))

    # print(keys)
    strong_results = [false_friend_strong_lookup[base, inflected, allomorph] for base, inflected, allomorph in keys]
    return all(strong_results)

strong_grouper = ["inflection_from", "inflection_to", "inflected_from", "inflected_to", "base_from", "base_to", "allomorph_from", "allomorph_to"]
strong_values = all_ff_results.groupby(strong_grouper).apply(get_is_strong).rename("is_strong")
all_ff_results = pd.merge(all_ff_results, strong_values, left_on=strong_grouper, right_index=True)

In [None]:
weak_ff_results = all_ff_results[~all_ff_results.is_strong]

# ONLY STRONG
all_ff_results = all_ff_results[all_ff_results.is_strong]

In [None]:
all_ff_results.query("inflection_from == 'VBD-FF' & inflection_to == 'VBD' and model_name == 'ffff_32' and base_model_name == 'w2v2_8'") \
    .groupby("base_to").gt_label_rank.mean().sort_values().tail(20)

### Main FF analysis

In [None]:
ff_frequency_bins = pd.qcut(pd.concat([all_ff_results.to_freq, all_ff_results.from_freq]), q=3, retbins=True)[1]
all_ff_results["from_freq_bin"] = pd.cut(all_ff_results.from_freq, bins=ff_frequency_bins, labels=[f"Q{i}" for i in range(1, 4)])
all_ff_results["to_freq_bin"] = pd.cut(all_ff_results.to_freq, bins=ff_frequency_bins, labels=[f"Q{i}" for i in range(1, 4)])

In [None]:
# Compare distribution of false friend word frequencies to distribution of NN/VB frequencies.
# This is to see if the false friends are more likely to be rare words.
false_friend_words = pd.concat([all_ff_results.query("ff_from").base_from, all_ff_results.query("ff_to").base_to]).unique()
nn_words = pd.concat([all_nnvb_results.query("inflection_from == 'NNS'").base_from,
                        all_nnvb_results.query("inflection_to == 'NNS'").base_to]).unique()
vb_words = pd.concat([all_nnvb_results.query("inflection_from == 'VBZ'").base_from,
                        all_nnvb_results.query("inflection_to == 'VBZ'").base_to]).unique()

In [None]:
expt_word_freqs = pd.concat({
    "false_friends": word_freq_df.loc[false_friend_words].LogFreq,
    "NN": word_freq_df.loc[nn_words].LogFreq,
    "VB": word_freq_df.loc[vb_words].LogFreq
}, names=["type"])

In [None]:
sns.displot(data=expt_word_freqs.reset_index(), x="LogFreq", row="type", kind="hist", bins=15,
            height=3, aspect=3, facet_kws={"sharey": False})

In [None]:
focus_ff_results = all_ff_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")
foil_ff_results = all_ff_results.query("base_model_name == @foil_base_model and model_name == @foil_model and equivalence == @foil_equivalence")

In [None]:
ff_results_summary2 = pd.concat([focus_ff_results, foil_ff_results]).groupby(["model_name", "inflection_from", "inflection_to"]) \
    [plot_metrics].mean().reset_index()
ff_results_summary2["transfer_label"] = ff_results_summary2.inflection_from + " -> " + ff_results_summary2.inflection_to
assert set(ff_results_summary2.model_name) == {"ffff_32", "id"}
ff_results_summary2["model_label"] = ff_results_summary2["model_name"].map({"id": "Wav2Vec", "ffff_32": "Word"})
ff_results_summary2 = ff_results_summary2.drop(columns=["model_name"])

ff_results_summary2 = pd.concat([ff_results_summary2,
    nnvb_results_summary2.query("inflection_from == inflection_to"),
    vbd_results_summary2.query("inflection_from == inflection_to")])

ff_results_summary2["base_inflection"] = ff_results_summary2.inflection_from.str.replace("-FF", "")
ff_results_summary2 = ff_results_summary2[ff_results_summary2.base_inflection.isin(plot_inflections)]

In [None]:
g = sns.FacetGrid(data=ff_results_summary2, row="model_label", col="base_inflection",
                  row_order=["Wav2Vec", "Word"], sharex=False, sharey=False)
plot_variable = "gt_label_rank"

vmin = ff_results_summary2[plot_variable].min()
vmax = ff_results_summary2[plot_variable].max()

def mapfn(data, **kwargs):
    ax = plt.gca()
    sns.heatmap(data.set_index(["inflection_from", "inflection_to"])[plot_variable].unstack("inflection_to"),
                vmin=vmin, vmax=vmax, annot=True, fmt=".2g", ax=ax, cbar=False)

g.map_dataframe(mapfn)

for i, row in enumerate(g.axes):
    for j, ax in enumerate(row):
        ax.set_title(ax.get_title().replace("base_inflection = ", "").replace("model_label = ", ""))
        if j > 0:
            ax.set_ylabel("")
        else:
            ax.set_ylabel("Train")
        
        if i < len(g.axes) - 1:
            ax.set_xlabel("")
        else:
            ax.set_xlabel("Test")

# add colorbar
cbar_ax = g.fig.add_axes([0.99, 0.18, 0.03, 0.7])
cbar = plt.colorbar(g.axes.flat[0].collections[0], cax=cbar_ax)
cbar.outline.set_visible(False)

g.tight_layout()
g.savefig(f"{output_dir}/ff_results-merged.pdf", bbox_inches="tight")

In [None]:
ff_results_summary2 = focus_ff_results.groupby(["inflection_from", "inflection_to"]) \
    [plot_metrics].mean().reset_index()

ff_results_summary2["transfer_label"] = ff_results_summary2.inflection_from + " -> " + ff_results_summary2.inflection_to

# add in data for NNS->NNS and VBZ->VBZ
ff_results_summary2 = pd.concat([
    ff_results_summary2,
    nnvb_results_summary2.query("inflection_from == inflection_to and model_label == 'Word'"),
    vbd_results_summary2.query("inflection_from == inflection_to and model_label == 'Word'")], axis=0)

ff_results_summary2["base_inflection"] = ff_results_summary2.inflection_from.str.replace("-FF", "")
ff_results_summary2 = ff_results_summary2[ff_results_summary2.base_inflection.isin(plot_inflections)]

# g = sns.FacetGrid(ff_results_summary2, col="base_inflection", sharex=False, sharey=False)
g = sns.FacetGrid(data=ff_results_summary2.melt(id_vars=["inflection_from", "inflection_to", "transfer_label", "base_inflection"],
                                                value_vars=plot_metrics),
                    col="base_inflection", row="variable", sharex=False, sharey=False,
                    height=3, aspect=1.25)

def mapfn(data, **kwargs):
    ax = plt.gca()
    metric = data.variable.iloc[0]
    vmin = 0 if metric == "correct" else data.value.min()
    vmax = data.value.max()

    sns.heatmap(data.set_index(["inflection_from", "inflection_to"]).value.unstack("inflection_to"),
                vmin=vmin, vmax=vmax, annot=True, ax=ax)

g.map_dataframe(mapfn)

for i, row in enumerate(g.axes):
    for j, ax in enumerate(row):
        ax.set_title(ax.get_title().replace("base_inflection = ", ""))
        if j > 0:
            ax.set_ylabel("")
        if j < len(row) - 1:
            ax.collections[0].colorbar.remove()

g.fig.tight_layout()
g.fig.savefig(f"{output_dir}/ff_results.pdf")

In [None]:
ff_foil_results_summary2 = foil_ff_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

ff_foil_results_summary2["transfer_label"] = ff_foil_results_summary2.inflection_from + " -> " + ff_foil_results_summary2.inflection_to

# add in data for NNS->NNS and VBZ->VBZ
ff_foil_results_summary2 = pd.concat([
    ff_foil_results_summary2,
    nnvb_results_summary2.query("inflection_from == inflection_to and model_label == 'Wav2Vec'"),
    vbd_results_summary2.query("inflection_from == inflection_to and model_label == 'Wav2Vec'")], axis=0)

ff_foil_results_summary2["base_inflection"] = ff_foil_results_summary2.inflection_from.str.replace("-FF", "")

ff_foil_results_summary2 = ff_foil_results_summary2[ff_foil_results_summary2.base_inflection.isin(plot_inflections)]

g = sns.FacetGrid(data=ff_foil_results_summary2.melt(id_vars=["inflection_from", "inflection_to", "transfer_label", "base_inflection"],
                                                value_vars=plot_metrics),
                    col="base_inflection", row="variable", sharex=False, sharey=False,
                    height=3, aspect=1.25)
def mapfn(data, **kwargs):
    ax = plt.gca()
    metric = data.variable.iloc[0]
    vmin = 0 if metric == "correct" else data.value.min()
    vmax = data.value.max()

    sns.heatmap(data.set_index(["inflection_from", "inflection_to"]).value.unstack("inflection_to"),
                vmin=vmin, vmax=vmax, annot=True, ax=ax)
g.map_dataframe(mapfn)

for i, row in enumerate(g.axes):
    for j, ax in enumerate(row):
        ax.set_title(ax.get_title().replace("base_inflection = ", ""))
        if j > 0:
            ax.set_ylabel("")
        if j < len(row) - 1:
            ax.collections[0].colorbar.remove()

g.fig.tight_layout()
g.fig.savefig(f"{output_dir}/ff_results-foil.pdf")

In [None]:
focus_weak_ff_results = weak_ff_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")

weak_ff_results_summary2 = focus_weak_ff_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

weak_ff_results_summary2["transfer_label"] = weak_ff_results_summary2.inflection_from + " -> " + weak_ff_results_summary2.inflection_to

# add in data for NNS->NNS and VBZ->VBZ
weak_ff_results_summary2 = pd.concat([weak_ff_results_summary2, nnvb_results_summary2.query("inflection_from == inflection_to and model_label == 'Word'")], axis=0)
# add in data for VBD->VBD
weak_ff_results_summary2 = pd.concat([weak_ff_results_summary2, vbd_results_summary2.query("inflection_from == inflection_to and model_label == 'Word'")], axis=0)

weak_ff_results_summary2["base_inflection"] = weak_ff_results_summary2.inflection_from.str.replace("-FF", "")

weak_ff_results_summary2 = weak_ff_results_summary2[weak_ff_results_summary2.base_inflection.isin(plot_inflections)]

g = sns.FacetGrid(data=weak_ff_results_summary2.melt(id_vars=["inflection_from", "inflection_to", "transfer_label", "base_inflection"],
                                                value_vars=plot_metrics),
                    col="base_inflection", row="variable", sharex=False, sharey=False,
                    height=3, aspect=1.25)
def mapfn(data, **kwargs):
    ax = plt.gca()
    metric = data.variable.iloc[0]
    vmin = 0 if metric == "correct" else data.value.min()
    vmax = data.value.max()

    sns.heatmap(data.set_index(["inflection_from", "inflection_to"]).value.unstack("inflection_to"),
                vmin=vmin, vmax=vmax, annot=True, ax=ax)
    
g.map_dataframe(mapfn)

for i, row in enumerate(g.axes):
    for j, ax in enumerate(row):
        ax.set_title(ax.get_title().replace("base_inflection = ", ""))
        if j > 0:
            ax.set_ylabel("")
        if j < len(row) - 1:
            ax.collections[0].colorbar.remove()

g.fig.tight_layout()
g.fig.savefig(f"{output_dir}/ff_results_weak.pdf")

In [None]:
foil_weak_ff_results = weak_ff_results.query("base_model_name == @foil_base_model and model_name == @foil_model and equivalence == @foil_equivalence")

weak_ff_foil_results_summary2 = foil_weak_ff_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

weak_ff_foil_results_summary2["transfer_label"] = weak_ff_foil_results_summary2.inflection_from + " -> " + weak_ff_foil_results_summary2.inflection_to

# add in data for NNS->NNS and VBZ->VBZ
weak_ff_foil_results_summary2 = pd.concat([weak_ff_foil_results_summary2, nnvb_results_summary2.query("inflection_from == inflection_to and model_label == 'Wav2Vec'")], axis=0)
# add in data for VBD->VBD
weak_ff_foil_results_summary2 = pd.concat([weak_ff_foil_results_summary2, vbd_results_summary2.query("inflection_from == inflection_to and model_label == 'Wav2Vec'")], axis=0)

weak_ff_foil_results_summary2["base_inflection"] = weak_ff_foil_results_summary2.inflection_from.str.replace("-FF", "")

weak_ff_foil_results_summary2 = weak_ff_foil_results_summary2[weak_ff_foil_results_summary2.base_inflection.isin(plot_inflections)]

g = sns.FacetGrid(data=weak_ff_foil_results_summary2.melt(id_vars=["inflection_from", "inflection_to", "transfer_label", "base_inflection"],
                                                value_vars=plot_metrics),
                    col="base_inflection", row="variable", sharex=False, sharey=False,
                    height=3, aspect=1.25)
def mapfn(data, **kwargs):
    ax = plt.gca()
    metric = data.variable.iloc[0]
    vmin = 0 if metric == "correct" else data.value.min()
    vmax = data.value.max()

    sns.heatmap(data.set_index(["inflection_from", "inflection_to"]).value.unstack("inflection_to"),
                vmin=vmin, vmax=vmax, annot=True, ax=ax)
    
g.map_dataframe(mapfn)
for i, row in enumerate(g.axes):
    for j, ax in enumerate(row):
        ax.set_title(ax.get_title().replace("base_inflection = ", ""))
        if j > 0:
            ax.set_ylabel("")
        if j < len(row) - 1:
            ax.collections[0].colorbar.remove()

g.fig.tight_layout()
g.fig.savefig(f"{output_dir}/ff_results_weak-foil.pdf")

### Frequency analysis

In [None]:
g = sns.catplot(data=focus_ff_results,
            x="from_freq_bin", y="correct", hue="inflection_from",
            col="inflection_to", col_wrap=4, units="base_from", kind="point", height=3)
g.figure.suptitle("Effect of source frequency")
for ax in g.axes.flat:
    ax.set_title(ax.get_title().split("=")[1].strip())
g.tight_layout()

In [None]:
g = sns.catplot(data=focus_ff_results,
            x="from_freq_bin", y="gt_distance", hue="inflection_from",
            col="inflection_to", col_wrap=4, units="base_from", kind="point", height=3)
g.figure.suptitle("Effect of source frequency")
for ax in g.axes.flat:
    ax.set_title(ax.get_title().split("=")[1].strip())
g.tight_layout()

In [None]:
g = sns.catplot(data=focus_ff_results,
            x="to_freq_bin", y="gt_label_rank", hue="inflection_from",
            col="inflection_to", row="model_name", units="base_to", kind="point", height=3)
g.figure.suptitle("Effect of target frequency")
for ax in g.axes.flat:
    ax.set_title(ax.get_title().split("=")[1].strip())
g.tight_layout()

In [None]:
g = sns.catplot(data=focus_ff_results,
            x="to_freq_bin", y="gt_label_rank", hue="inflection_from",
            col="inflection_to", row="model_name", units="base_to", kind="point", height=3)
g.figure.suptitle("Effect of target frequency")
for ax in g.axes.flat:
    ax.set_title(ax.get_title().split("=")[1].strip())
g.tight_layout()

## Forced-choice analysis

In [None]:
state_space = StateSpaceAnalysisSpec.from_hdf5(state_space_path)

In [None]:
def guess_nns_vbz_allomorph(base_phones):
    """
    Given a list of CMUDICT phones for a base form, 
    return the 'expected' post-divergence allomorph 
    (S, Z, or IH Z, etc.) for the English plural / 3sg verb.
    """
    last_phone = base_phones[-1]

    # Define sets or lists for final-phoneme checks
    SIBILANTS = {"S", "Z", "SH", "CH", "JH", "ZH"}
    VOICELESS = {"P", "T", "K", "F", "TH"}  # Could add others as needed
    
    if last_phone in SIBILANTS:
        # e.g., 'CH' -> "IH Z"
        return "IH Z"
    elif last_phone in VOICELESS:
        # e.g., 'K', 'P', 'T' -> "S"
        return "S"
    else:
        # default to voiced => "Z"
        return "Z"


def guess_past_allomorph(base_phones):
    """
    Given a list of CMUDICT phones for a base form,
    return the 'expected' post-divergence allomorph
    (T, D, or IH D) for the English past tense.
    """
    last_phone = base_phones[-1]

    ALVEOLAR_STOPS = {"T", "D"}
    # Example set of voiceless consonants (non-exhaustive—adjust as needed).
    VOICELESS = {"P", "F", "K", "S", "SH", "CH", "TH"}
    
    if last_phone in ALVEOLAR_STOPS:
        # E.g., "want" -> "wanted" => "AH0 D"
        return "IH D"
    elif last_phone in VOICELESS:
        # E.g., "jump" -> "jumped" => "T"
        return "T"
    else:
        # default to voiced => "D"
        return "D"
    
STRONG_GUESSERS = {
    "Z_S": guess_nns_vbz_allomorph,
    "D_T": guess_past_allomorph,
    "T_IH D": guess_past_allomorph,
    "D_IH D": guess_past_allomorph,
}

WEAK_SPACES = {
    "Z_S": {"Z", "S", "IH Z"},
    "D_T": {"D", "T", "IH D"},
    "T_IH D": {"D", "T", "IH D"},
    "D_IH D": {"D", "T", "IH D"},
}

In [None]:
cuts_df = state_space.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label", "instance_idx", "frame_idx"]).sort_index()
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [None]:
fc_exclude = "ay ba c des eh p k pa ca na b co ben been shun own".split()
fc_exclude_inflected = "look'd push'd los ince".split()

In [None]:
fc_results = all_results.loc[all_results.inflection_to.str.startswith("FC-")].copy()
fc_results = fc_results[~fc_results.base_to.isin(fc_exclude)]
fc_results["fc_pair"] = fc_results.inflection_to.str.extract(r"FC-([\w\s_]+)")
fc_results["inflected_from"] = fc_results.from_equiv_label.apply(lambda x: eval(x)[1])
fc_results["inflected_to"] = fc_results.to_equiv_label.apply(lambda x: eval(x)[1])
fc_results["layer"] = fc_results.base_model_name.str.extract(r"_(\d+)$").astype(int)
fc_results = fc_results[~fc_results.inflected_to.isin(fc_exclude_inflected)]

In [None]:
fc_metadata = all_cross_instances[all_cross_instances.inflection.str.startswith("FC-")] \
    .groupby(["inflection", "base", "inflected"]).head(1) \
        [["inflection", "base", "base_phones",
          "inflected", "inflected_phones", "inflected2", "inflected2_phones",
          "post_divergence"]] \
    .rename(columns={"base": "base_to", "base_phones": "base_to_phones",
                     "inflected": "inflected_to", "inflected_phones": "inflected_to_phones",
                     "inflected2": "inflected2_to", "inflected2_phones": "inflected2_to_phones",
                     "post_divergence": "post_divergence_to"})
fc_metadata["fc_pair"] = fc_metadata.inflection.str.extract(r"FC-([\w\s_]+)")
fc_metadata = fc_metadata.drop(columns=["inflection"])

In [None]:
fc_results = pd.merge(fc_results, fc_metadata, on=["fc_pair", "base_to", "inflected_to"])

In [None]:
fc_results["strong_expected"] = fc_results.apply(lambda xs: STRONG_GUESSERS[xs.fc_pair](xs.base_to_phones.split(" ")), axis=1)
fc_results["strong_phones"] = fc_results.apply(lambda xs: " ".join([xs.base_to_phones, STRONG_GUESSERS[xs.fc_pair](xs.base_to_phones.split(" "))]), axis=1)
fc_results["inflected_to_strong"] = fc_results.apply(lambda xs: xs.inflected_to_phones[len(xs.base_to_phones) + 1:] == xs.strong_expected, axis=1)
fc_results["inflected2_to_strong"] = fc_results.apply(lambda xs: xs.inflected2_to_phones[len(xs.base_to_phones) + 1:] == xs.strong_expected, axis=1)

In [None]:
# merge in metadata about "from" item
fc_from_metadata = all_cross_instances[all_cross_instances.inflection.isin(fc_results.inflection_from)] \
    .groupby(["inflection", "base", "inflected"]).apply(lambda xs: xs.post_divergence.value_counts().index[0]) \
    .rename("post_divergence") \
    .reset_index() \
    .rename(columns={"inflection": "inflection_from", "base": "base_from", "inflected": "inflected_from", "post_divergence": "post_divergence_from"})

fc_results = pd.merge(fc_results, fc_from_metadata, on=["inflection_from", "base_from", "inflected_from"])

In [None]:
# keep only the most frequent suffixes
keep_post_divergence_n = 4
keep_post_divergence = fc_results.groupby("fc_pair").apply(lambda xs: xs.post_divergence_from.value_counts().head(keep_post_divergence_n))
fc_results = pd.concat([
    fc_results[(fc_results.fc_pair == fc_pair) & fc_results.post_divergence_from.isin(rows.index.get_level_values("post_divergence_from"))]
    for fc_pair, rows in keep_post_divergence.groupby("fc_pair")
])

In [None]:
strong_items = pd.concat([
    fc_results.query("inflected_to_strong").groupby(["fc_pair", "base_to"]).inflected_to.apply(lambda xs: xs.head(1)).droplevel(-1),
    fc_results.query("inflected2_to_strong").groupby(["fc_pair", "base_to"]).inflected2_to.apply(lambda xs: xs.head(1)).droplevel(-1)
], axis=0)
weak_items = pd.concat([
    fc_results.query("not inflected_to_strong").groupby(["fc_pair", "base_to"]).inflected_to.apply(lambda xs: xs.head(1)).droplevel(-1),
    fc_results.query("not inflected2_to_strong").groupby(["fc_pair", "base_to"]).inflected2_to.apply(lambda xs: xs.head(1)).droplevel(-1)
], axis=0)

In [None]:
import functools
import itertools

@functools.lru_cache(maxsize=None)
def _get_strong_items(fc_pair: str, base_phones: tuple[str, ...]):
    strong_phones = " ".join([*base_phones, STRONG_GUESSERS[fc_pair](base_phones)])
    homophones = cut_phonemic_forms[cut_phonemic_forms == strong_phones].index.get_level_values("label").unique()
    return " ".join(homophones)
@functools.lru_cache(maxsize=None)
def _get_weak_items(fc_pair, base_phones: tuple[str, ...]):
    weak_suffixes = WEAK_SPACES[fc_pair] - set([STRONG_GUESSERS[fc_pair](base_phones)])
    homophones = set()
    for suffix in weak_suffixes:
        homophones |= set(cut_phonemic_forms[cut_phonemic_forms == " ".join([*base_phones, suffix])].index.get_level_values("label").unique())
    return " ".join(homophones)
    
def get_strong_items(ser):
    fc_pair, base_phones = tuple(ser)
    base_phones = tuple(base_phones.split(" "))
    return _get_strong_items(fc_pair, base_phones)
def get_weak_items(ser):
    fc_pair, base_phones = tuple(ser)
    base_phones = tuple(base_phones.split(" "))
    return _get_weak_items(fc_pair, base_phones)
fc_results["strong_item_to"] = fc_results[["fc_pair", "base_to_phones"]].apply(get_strong_items, axis=1)
fc_results["weak_item_to"] = fc_results[["fc_pair", "base_to_phones"]].apply(get_weak_items, axis=1)

In [None]:
# exclude any items for which the orthographic strong/weak forms overlap
fc_results = fc_results[fc_results.apply(lambda xs: set(xs.strong_item_to.split()).isdisjoint(set(xs.weak_item_to.split())), axis=1)]

In [None]:
fc_results["chose_strong"] = fc_results.apply(lambda xs: (re.search(f"\\b{xs.predicted_label}\\b", xs.strong_item_to) is not None) if xs.strong_item_to is not None else None, axis=1)
fc_results["chose_weak"] = fc_results.apply(lambda xs: (re.search(f"\\b{xs.predicted_label}\\b", xs.weak_item_to) is not None) if xs.weak_item_to is not None else None, axis=1)
fc_results["chose_strong_or_weak"] = fc_results.chose_strong.fillna(False) | fc_results.chose_weak.fillna(False)

### Overall plots

In [None]:
fc_results.groupby("fc_pair").base_to.nunique()

In [None]:
fc_results.query("base_model_name == 'w2v2_8' and model_name == 'ff_32' and fc_pair == 'Z_S'").groupby("base_to").chose_strong.mean()

In [None]:
sns.catplot(data=fc_results.groupby(run_groupers + ["layer", "fc_pair", "base_to"]).apply(
    lambda xs: pd.Series({
        "chose_strong_norm": xs.chose_strong.sum() / xs.chose_strong_or_weak.sum(),
        "chose_weak_norm": xs.chose_weak.sum() / xs.chose_strong_or_weak.sum()
    })) \
    .dropna().reset_index(),
    x="layer", col="model_name", y="chose_strong_norm", hue="fc_pair", kind="point", errorbar="se")

### Exploration

In [None]:
focus_fc_results = fc_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")

In [None]:
focus_fc_results.groupby("fc_pair").sample(4)[["fc_pair", "base_to", "inflected_to", "strong_item_to", "weak_item_to", "base_from", "inflected_from", "predicted_label", "chose_strong", "chose_weak"]]

In [None]:
sns.catplot(data=focus_fc_results.groupby(["fc_pair", "post_divergence_from", "base_to"]).chose_strong.mean().dropna().reset_index(),
            x="fc_pair", hue="post_divergence_from", y="chose_strong", kind="bar", errorbar="se")

In [None]:
# fc_results_by_word = fc_results.groupby(run_groupers + ["layer", "fc_pair", "base_to"]).apply(
#     lambda xs: pd.Series({
#         "chose_strong_norm": xs.chose_strong.sum() / xs.chose_strong_or_weak.sum(),
#         "chose_weak_norm": xs.chose_weak.sum() / xs.chose_strong_or_weak.sum(),
#         "chose_strong_or_weak_proportion": xs.chose_strong_or_weak.mean(),
#     }))
# fc_results_by_word = fc_results_by_word[fc_results_by_word.chose_strong_or_weak_proportion > 0.1] \
#     .dropna().reset_index().sort_values("chose_strong_norm").assign(run_name=lambda xs: xs.model_name + " " + xs.layer.map("{:02d}".format))

# g = sns.FacetGrid(data=fc_results_by_word, col="fc_pair", row="run_name",
#                   row_order=sorted(fc_results_by_word.run_name.unique()),
#                   height=12, aspect=0.6, sharey=False)
# def f(data, **kwargs):
#     sns.heatmap(data=data.set_index("base_to")[["chose_strong_norm", "chose_weak_norm"]],
#                 vmin=0, vmax=1, cbar=False, **kwargs)
# g.map_dataframe(f)

In [None]:
fc_results_by_word_and_source = focus_fc_results.groupby(run_groupers + ["fc_pair", "post_divergence_from", "base_to"]).apply(
    lambda xs: pd.Series({
        "chose_strong_norm": xs.chose_strong.sum() / xs.chose_strong_or_weak.sum(),
        "chose_weak_norm": xs.chose_weak.sum() / xs.chose_strong_or_weak.sum()
    })) \
    .dropna().reset_index().sort_values("chose_strong_norm")

sns.catplot(data=fc_results_by_word_and_source,
            col="fc_pair", x="post_divergence_from", y="chose_strong_norm",
            kind="bar", errorbar="se", sharex=False)

In [None]:
g = sns.displot(data=fc_results_by_word_and_source.query("fc_pair == 'Z_S'").assign(post_divergence_from=lambda xs: xs.post_divergence_from.map({"S": "s", "Z": "z", "IH Z": "ɪz", "AH Z": "ɪz"})),
                hue="post_divergence_from", x="chose_strong_norm",
                kind="ecdf", height=4, aspect=1.5, linewidth=4)

ax = g.axes.flat[0]
ax.set_xlabel("Proportion of phonologically\nconsistent choices")
ax.set_ylabel("Proportion of\nword types")
g.legend.set_title("Source\nallomorph")
g.legend.set_bbox_to_anchor((0.5, 0.6), transform=ax.transAxes)

g.tight_layout()
g.savefig(f"{output_dir}/fc_results-ecdf.pdf", bbox_inches="tight")

In [None]:
# sns.catplot(data=fc_results_by_word_and_source.query("fc_pair == 'Z_S'").assign(post_divergence_from=lambda xs: xs.post_divergence_from.map({"S": "s", "Z": "z", "IH Z": "ɪz", "AH Z": "ɪz"})),
#             col="fc_pair", y="post_divergence_from", x="chose_strong_norm",
#             kind="swarm", height=7, sharex=False)

In [None]:
g = sns.catplot(data=fc_results_by_word_and_source.query("fc_pair == 'Z_S'").assign(post_divergence_from=lambda xs: xs.post_divergence_from.map({"S": "s", "Z": "z", "IH Z": "ɪz", "AH Z": "ɪz"})),
                x="post_divergence_from", y="chose_strong_norm",
                kind="bar", errorbar="se", sharex=False, height=4)

ax = g.axes.flat[0]
ax.set_ylabel("Proportion of\nphonologically\nconsistent choices")
ax.set_xlabel("Source allomorph")

g.savefig(f"{output_dir}/fc_results_by_word_and_source-Z_S.pdf", bbox_inches="tight")

In [None]:
# g = sns.FacetGrid(data=fc_results_by_word_and_source, col="fc_pair", row="post_divergence_from",
#                   height=12, aspect=0.6, sharey=False)
# def f(data, **kwargs):
#     sns.heatmap(data=data.set_index("base_to")[["chose_strong_norm", "chose_weak_norm"]],
#                 vmin=0, vmax=1, cbar=False, **kwargs)
# g.map_dataframe(f)

### Frequency analysis

In [None]:
fc_results = pd.merge(fc_results, word_freq_df.LogFreq.rename("from_base_freq"),
                            left_on="base_from", right_index=True)
fc_results = pd.merge(fc_results, word_freq_df.LogFreq.rename("from_inflected_freq"),
                            left_on="inflected_from", right_index=True)
fc_results = pd.merge(fc_results, word_freq_df.LogFreq.rename("to_base_freq"),
                              left_on="base_to", right_index=True)
fc_results = pd.merge(fc_results, word_freq_df.LogFreq.rename("to_inflected_freq"),
                            left_on="inflected_to", right_index=True)

In [None]:
fc_results["from_freq"] = fc_results[["from_base_freq", "from_inflected_freq"]].mean(axis=1)
fc_results["to_freq"] = fc_results[["to_base_freq", "to_inflected_freq"]].mean(axis=1)

_, fc_frequency_bins = pd.qcut(pd.concat([fc_results.to_freq, fc_results.from_freq]), q=4, retbins=True)
fc_results["from_freq_bin"] = pd.cut(fc_results.from_freq, bins=fc_frequency_bins, labels=[f"Q{i}" for i in range(1, 5)])
fc_results["to_freq_bin"] = pd.cut(fc_results.to_freq, bins=fc_frequency_bins, labels=[f"Q{i}" for i in range(1, 5)])

In [None]:
focus_fc_results = fc_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")

In [None]:
sns.catplot(data=focus_fc_results.groupby(["fc_pair", "post_divergence_from", "base_to", "from_freq_bin"]).apply(
    lambda xs: pd.Series({
        "chose_strong_norm": xs.chose_strong.sum() / xs.chose_strong_or_weak.sum(),
        "chose_weak_norm": xs.chose_weak.sum() / xs.chose_strong_or_weak.sum()
    })) \
    .dropna().reset_index().sort_values("chose_strong_norm"),
    x="from_freq_bin", y="chose_strong_norm", hue="post_divergence_from", col="fc_pair", kind="point")

In [None]:
sns.catplot(data=focus_fc_results.groupby(["fc_pair", "post_divergence_from", "base_to", "to_freq_bin"]).apply(
    lambda xs: pd.Series({
        "chose_strong_norm": xs.chose_strong.sum() / xs.chose_strong_or_weak.sum(),
        "chose_weak_norm": xs.chose_weak.sum() / xs.chose_strong_or_weak.sum()
    })) \
    .dropna().reset_index().sort_values("chose_strong_norm"),
    x="to_freq_bin", y="chose_strong_norm", hue="post_divergence_from",  col="fc_pair", kind="point")

## Exploratory for viz

In [None]:
expl = nnvb_focus.query("model_name == 'ffff_32' and ((inflection_from == 'VBZ' and inflection_to == 'NNS') or (inflection_from == 'NNS' and inflection_to == 'VBZ'))").reset_index()
expl.loc[expl.inflection_to == "NNS", "involved_noun"] = expl.base_to
expl.loc[expl.inflection_from == "NNS", "involved_noun"] = expl.base_from
expl.loc[expl.inflection_to == "NNS", "involved_allomorph"] = expl.allomorph_to
expl.loc[expl.inflection_from == "NNS", "involved_allomorph"] = expl.allomorph_from
expl["direction"] = expl.apply(lambda xs: "target" if xs.inflection_from == "VBZ" else "source", axis=1)
# only include stable counts
expl = expl.groupby(["involved_noun", "direction"]).filter(lambda xs: len(xs) > 5)
expl = expl.groupby(["involved_noun", "involved_allomorph", "direction"]).gt_label_rank.mean().unstack().dropna()
expl["ratio"] = (expl.target + 1) / (expl.source + 1)

In [None]:
expl.sort_values("ratio").head(10)

In [None]:
expl.sort_values("ratio").tail(10)

In [None]:
expl.sort_values("target").head(6)

In [None]:
expl.sort_values("source").head(6)

In [None]:
nnvb_focus.query("model_name == 'ffff_32' and base_to == 'cedar'").gt_label_rank.mean()

In [None]:
sns.heatmap(expl.sort_values("ratio")[["source", "target"]])

### Exploratory for viz 2

Find a noun N and a verb V for which performance is high within-inflection but for which transfer in one or both directions is poor.

In [None]:
nnvb_focus.query("model_name == 'ffff_32' and inflection_from == inflection_to") \
    .groupby(["inflection_from", "base_to"]).filter(lambda xs: len(xs) > 5) \
    .groupby(["inflection_from", "base_from"]).filter(lambda xs: len(xs) > 5)

In [None]:
expl2 = nnvb_focus.query("model_name == 'ffff_32' and inflection_from == inflection_to") \
    .groupby(["inflection_from", "base_to"]).filter(lambda xs: len(xs) > 5) \
    .groupby(["inflection_from", "base_from"]).filter(lambda xs: len(xs) > 5) \

expl2 = pd.DataFrame({
    "target_rank": expl2.groupby(["inflection_from", "base_to"]).gt_label_rank.mean(),
    "source_rank": expl2.groupby(["inflection_from", "base_from"]).gt_label_rank.mean()
}).dropna()
expl2["ratio"] = (expl2.target_rank + 1) / (expl2.source_rank + 1)
expl2.index.names = ["inflection", "label"]

In [None]:
expl2.sort_values("ratio").head(10)

In [None]:
expl2.sort_values("ratio").tail(10)

In [None]:
expl2[(expl2.target_rank < 1) & (expl2.source_rank < 1)]

In [None]:
transfer_ranks = nnvb_focus.query("model_name == 'ffff_32' and inflection_from != inflection_to") \
    .groupby(["inflection_from", "base_to"]).filter(lambda xs: len(xs) > 5) \
    .groupby(["inflection_from", "base_from"]).filter(lambda xs: len(xs) > 5)

transfer_ranks = pd.DataFrame({
    "source_transfer_rank": transfer_ranks.groupby(["inflection_from", "base_from"]).gt_label_rank.mean(),
    "target_transfer_rank": transfer_ranks.groupby(["inflection_to", "base_to"]).gt_label_rank.mean()
}).dropna()
transfer_ranks.index.names = ["inflection", "label"]
transfer_ranks

In [None]:
expl3 = pd.merge(expl2, transfer_ranks, left_index=True, right_index=True)
expl3["mean_within"] = (expl3.source_rank + expl3.target_rank) / 2
expl3["mean_transfer"] = (expl3.source_transfer_rank + expl3.target_transfer_rank) / 2
expl3["ratio"] = (expl3.mean_within + 1) / (expl3.mean_transfer + 1)

In [None]:
expl3[(expl3.target_rank < 1) & (expl3.source_rank < 1)]# & (expl3.source_transfer_rank > 1) & (expl3.target_transfer_rank > 1)].sort_values("ratio")