In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import re

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm
from statsmodels.formula.api import logit

from src.analysis.state_space import StateSpaceAnalysisSpec
from src.utils import concat_csv_with_indices

In [None]:
sns.set_context("paper", font_scale=1.5)

In [None]:
false_friends_path = "outputs/analogy/inputs/librispeech-train-clean-100/w2v2/false_friends.csv"
state_space_path = "outputs/analogy/inputs/librispeech-train-clean-100/w2v2/state_space_spec.h5"
output_dir = "analogy_figures"

In [None]:
# Grouping variables on experiment results dataframe to select a single run
run_groupers = ["base_model_name", "model_name", "equivalence"]

plot_runs = [(f"w2v2_{i}", "ff_32", "word_broad_10frames_fixedlen25") for i in range(12)] + \
            [(f"w2v2_{i}", "id", "id") for i in range(12)]
# [(f"w2v2_{i}", "discrim-ff_32", "word_broad_10frames_fixedlen25") for i in range(12)] + \

main_plot_run = ("w2v2_8", "ff_32", "word_broad_10frames_fixedlen25")
# choose a vmin, vmax so that all heatmaps have the same color scale
main_plot_vmin, main_plot_vmax = 0.4, 0.9

plot_inflections = ["NNS", "VBZ", "VBD"]

## Load metadata

In [None]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

In [None]:
false_friends_df = pd.read_csv(false_friends_path)

In [None]:
false_friends_df.query("inflection == 'VBD' and not strong")

## Theoretical generalization matrices

In [None]:
inflections = ["NNS", "VBZ"]
allomorphs = ["Z", "S", "IH Z"]
rows = [{"from_inflection": from_inflection, "from_allomorph": from_allomorph,
            "to_inflection": to_inflection, "to_allomorph": to_allomorph}
        for from_inflection in inflections
        for from_allomorph in allomorphs
        for to_inflection in inflections
        for to_allomorph in allomorphs]

generalization_df = pd.DataFrame(rows)
generalization_df["source_label"] = generalization_df["from_inflection"] + " " + generalization_df["from_allomorph"]
generalization_df["target_label"] = generalization_df["to_inflection"] + " " + generalization_df["to_allomorph"]
generalization_df

In [None]:
phonetic_generalization_df = generalization_df.copy()
phonetic_generalization_df["correct"] = False
phonetic_generalization_df.loc[phonetic_generalization_df.from_allomorph == phonetic_generalization_df.to_allomorph, "correct"] = True

In [None]:
sns.heatmap(phonetic_generalization_df.set_index(["source_label", "target_label"]).correct.unstack())

In [None]:
morphological_generalization_df = generalization_df.copy()
morphological_generalization_df["correct"] = False
morphological_generalization_df.loc[phonetic_generalization_df.from_inflection == phonetic_generalization_df.to_inflection, "correct"] = True

In [None]:
sns.heatmap(morphological_generalization_df.set_index(["source_label", "target_label"]).correct.unstack())

In [None]:
metaphon_generalization_df = generalization_df.copy()
metaphon_generalization_df["correct"] = True

In [None]:
sns.heatmap(metaphon_generalization_df.set_index(["source_label", "target_label"]).correct.unstack(),
            vmin=0, vmax=1)

### With false friends

In [None]:
inflections = ["NNS", "VBZ"]
allomorphs = ["Z", "S", "IH Z"]
rows = [{"from_inflection_base": from_inflection, "from_allomorph": from_allomorph,
         "to_inflection_base": to_inflection, "to_allomorph": to_allomorph,
         "ff_from": ff_from, "ff_to": ff_to}
        for from_inflection in inflections
        for from_allomorph in allomorphs
        for to_inflection in inflections
        for to_allomorph in allomorphs
        for ff_from in [False, True]
        for ff_to in [False, True]]

ff_generalization_df = pd.DataFrame(rows)
ff_generalization_df["from_inflection"] = ff_generalization_df["from_inflection_base"] + ff_generalization_df.ff_from.map({False: "", True: "-FF"})
ff_generalization_df["to_inflection"] = ff_generalization_df["to_inflection_base"] + ff_generalization_df.ff_to.map({False: "", True: "-FF"})
ff_generalization_df["source_label"] = ff_generalization_df["from_inflection"] + " " + ff_generalization_df["from_allomorph"]
ff_generalization_df["target_label"] = ff_generalization_df["to_inflection"] + " " + ff_generalization_df["to_allomorph"]
ff_generalization_df

In [None]:
ff_phonetic_generalization_df = ff_generalization_df.copy()
ff_phonetic_generalization_df["correct"] = False
ff_phonetic_generalization_df.loc[ff_phonetic_generalization_df.from_allomorph == ff_phonetic_generalization_df.to_allomorph, "correct"] = True

In [None]:
sns.heatmap(ff_phonetic_generalization_df.set_index(["source_label", "target_label"]).correct.unstack())

In [None]:
ff_morphological_generalization_df = ff_generalization_df.copy()
ff_morphological_generalization_df["correct"] = False
ff_morphological_generalization_df.loc[(ff_phonetic_generalization_df.from_inflection == ff_phonetic_generalization_df.to_inflection)
                                       & ~ff_phonetic_generalization_df.from_inflection.str.contains("-FF"), "correct"] = True

In [None]:
sns.heatmap(ff_morphological_generalization_df.set_index(["source_label", "target_label"]).correct.unstack())

In [None]:
ff_metaphon_generalization_df = ff_generalization_df.copy()
ff_metaphon_generalization_df["correct"] = 0.3
ff_metaphon_generalization_df.loc[~ff_metaphon_generalization_df.ff_from &
                                  ~ff_metaphon_generalization_df.ff_to, "correct"] = 1.0

In [None]:
sns.heatmap(ff_metaphon_generalization_df.set_index(["source_label", "target_label"]).correct.unstack(),
            vmin=0, vmax=1)

## Load results

In [None]:
all_results = concat_csv_with_indices(
        "outputs/analogy/runs/**/experiment_results.csv",
        [lambda p: p.parent.name, lambda p: p.parents[1].name,
            lambda p: p.parents[2].name],
        ["equivalence", "model_name", "base_model_name"]) \
    .droplevel(-1).reset_index()

In [None]:
all_id_results = concat_csv_with_indices(
        "outputs/analogy/runs_id/**/experiment_results.csv",
        [lambda p: p.parent.name],
        ["base_model_name"]) \
    .droplevel(-1).reset_index()
all_id_results["model_name"] = "id"
all_id_results["equivalence"] = "id"

In [None]:
all_results = pd.concat([all_results, all_id_results], ignore_index=True)
all_results["group"] = all_results.group.apply(lambda x: eval(x) if not (isinstance(x, float) and np.isnan(x)) else None)

In [None]:
all_results

In [None]:
!ls -lh analogy_results_20250217.pkl

### Layer-wise

In [None]:
plot_lw = all_results.query("experiment == 'regular'").copy()
plot_lw = plot_lw.groupby(run_groupers + ["group", "inflection_from"]).correct.mean()
plot_lw = plot_lw.reindex([(*plot_run, group, inflection_from)
                           for group in plot_lw.index.get_level_values("group").unique()
                           for inflection_from in plot_lw.index.get_level_values("inflection_from").unique()
                           for plot_run in plot_runs]).reset_index()
plot_lw["group0"] = plot_lw.group.apply(lambda x: x[0] if x is not None else None)
plot_lw["group1"] = plot_lw.group.apply(lambda x: x[1] if x is not None else None)
plot_lw["layer"] = plot_lw.base_model_name.str.extract(r"_(\d+)$").astype(int)

lw_random = plot_lw[plot_lw.group0 == "random"].groupby(["inflection_from", "layer"]).correct.mean().reset_index().dropna()

plot_lw = plot_lw[plot_lw.inflection_from.isin(plot_inflections)]
plot_lw = plot_lw[(plot_lw.group1 == True)]

In [None]:
lw_random

In [None]:
g = sns.catplot(data=plot_lw, x="layer", y="correct", hue="model_name", row="inflection_from",
                kind="point", height=3, aspect=3)

for ax, row_name in zip(g.axes.flat, g.row_names):
    sns.lineplot(data=lw_random,
                 x="layer", y="correct", ax=ax, color="gray", linestyle="--",
                 legend=None)
    ax.set_title(ax.get_title().split("=")[1].strip())
    ax.set_ylabel("Correct")
    if ax.get_xlabel() == "layer":
        ax.set_xlabel("Layer")

# g.figure.tight_layout()
g.figure.savefig(f"{output_dir}/layer_wise.pdf")

In [None]:
mca = pd.read_csv("outputs/analogy/inputs/librispeech-train-clean-100/w2v2/most_common_allomorphs.csv", index_col=0)
mca.query("inflection == 'VBD'")

## Compute controlled NNVB results

In [None]:
all_nnvb_results = []

for run, run_results in all_results.groupby(run_groupers):
    run_results = run_results.set_index("experiment")
    nnvb_expts = run_results.index.unique()
    nnvb_expts = nnvb_expts[nnvb_expts.str.contains("unambiguous-")]

    for expt in nnvb_expts:
        inflection_from, allomorph_from, inflection_to, allomorph_to = \
            re.findall(r"unambiguous-(\w+)_([\w\s]+)_to_(\w+)_([\w\s]+)", expt)[0]
        expt_df = run_results.loc[expt].copy()

        num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))
        # DEV
        # if num_seen_words < 10:
        #     print(f"Skipping {expt} due to only {num_seen_words} seen words")
        #     continue

        expt_df["inflection_from"] = inflection_from
        expt_df["allomorph_from"] = allomorph_from
        expt_df["inflection_to"] = inflection_to
        expt_df["allomorph_to"] = allomorph_to

        all_nnvb_results.append(expt_df)

all_nnvb_results = pd.concat(all_nnvb_results)

all_nnvb_results["inflected_from"] = all_nnvb_results.from_equiv_label.apply(lambda x: eval(x)[1])
all_nnvb_results["inflected_to"] = all_nnvb_results.to_equiv_label.apply(lambda x: eval(x)[1])

all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("from_base_freq"),
                            left_on="base_from", right_index=True)
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("from_inflected_freq"),
                            left_on="inflected_from", right_index=True)
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("to_base_freq"),
                              left_on="base_to", right_index=True)
all_nnvb_results = pd.merge(all_nnvb_results, word_freq_df.LogFreq.rename("to_inflected_freq"),
                            left_on="inflected_to", right_index=True)

all_nnvb_results["from_freq"] = all_nnvb_results[["from_base_freq", "from_inflected_freq"]].mean(axis=1)
all_nnvb_results["to_freq"] = all_nnvb_results[["to_base_freq", "to_inflected_freq"]].mean(axis=1)

_, frequency_bins = pd.qcut(pd.concat([all_nnvb_results.to_freq, all_nnvb_results.from_freq]), q=5, retbins=True)
all_nnvb_results["to_freq_bin"] = pd.cut(all_nnvb_results.to_freq, bins=frequency_bins, labels=[f"Q{i}" for i in range(1, 6)])
all_nnvb_results["from_freq_bin"] = pd.cut(all_nnvb_results.from_freq, bins=frequency_bins, labels=[f"Q{i}" for i in range(1, 6)])

In [None]:
def summarize_nnvb_run(rows):
    rows["source_label"] = rows.inflection_from + " " + rows.allomorph_from
    rows["target_label"] = rows.inflection_to + " " + rows.allomorph_to

    rows["transfer_label"] = rows.inflection_from + " -> " + rows.inflection_to
    rows["phon_label"] = rows.allomorph_from + " " + rows.allomorph_to

    # only retain cases where we have data in both transfer directions from source <-> target within inflection
    rows["complement_exists"] = rows.apply(lambda row: len(rows.query("source_label == @row.target_label and target_label == @row.source_label")), axis=1)
    rows = rows.query("complement_exists > 0").drop(columns=["complement_exists"])

    return rows

summary_groupers = ["inflection_from", "inflection_to", "allomorph_from", "allomorph_to"]
nnvb_results_summary = all_nnvb_results.groupby(run_groupers + summary_groupers) \
    .correct.agg(["count", "mean"]) \
    .query("count >= 0") \
    .reset_index(summary_groupers) \
    .groupby(run_groupers, group_keys=False) \
    .apply(summarize_nnvb_run) \
    .reset_index()

nnvb_results_summary

In [None]:
plot_results = []
for base_model_name, model_name, equivalence in plot_runs:
    results_i = nnvb_results_summary.query("base_model_name == @base_model_name and model_name == @model_name and equivalence == @equivalence")
    if len(results_i) > 0:
        plot_results.append(results_i)
num_plot_runs = len(plot_results)

ncols = 2
nrows = int(np.ceil(num_plot_runs / ncols))
f, axs = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))

for ax, results_i in zip(axs.flat, plot_results):
    sns.heatmap(results_i.set_index(["source_label", "target_label"])["mean"].unstack(),
                vmin=0, vmax=1, ax=ax)
    key_row = results_i.iloc[0]
    ax.set_title(f"{key_row.base_model_name} -> {key_row.model_name} ({key_row.equivalence})")

### Focused plots

In [None]:
focus_base_model, focus_model, focus_equivalence = main_plot_run
foil_base_model, foil_model, foil_equivalence = "w2v2_8", "id", "id"

nnvb_focus = all_nnvb_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")
nnvb_foil = all_nnvb_results.query("base_model_name == @foil_base_model and model_name == @foil_model and equivalence == @foil_equivalence")
nnvb_focus["model_label"] = "Word"
nnvb_foil["model_label"] = "Wav2Vec"

nnvb_focus = pd.concat([nnvb_focus, nnvb_foil])

allomorph_labels = {"Z": "z", "S": "s", "IH Z": "ɪz"}
nnvb_focus["allomorph_from"] = nnvb_focus.allomorph_from.map(allomorph_labels)
nnvb_focus["allomorph_to"] = nnvb_focus.allomorph_to.map(allomorph_labels)
nnvb_focus

In [None]:
nnvb_results_summary = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to",
                                             "allomorph_from", "allomorph_to"]) \
    .correct.agg(["count", "mean"]) \
    .query("count >= 0") \
    .reset_index()

nnvb_results_summary["source_label"] = nnvb_results_summary.inflection_from + "\n" + nnvb_results_summary.allomorph_from
nnvb_results_summary["target_label"] = nnvb_results_summary.inflection_to + "\n" + nnvb_results_summary.allomorph_to

nnvb_results_summary["transfer_label"] = nnvb_results_summary.inflection_from + " -> " + nnvb_results_summary.inflection_to
nnvb_results_summary["phon_label"] = nnvb_results_summary.allomorph_from + " " + nnvb_results_summary.allomorph_to

# only retain cases where we have data in both transfer directions from source <-> target within inflection
nnvb_results_summary["complement_exists"] = nnvb_results_summary.apply(lambda row: len(nnvb_results_summary.query("source_label == @row.target_label and target_label == @row.source_label")), axis=1)
nnvb_results_summary = nnvb_results_summary.query("complement_exists > 0").drop(columns=["complement_exists"])

# drop VBZ IH Z, which only has 4 word types
nnvb_results_summary = nnvb_results_summary[(nnvb_results_summary.source_label != "VBZ\nɪz") & (nnvb_results_summary.target_label != "VBZ\nɪz")]

nnvb_results_summary

In [None]:
nnvb_focus_bar = nnvb_focus.assign(source_label=lambda xs: xs.inflection_from + " " + xs.allomorph_from)
nnvb_focus_bar = nnvb_focus_bar[(nnvb_focus_bar.source_label != "VBZ ɪz")]
order = nnvb_focus_bar.groupby("source_label").correct.mean().sort_values().index
g = sns.catplot(data=nnvb_focus_bar, x="inflection_to", hue="source_label", y="correct", col="model_label", kind="bar")
g._legend.set_title("Train inflection\nand allomorph")

for ax in g.axes.flat:
    ax.set_title(ax.get_title().split("=")[1].strip())
    ax.set_xlabel("Test inflection")
    ax.set_ylabel("Accuracy")

In [None]:
# f, ax = plt.subplots(1, 2, figsize=(7 * 2, 6))

f, axs = plt.subplots(1, 3, figsize=(7 * 2, 6), gridspec_kw={'width_ratios': [1, 1, 0.04]})
for i, (ax, (model_label, rows)) in enumerate(zip(axs, nnvb_results_summary.groupby("model_label"))):
    cbar_ax = None
    if i == 1:
        cbar_ax = axs.flat[-1]

    ax.set_title(model_label)
    sns.heatmap(rows.set_index(["source_label", "target_label"]).sort_index()["mean"].unstack("target_label"),
                vmin=main_plot_vmin, vmax=main_plot_vmax, annot=True, ax=ax,
                cbar=i == 1, cbar_ax=cbar_ax)

    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    ax.set_ylabel("Train")
    ax.set_xlabel("Test")

f.tight_layout()
f.savefig(f"{output_dir}/nnvb_allomorphs.pdf")

In [None]:
nnvb_results_summary2 = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

nnvb_results_summary2["transfer_label"] = nnvb_results_summary2.inflection_from + " -> " + nnvb_results_summary2.inflection_to
nnvb_results_summary2

In [None]:
f, axs = plt.subplots(1, 3, figsize=(4 * 2, 3), gridspec_kw={'width_ratios': [1, 1, 0.04]})

for i, (ax, (model_label, rows)) in enumerate(zip(axs, nnvb_results_summary2.groupby("model_label"))):
    cbar_ax = None
    if i == 1:
        cbar_ax = axs.flat[-1]
    
    ax.set_title(model_label)
    sns.heatmap(rows.set_index(["inflection_from", "inflection_to"]).correct.unstack(),
                annot=True, vmin=main_plot_vmin, vmax=main_plot_vmax, ax=ax,
                cbar=i == 1, cbar_ax=cbar_ax)
    ax.set_xlabel("Test")
    ax.set_ylabel("Train")

f = plt.gcf()
f.tight_layout()
f.savefig(f"{output_dir}/nnvb_results.pdf")

In [None]:
nnvb_plot = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to", "base_to"]).correct.mean().reset_index()
nnvb_plot["transfer_label"] = nnvb_plot.inflection_from + " -> " + nnvb_plot.inflection_to
order = nnvb_plot.groupby("transfer_label").correct.mean().sort_values().index
g = sns.catplot(data=nnvb_plot, x="transfer_label", y="correct", kind="bar", hue="model_label", order=order, errorbar="se", height=4, aspect=2)
g._legend.set_title("Model")
ax = g.axes.flat[0]

ax.set_xlabel("Evaluation")
ax.set_ylabel("Accuracy")

In [None]:
nnvb_plot = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to", "base_to"]).correct.mean().reset_index()
nnvb_plot["transfer_label"] = nnvb_plot.inflection_from + " -> " + nnvb_plot.inflection_to
order = nnvb_plot.groupby("transfer_label").correct.mean().sort_values().index
g = sns.catplot(data=nnvb_plot.query("model_label == 'Word'"), x="transfer_label", y="correct", kind="bar", hue="model_label", order=order, errorbar="se", height=4, aspect=2)
g._legend.set_title("Model")
ax = g.axes.flat[0]

ax.set_xlabel("Evaluation")
ax.set_ylabel("Accuracy")

In [None]:
nnvb_phase1 = nnvb_plot[nnvb_plot.inflection_from == nnvb_plot.inflection_to]
order = nnvb_phase1.groupby("transfer_label").correct.mean().sort_values().index
g = sns.catplot(data=nnvb_phase1, x="transfer_label", y="correct", kind="bar", hue="inflection_to", order=order, errorbar="se", height=4, aspect=1)
# remove legend
g._legend.remove()
ax = g.axes.flat[0]

ax.set_xlabel("Evaluation")
ax.set_ylabel("Accuracy")

In [None]:
order = sorted(nnvb_results_summary.transfer_label.unique(), key=lambda x: x[4:])
g = sns.catplot(data=nnvb_results_summary, x="transfer_label", y="mean", hue="inflection_to",
            order=order, kind="swarm", errorbar="se", height=5, aspect=1.5, size=11)
ax = g.axes.flat[0]
ax.set_xlabel("Evaluation")
ax.set_ylabel("Mean accuracy")

### Regression analysis

In [None]:
# def get_interaction_strength(rows):
#     rows["correct"] = rows.correct.astype(int)
    
#     # exclude rare
#     rows = rows[~((rows.inflection_from == "VBZ") & rows.inflection_from == "IH Z") &
#                 ~((rows.inflection_to == "VBZ") & rows.inflection_to == "IH Z")]
    
#     # standardize frequency
#     rows["from_freq"] = (rows.from_freq - rows.from_freq.mean()) / rows.from_freq.std()
#     rows["to_freq"] = (rows.to_freq - rows.to_freq.mean()) / rows.to_freq.std()
    
#     formula = "correct ~ C(inflection_from, Treatment(reference='NNS')) * C(inflection_to, Treatment(reference='NNS')) + " \
#               "C(allomorph_from, Treatment(reference='Z')) * C(allomorph_to, Treatment(reference='Z')) +" \
#               "from_freq + to_freq"
    
#     model = logit(formula, data=rows).fit()

#     return model.params

In [None]:
# interaction_strengths = all_nnvb_results.groupby(run_groupers).apply(get_interaction_strength) \
#     .reset_index().melt(id_vars=run_groupers, value_name="coef_norm")
# interaction_strengths = interaction_strengths[interaction_strengths.variable.str.contains(":", regex=True)]
# interaction_strengths = interaction_strengths.groupby(run_groupers).coef_norm.apply(lambda xs: np.linalg.norm(xs, ord=1)).sort_values()

In [None]:
# plot_accuracy = all_nnvb_results.groupby(run_groupers).correct.mean().reindex(plot_runs).reset_index()
# plot_accuracy["layer"] = plot_accuracy.base_model_name.str.extract(r"_(\d+)$").astype(int)

# g = sns.catplot(data=plot_accuracy, x="layer", y="correct", hue="model_name", height=3, aspect=2, kind="point")
# # g.figure.tight_layout()
# g.figure.savefig(f"{output_dir}/nnvb_layer_wise.pdf")

In [None]:
# plot_is = interaction_strengths.reindex(plot_runs).reset_index()
# plot_is["layer"] = plot_is.base_model_name.str.extract(r"_(\d+)$").astype(int)

# g = sns.catplot(data=plot_is, x="layer", y="coef_norm", hue="model_name", kind="point", height=3, aspect=2)
# g.axes.flat[0].set_ylabel("Allomorph/inflection\ninteraction strength")
# g.axes.flat[0].set_xlabel("Layer")

# g.savefig(f"{output_dir}/interaction_strength.pdf")

### Digression

In [None]:
study_df = nnvb_focus[(nnvb_focus.inflection_to == "VBZ")].copy()

In [None]:
study_df["predicted_stem"] = study_df.predicted_label.str.replace(r"s$|ed$|ings?$", "", regex=True)
study_df["base_to_stem"] = study_df.base_to.str.replace(r"e$", "", regex=True).replace(r"y$", "i", regex=True)
study_df["predicted_within_inflection"] = \
    (study_df.predicted_stem == study_df.base_to) | (study_df.predicted_stem == study_df.base_to_stem)
vb_irregulars = [("do", "did"), ("do", "does"), ("begin", "began"), ("learn", "learnt"), ("send", "sent"), ("shine", "shone"), ("seem", "seem'd"), ("read", "red"),
                 ("possess", "possesses"), ("bring", "brings"), ("carry", "carries"), ("occur", "occurred"), ("think", "thinkest"), ("grow", "grew"),
                 ("put", "putting"), ("begin", "beginning"), ("give", "givest"),
                 # homophones
                 ("allow", "aloud"), ("write", "rights"), ("write", "wright's"), ("depend", "dependent")]
for base, predicted in vb_irregulars:
    study_df.loc[study_df.base_to == base, "predicted_within_inflection"] |= study_df.loc[study_df.base_to == base].predicted_label == predicted

In [None]:
# study_df[~study_df.predicted_within_inflection].groupby("base_to").predicted_label.value_counts().sort_values(ascending=False).iloc[60:80]

In [None]:
nnvb_focus

In [None]:
merge_keys = ["experiment", "equivalence", "model_label", "model_name", "base_model_name", "group", "inflection_from", "inflection_to", "base_from", "inflected_from", "base_to", "inflected_to"]
nnvb_focus = pd.merge(nnvb_focus.reset_index(), study_df.reset_index()[merge_keys + ["predicted_within_inflection"]],
         on=merge_keys, how="left")

In [None]:
nnvb_focus["correct_or_predicted_within_inflection"] = nnvb_focus.correct | nnvb_focus.predicted_within_inflection

In [None]:
nnvb_results_summary2 = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to"]) \
    [["correct_or_predicted_within_inflection", "correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

nnvb_results_summary2["transfer_label"] = nnvb_results_summary2.inflection_from + " -> " + nnvb_results_summary2.inflection_to
nnvb_results_summary2

In [None]:
f, axs = plt.subplots(1, 3, figsize=(4 * 2, 3), gridspec_kw={'width_ratios': [1, 1, 0.04]})

for i, (ax, (model_label, rows)) in enumerate(zip(axs, nnvb_results_summary2.groupby("model_label"))):
    cbar_ax = None
    if i == 1:
        cbar_ax = axs.flat[-1]
    
    ax.set_title(model_label)
    sns.heatmap(rows.set_index(["inflection_from", "inflection_to"]).correct_or_predicted_within_inflection.unstack(),
                vmin=main_plot_vmin, vmax=main_plot_vmax,
                annot=True, ax=ax, cbar=i == 1, cbar_ax=cbar_ax)
    ax.set_xlabel("Test")
    ax.set_ylabel("Train")

f = plt.gcf()
f.tight_layout()
f.savefig(f"{output_dir}/nnvb_results-correct_inflection.pdf")

In [None]:
nnvb_plot = nnvb_focus.groupby(["model_label", "inflection_from", "inflection_to", "base_to"]).correct_or_predicted_within_inflection.mean().reset_index()
nnvb_plot["transfer_label"] = nnvb_plot.inflection_from + " -> " + nnvb_plot.inflection_to
order = ["NNS -> VBZ", "VBZ -> VBZ", "VBZ -> NNS", "NNS -> NNS"]
g = sns.catplot(data=nnvb_plot, x="transfer_label", y="correct_or_predicted_within_inflection", kind="bar", hue="model_label", order=order, errorbar="se", height=4, aspect=2)
g._legend.set_title("Model")
ax = g.axes.flat[0]

ax.set_xlabel("Evaluation")
ax.set_ylabel("Accuracy")

### Frequency analysis

In [None]:
sns.catplot(data=all_nnvb_results.query("base_model_name == 'w2v2_8'").reset_index(),
            x="from_freq_bin", y="correct", hue="model_name",
            row="inflection_from", col="inflection_to", units="base_from", kind="point")

In [None]:
sns.catplot(data=all_nnvb_results.query("base_model_name == 'w2v2_8'").reset_index(),
            x="to_freq_bin", y="correct", hue="model_name",
            row="inflection_from", col="inflection_to", units="base_to", kind="point")

## False friend analysis

In [None]:
all_ff_results = []

for run, run_results in all_results.groupby(run_groupers):
    run_results = run_results.set_index("experiment")
    false_friend_expts = run_results.index.unique()
    false_friend_expts = false_friend_expts[false_friend_expts.str.contains("FF")]

    for expt_name in false_friend_expts:
        expt_df = run_results.loc[expt_name].copy()
        num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))

        if num_seen_words < 10:
            print(f"Skipping {expt} due to only {num_seen_words} seen words")
            continue

        if expt_name.count("-FF-") == 2:
            allomorph_from, allomorph_to = re.findall(r"-FF-([\w\s]+)-to-.+FF-([\w\s]+)", expt_name)[0]
            ff_from, ff_to = True, True
        else:
            try:
                allomorph_from, allomorph_to = re.findall(r"_([\w\s]+)-to-.+FF-([\w\s]+)", expt_name)[0]
                # is the false friend on the "from" side?
                ff_from, ff_to = False, True
            except:
                allomorph_from, allomorph_to = re.findall(r".+FF-([\w\s]+)-to-.+_([\w\s]+)", expt_name)[0]
                ff_from, ff_to = True, False

        expt_df["allomorph_from"] = allomorph_from
        expt_df["allomorph_to"] = allomorph_to

        if ff_from:
            expt_df["inflection_from"] = expt_df.inflection_from.str.replace("-FF-.+", "-FF", regex=True)
        if ff_to:
            expt_df["inflection_to"] = expt_df.inflection_to.str.replace("-FF-.+", "-FF", regex=True)

        expt_df["ff_from"] = ff_from
        expt_df["ff_to"] = ff_to

        all_ff_results.append(expt_df)

    # add within-false-friend tests
    expt_df = run_results.loc["false_friends"].copy()
    expt_df["allomorph_from"] = expt_df.inflection_from.str.extract(r"FF-(.+)$")
    expt_df["allomorph_to"] = expt_df.inflection_to.str.extract(r"FF-(.+)$")
    expt_df["inflection_from"] = expt_df.inflection_from.str.replace("-FF-.+", "-FF", regex=True)
    expt_df["inflection_to"] = expt_df.inflection_to.str.replace("-FF-.+", "-FF", regex=True)
    expt_df["ff_from"] = True
    expt_df["ff_to"] = True

    all_ff_results.append(expt_df)

    # expt_df = expt_df[expt_df.inflection_from.isin(all_ff_results.inflection_from.unique())]

all_ff_results = pd.concat(all_ff_results).reset_index()

ff_exclude = "wreck e eh wandering lo chiu ha hahn meek jew"
ff_exclude_inflected = "bunce los"

# exclude FF bases
all_ff_results = all_ff_results[~(all_ff_results.inflection_from.str.endswith("-FF") & all_ff_results.base_from.isin(ff_exclude.split()))]
all_ff_results = all_ff_results[~(all_ff_results.inflection_to.str.endswith("-FF") & all_ff_results.base_to.isin(ff_exclude.split()))]

all_ff_results["inflected_from"] = all_ff_results.from_equiv_label.apply(lambda x: eval(x)[1])
all_ff_results["inflected_to"] = all_ff_results.to_equiv_label.apply(lambda x: eval(x)[1])

# exclude FF inflected
all_ff_results = all_ff_results[~(all_ff_results.inflection_from.str.endswith("-FF") & all_ff_results.inflected_from.isin(ff_exclude_inflected.split()))]
all_ff_results = all_ff_results[~(all_ff_results.inflection_to.str.endswith("-FF") & all_ff_results.inflected_to.isin(ff_exclude_inflected.split()))]

# add frequency information
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("from_base_freq"), left_on="base_from", right_index=True)
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("to_base_freq"), left_on="base_to", right_index=True)
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("from_inflected_freq"), left_on="inflected_from", right_index=True)
all_ff_results = pd.merge(all_ff_results, word_freq_df["LogFreq"].rename("to_inflected_freq"), left_on="inflected_to", right_index=True)
all_ff_results["from_freq"] = all_ff_results[["from_base_freq", "from_inflected_freq"]].mean(axis=1)
all_ff_results["to_freq"] = all_ff_results[["to_base_freq", "to_inflected_freq"]].mean(axis=1)

all_ff_results["transfer_label"] = all_ff_results.inflection_from + " -> " + all_ff_results.inflection_to

In [None]:
# Post-hoc fix some bugs
all_ff_results.loc[(all_ff_results.base_to == "tho") & (all_ff_results.predicted_label == "though") & (all_ff_results.gt_label_rank == 1), "correct"] = True
all_ff_results.loc[(all_ff_results.base_to == "philip") & (all_ff_results.predicted_label == "philip's"), "correct"] = True
all_ff_results.loc[(all_ff_results.base_to == "adam") & (all_ff_results.predicted_label == "adam's"), "correct"] = True
all_ff_results.loc[(all_ff_results.base_to == "who") & (all_ff_results.predicted_label == "who's"), "correct"] = True

In [None]:
false_friend_strong_lookup = false_friends_df.set_index(["base", "inflected", "post_divergence"]).strong.to_dict()

In [None]:
def get_is_strong(rows):
    keys = []
    row = rows.iloc[0]
    if "-FF" in row.inflection_from:
        keys.append((row.base_from, row.inflected_from, row.allomorph_from))
    if "-FF" in row.inflection_to:
        keys.append((row.base_to, row.inflected_to, row.allomorph_to))

    # print(keys)
    strong_results = [false_friend_strong_lookup[base, inflected, allomorph] for base, inflected, allomorph in keys]
    return all(strong_results)

strong_grouper = ["inflection_from", "inflection_to", "inflected_from", "inflected_to", "base_from", "base_to", "allomorph_from", "allomorph_to"]
strong_values = all_ff_results.groupby(strong_grouper).apply(get_is_strong).rename("is_strong")
all_ff_results = pd.merge(all_ff_results, strong_values, left_on=strong_grouper, right_index=True)

In [None]:
weak_ff_results = all_ff_results[~all_ff_results.is_strong]

# ONLY STRONG
all_ff_results = all_ff_results[all_ff_results.is_strong]

### Weak sub-analysis

In [None]:
ss = StateSpaceAnalysisSpec.from_hdf5(state_space_path)

In [None]:
# these base forms participate in a real inflection,
# so they have a distractor
weak_alternates = {
    "barbara": "barbara's",
    "bay": "bays",
    "den": "dens",
    "dew": "dews",
    "fall": "falls",
    "fear": "fears",
    "flee": "flees",
    "hen": "hens",
    "her": "hers",
    "jew": "jews",
    "joy": "joys",
    "lay": "lays",
    "one": "ones one's",
    "patricia": "patricia's",
    "peer": "peers",
    "per": "purrs",
    "river": "rivers river's",
    "saw": "saws",
    "scare": "scares",
    "sin": "sins",
    "syria": "syria's",
    "victoria": "victoria's"
}

# Only retain cases where the alternate is in the vocabulary
drop_cases = [k for k, v in weak_alternates.items() if not any(word in ss.labels for word in v.split(" "))]
print(f"Dropping {len(drop_cases)} cases: {drop_cases}")
weak_alternates = {k: v for k, v in weak_alternates.items() if any(word in ss.labels for word in v.split(" "))}

In [None]:
weak_ff_results[~weak_ff_results.base_to.map(weak_alternates).isnull() & weak_ff_results.ff_to].base_to.unique()

In [None]:
weak_ff_results[weak_ff_results.base_to.map(weak_alternates).isnull() & weak_ff_results.ff_to].base_to.unique()

In [None]:
weak_sub_df = weak_ff_results[weak_ff_results.ff_to].copy()
weak_sub_df["competitor_to"] = weak_sub_df.base_to.map(weak_alternates)
weak_sub_df = weak_sub_df[weak_sub_df.competitor_to.notnull()]
weak_sub_df = weak_sub_df.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")
weak_sub_df["alternate_match"] = weak_sub_df.apply(lambda xs: re.match(f"\\b{xs.predicted_label}\\b", xs.competitor_to) is not None, axis=1)
weak_sub_df["source_label"] = weak_sub_df.inflection_from + " " + weak_sub_df.allomorph_from
# just study /S/ false friends right now
weak_sub_df = weak_sub_df.query("inflection_to in ['VBZ-FF', 'NNS-FF'] and allomorph_to == 'S' and allomorph_from in ['Z', 'S']")
weak_sub_df

In [None]:
def get_props(xs):
    return pd.Series({
        "correct": xs.correct.sum(),
        "alternate_match": xs.alternate_match.sum(),
        "prop_correct": xs.correct.sum() / len(xs),#(xs.correct.sum() + xs.alternate_match.sum()),
        "prop_alternate_match": xs.alternate_match.sum() / len(xs),#/ (xs.correct.sum() + xs.alternate_match.sum())
    })
weak_sub_plot = weak_sub_df.groupby(["source_label", "inflection_to", "base_to"]).apply(get_props).reset_index()
weak_sub_plot["source_label"] = weak_sub_plot.source_label.str.replace(r"(NNS|VBZ)-?\s*", "", regex=True)

In [None]:
from matplotlib.patches import bbox_artist

f, ax = plt.subplots(figsize=(8, 4))

hue_order = weak_sub_plot.groupby("source_label").prop_correct.mean().sort_values().index
ax = sns.barplot(data=weak_sub_plot, x="inflection_to", hue="source_label", hue_order=hue_order, y="prop_alternate_match", errorbar="se")
ax.set_ylabel("Proportion\nchoices of /z/", rotation=0, labelpad=70)
ax.set_xlabel("Test inflection")
ax.legend(title="Source inflection", loc="upper right", bbox_to_anchor=(1.325, 1))
# ax.figure.tight_layout()

In [None]:
from matplotlib.patches import bbox_artist

f, ax2 = plt.subplots(figsize=(8, 4))

hue_order = weak_sub_plot.groupby("source_label").prop_correct.mean().sort_values().index
ax2 = sns.barplot(data=weak_sub_plot, x="inflection_to", hue="source_label", hue_order=hue_order, y="prop_correct", errorbar="se")
ax2.set_ylabel("Proportion\nchoices of /s/", rotation=0, labelpad=70)
ax2.set_xlabel("Test inflection")
ax2.set_ylim(ax.get_ylim())
ax2.legend(title="Source inflection", loc="upper right", bbox_to_anchor=(1.325, 1))
# ax.figure.tight_layout()

In [None]:
weak_sub_plot = weak_sub_df.groupby(["source_label", "inflection_to", "base_to"])[["correct", "alternate_match"]].mean().reset_index() \
    .melt(id_vars=["source_label", "inflection_to", "base_to"])
weak_sub_plot["source_label"] = weak_sub_plot.source_label.str.replace(r"(NNS|VBZ)-?\s*", "", regex=True)

hue_order = weak_sub_plot.groupby("source_label").value.mean().sort_values().index
sns.catplot(data=weak_sub_plot, x="inflection_to", hue="source_label", y="value", col="variable",
            hue_order=hue_order, kind="bar", errorbar="se")

In [None]:
weak_sub_df.groupby(["inflection_to", "source_label"]).correct.mean().groupby("inflection_to").apply(lambda xs: xs.sort_values())

In [None]:
weak_sub_df.groupby(["inflection_to", "source_label"]).alternate_match.mean().groupby("inflection_to").apply(lambda xs: xs.sort_values())

We first consider the case of strong false friends, which have a concatenated /z/ /s/ or /Iz/ and could have been distributed by the same surface alternate pattern.
For example, "beside" -- "besides" does not instantiate the relevant morphological pattern, but the concatenation of /z/ is consistent with the pattern on the surface.
We see roughly similar results for strong false friends as for real morphological pairs.
This indicates that the surface alternation pattern extends beyond real morphological patterns; the representation is thus capturing something not at the morphological level here, but at the surface-alternate level.

We next consider the case of weak false friends, which have a concatenated /z/ /s/ or /Iz/ but do not respect the surface alternate pattern.
For example, "flee" -- "fleece" has the right concatenative relationship but does not respect the pattern (here we would expect to see /z/ following the vowel).
Generalization here is quite poor, indicating that the model is not appropriately applying the surface alternate pattern.
Error analysis reveals that the model is respecting the surface alternate pattern just as we expect.

We zoom in on weak false friends with concatenated /s/, for example "flee" "fleece." Here there is a competitor "flees" which respects the surface alternation pattern.
Call these items "weak false friends with alternates" (WFFA).
We evaluate the model's performance on WFFA from two sources: valid instances of plurals with /z/, and valid instances of plurals with /s/.
Computing analogies from the former category leads the model to predict the alternate form /z/, unsurprisingly.
However, computing analogies from the latter category leads the model to predict the alternate form /z/ at roughly the same rate!
This indicates that the latter category is not generating a context-free phonological representation, but rather a more abstract representation of "give me the right kind of surface alternate."

In [None]:
weak_sub2_df = weak_ff_results[weak_ff_results.ff_to].copy()
weak_sub2_df["competitor_to"] = weak_sub2_df.base_to.map(weak_alternates)
weak_sub2_df = weak_sub2_df[weak_sub2_df.competitor_to.isnull()]
weak_sub2_df = weak_sub2_df.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")
weak_sub2_df["source_label"] = weak_sub2_df.inflection_from + " " + weak_sub2_df.allomorph_from
# just study /S/ false friends right now
weak_sub2_df = weak_sub2_df.query("inflection_to in ['VBZ-FF', 'NNS-FF'] and allomorph_to == 'S' and allomorph_from in ['Z', 'S']")
weak_sub2_df

In [None]:
weak_sub2_df.groupby(["inflection_to", "source_label"]).correct.mean().groupby("inflection_to").apply(lambda xs: xs.sort_values())

### Main FF analysis

In [None]:
ff_frequency_bins = pd.qcut(pd.concat([all_ff_results.to_freq, all_ff_results.from_freq]), q=3, retbins=True)[1]
all_ff_results["from_freq_bin"] = pd.cut(all_ff_results.from_freq, bins=ff_frequency_bins, labels=[f"Q{i}" for i in range(1, 4)])
all_ff_results["to_freq_bin"] = pd.cut(all_ff_results.to_freq, bins=ff_frequency_bins, labels=[f"Q{i}" for i in range(1, 4)])

In [None]:
# Compare distribution of false friend word frequencies to distribution of NN/VB frequencies.
# This is to see if the false friends are more likely to be rare words.
false_friend_words = pd.concat([all_ff_results.query("ff_from").base_from, all_ff_results.query("ff_to").base_to]).unique()
nn_words = pd.concat([all_nnvb_results.query("inflection_from == 'NNS'").base_from,
                        all_nnvb_results.query("inflection_to == 'NNS'").base_to]).unique()
vb_words = pd.concat([all_nnvb_results.query("inflection_from == 'VBZ'").base_from,
                        all_nnvb_results.query("inflection_to == 'VBZ'").base_to]).unique()

In [None]:
expt_word_freqs = pd.concat({
    "false_friends": word_freq_df.loc[false_friend_words].LogFreq,
    "NN": word_freq_df.loc[nn_words].LogFreq,
    "VB": word_freq_df.loc[vb_words].LogFreq
}, names=["type"])

In [None]:
sns.displot(data=expt_word_freqs.reset_index(), x="LogFreq", row="type", kind="hist", bins=15,
            height=1, aspect=3, facet_kws={"sharey": False})

In [None]:
ff_results_summary2

In [None]:
focus_ff_results = all_ff_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")

ff_results_summary2 = focus_ff_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

ff_results_summary2["transfer_label"] = ff_results_summary2.inflection_from + " -> " + ff_results_summary2.inflection_to

# add in data for NNS->NNS and VBZ->VBZ
ff_results_summary2 = pd.concat([ff_results_summary2, nnvb_results_summary2.query("inflection_from == inflection_to and model_label == 'Word'")], axis=0)

ff_results_summary2["base_inflection"] = ff_results_summary2.inflection_from.str.replace("-FF", "")

ff_results_summary2 = ff_results_summary2[ff_results_summary2.base_inflection.isin(plot_inflections)]

g = sns.FacetGrid(ff_results_summary2, col="base_inflection", sharex=False, sharey=False)

def mapfn(data, **kwargs):
    ax = plt.gca()
    sns.heatmap(data.set_index(["inflection_from", "inflection_to"]).correct.unstack("inflection_to"),
                vmin=main_plot_vmin, vmax=main_plot_vmax, annot=True, ax=ax)

g.map_dataframe(mapfn)

for i, ax in enumerate(g.axes.flat):
    ax.set_title(ax.get_title().replace("base_inflection = ", ""))
    if i > 0:
        ax.set_ylabel("")
    if i < len(g.axes.flat) - 1:
        ax.collections[0].colorbar.remove()

g.fig.tight_layout()
g.fig.savefig(f"{output_dir}/ff_results.pdf")

In [None]:
focus_weak_ff_results = weak_ff_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")

weak_ff_results_summary2 = focus_weak_ff_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

weak_ff_results_summary2["transfer_label"] = weak_ff_results_summary2.inflection_from + " -> " + weak_ff_results_summary2.inflection_to

# add in data for NNS->NNS and VBZ->VBZ
weak_ff_results_summary2 = pd.concat([weak_ff_results_summary2, nnvb_results_summary2.query("inflection_from == inflection_to and model_label == 'Word'")], axis=0)

weak_ff_results_summary2["base_inflection"] = weak_ff_results_summary2.inflection_from.str.replace("-FF", "")

weak_ff_results_summary2 = weak_ff_results_summary2[weak_ff_results_summary2.base_inflection.isin(plot_inflections)]

g = sns.FacetGrid(weak_ff_results_summary2, col="base_inflection", sharex=False, sharey=False)

def mapfn(data, **kwargs):
    ax = plt.gca()
    sns.heatmap(data.set_index(["inflection_from", "inflection_to"]).correct.unstack("inflection_to"),
                vmin=main_plot_vmin, vmax=main_plot_vmax, annot=True, ax=ax)

g.map_dataframe(mapfn)

for i, ax in enumerate(g.axes.flat):
    ax.set_title(ax.get_title().replace("base_inflection = ", ""))
    if i > 0:
        ax.set_ylabel("")
    if i < len(g.axes.flat) - 1:
        ax.collections[0].colorbar.remove()

g.fig.tight_layout()
g.fig.savefig(f"{output_dir}/ff_results_weak.pdf")

In [None]:
weak_ff_results[weak_ff_results.inflection_to == "NNS-FF"].groupby("base_to").predicted_label.value_counts().to_csv("weak_results.csv")

In [None]:
weak_ff_results.query("base_to == 'den'")[["base_from", "allomorph_from", "predicted_label", "gt_label"]]

In [None]:
sns.catplot(data=focus_ff_results.groupby(["transfer_label", "base_to"]).correct.mean().reset_index(),
            x="transfer_label", y="correct", aspect=3)

In [None]:
focus_ff_results.allomorph_to.value_counts()

In [None]:
focus_ff_results.query("base_to == 'to'")[["gt_label", "allomorph_to"]].iloc[0]

In [None]:
focus_ff_results.query("base_to == 'why'")

In [None]:
focus_ff_results.query("transfer_label == 'NNS -> NNS-FF'").groupby("base_to").correct.mean().sort_values()

In [None]:
sns.catplot(data=all_ff_results.query("base_model_name == 'w2v2_8'").reset_index(),
            x="from_freq_bin", y="correct", hue="model_name",
            col="transfer_label", col_wrap=2, kind="point")

In [None]:
all_ff_results.query("base_model_name == 'w2v2_8' and model_name == 'ff_32' and transfer_label == 'NNS -> NNS-FF'").groupby(["to_freq_bin", "base_to"]).correct.agg(["count", "mean"]).dropna().groupby("to_freq_bin").apply(lambda xs: xs.sort_values("mean"))

In [None]:
sns.catplot(data=all_ff_results.query("base_model_name == 'w2v2_8'").reset_index(),
            x="to_freq_bin", y="correct", hue="model_name",
            col="transfer_label", col_wrap=2, kind="point")

## Controlled VBD analysis

In [None]:
all_vbd_results = all_results.query("experiment == 'regular' and inflection_from == 'VBD'")
all_vbd_results = pd.merge(all_vbd_results, mca.rename(columns={"base": "base_from", "inflection": "inflection_from", "most_common_allomorph": "allomorph_from"}),
               on=["base_from", "inflection_from"], how="left")
all_vbd_results = pd.merge(all_vbd_results, mca.rename(columns={"base": "base_to", "inflection": "inflection_to", "most_common_allomorph": "allomorph_to"}),
               on=["base_to", "inflection_to"], how="left")
all_vbd_results[["allomorph_from", "allomorph_to"]].value_counts()

In [None]:
keep_vbd_allomorphs = all_vbd_results.allomorph_from.value_counts().head(3).index
all_vbd_results = all_vbd_results[all_vbd_results.allomorph_from.isin(keep_vbd_allomorphs)
                                  & all_vbd_results.allomorph_to.isin(keep_vbd_allomorphs)]

In [None]:
# Add frequency information

all_vbd_results["inflected_from"] = all_vbd_results.from_equiv_label.apply(lambda x: eval(x)[1])
all_vbd_results["inflected_to"] = all_vbd_results.to_equiv_label.apply(lambda x: eval(x)[1])

all_vbd_results = pd.merge(all_vbd_results, word_freq_df["LogFreq"].rename("from_base_freq"),
                           left_on="base_from", right_index=True)
all_vbd_results = pd.merge(all_vbd_results, word_freq_df["LogFreq"].rename("to_base_freq"),
                           left_on="base_to", right_index=True)
all_vbd_results = pd.merge(all_vbd_results, word_freq_df["LogFreq"].rename("from_inflected_freq"),
                            left_on="inflected_from", right_index=True)
all_vbd_results = pd.merge(all_vbd_results, word_freq_df["LogFreq"].rename("to_inflected_freq"),
                            left_on="inflected_to", right_index=True)

all_vbd_results["from_freq"] = all_vbd_results[["from_base_freq", "from_inflected_freq"]].mean(axis=1)
all_vbd_results["to_freq"] = all_vbd_results[["to_base_freq", "to_inflected_freq"]].mean(axis=1)

_, vbd_frequency_bins = pd.qcut(pd.concat([all_vbd_results.to_freq, all_vbd_results.from_freq]), q=3, retbins=True)
all_vbd_results["from_freq_bin"] = pd.cut(all_vbd_results.from_freq, bins=vbd_frequency_bins, labels=[f"Q{i}" for i in range(1, 4)])
all_vbd_results["to_freq_bin"] = pd.cut(all_vbd_results.to_freq, bins=vbd_frequency_bins, labels=[f"Q{i}" for i in range(1, 4)])

In [None]:
def summarize_vbd_run(rows):
    rows["source_label"] = rows.inflection_from + " " + rows.allomorph_from
    rows["target_label"] = rows.inflection_to + " " + rows.allomorph_to

    rows["transfer_label"] = rows.inflection_from + " -> " + rows.inflection_to
    rows["phon_label"] = rows.allomorph_from + " -> " + rows.allomorph_to

    return rows

summary_groupers = ["inflection_from", "inflection_to", "allomorph_from", "allomorph_to"]
vbd_results_summary = all_vbd_results.groupby(run_groupers + summary_groupers) \
    .correct.agg(["count", "mean"]) \
    .reset_index(summary_groupers) \
    .groupby(run_groupers, group_keys=False) \
    .apply(summarize_vbd_run) \
    .reset_index()

vbd_results_summary

In [None]:
plot_results = []
for base_model_name, model_name, equivalence in plot_runs:
    results_i = vbd_results_summary.query("base_model_name == @base_model_name and model_name == @model_name and equivalence == @equivalence")
    if len(results_i) > 0:
        plot_results.append(results_i)
num_plot_runs = len(plot_results)

ncols = 2
nrows = int(np.ceil(num_plot_runs / ncols))
f, axs = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))

for ax, results_i in zip(axs.flat, plot_results):
    sns.heatmap(results_i.set_index(["source_label", "target_label"])["mean"].unstack(),
                vmin=0, vmax=1, ax=ax)
    key_row = results_i.iloc[0]
    ax.set_title(f"{key_row.base_model_name} -> {key_row.model_name} ({key_row.equivalence})")

### Focused plots

In [None]:
focus_base_model, focus_model, focus_equivalence = main_plot_run
foil_base_model, foil_model, foil_equivalence = "w2v2_8", "id", "id"

vbd_focus = all_vbd_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")
vbd_foil = all_vbd_results.query("base_model_name == @foil_base_model and model_name == @foil_model and equivalence == @foil_equivalence")
vbd_focus["model_label"] = "Word"
vbd_foil["model_label"] = "Wav2Vec"

vbd_focus = pd.concat([vbd_focus, vbd_foil])

allomorph_labels = {"D": "d", "T": "t", "IH D": "ɪd"}
vbd_focus["allomorph_from"] = vbd_focus.allomorph_from.map(allomorph_labels)
vbd_focus["allomorph_to"] = vbd_focus.allomorph_to.map(allomorph_labels)
vbd_focus

In [None]:
vbd_results_summary = vbd_focus.groupby(["model_label", "inflection_from", "inflection_to",
                                             "allomorph_from", "allomorph_to"]) \
    .correct.agg(["count", "mean"]) \
    .query("count >= 0") \
    .reset_index()

vbd_results_summary["source_label"] = vbd_results_summary.inflection_from + "\n" + vbd_results_summary.allomorph_from
vbd_results_summary["target_label"] = vbd_results_summary.inflection_to + "\n" + vbd_results_summary.allomorph_to

vbd_results_summary["transfer_label"] = vbd_results_summary.inflection_from + " -> " + vbd_results_summary.inflection_to
vbd_results_summary["phon_label"] = vbd_results_summary.allomorph_from + " " + vbd_results_summary.allomorph_to

# only retain cases where we have data in both transfer directions from source <-> target within inflection
vbd_results_summary["complement_exists"] = vbd_results_summary.apply(lambda row: len(vbd_results_summary.query("source_label == @row.target_label and target_label == @row.source_label")), axis=1)
vbd_results_summary = vbd_results_summary.query("complement_exists > 0").drop(columns=["complement_exists"])

vbd_results_summary

In [None]:
vbd_focus_bar = vbd_focus.assign(source_label=lambda xs: xs.inflection_from + " " + xs.allomorph_from)
order = vbd_focus_bar.groupby("source_label").correct.mean().sort_values().index
g = sns.catplot(data=vbd_focus_bar, x="allomorph_to", hue="source_label", y="correct", col="model_label", kind="bar")
g._legend.set_title("Train inflection\nand allomorph")

for ax in g.axes.flat:
    ax.set_title(ax.get_title().split("=")[1].strip())
    ax.set_xlabel("Test inflection")
    ax.set_ylabel("Accuracy")

In [None]:
# f, ax = plt.subplots(1, 2, figsize=(7 * 2, 6))

f, axs = plt.subplots(1, 3, figsize=(7 * 2, 6), gridspec_kw={'width_ratios': [1, 1, 0.04]})
for i, (ax, (model_label, rows)) in enumerate(zip(axs, vbd_results_summary.groupby("model_label"))):
    cbar_ax = None
    if i == 1:
        cbar_ax = axs.flat[-1]

    ax.set_title(model_label)
    sns.heatmap(rows.set_index(["source_label", "target_label"]).sort_index()["mean"].unstack("target_label"),
                vmin=main_plot_vmin, vmax=main_plot_vmax, annot=True, ax=ax,
                cbar=i == 1, cbar_ax=cbar_ax)

    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    ax.set_ylabel("Train")
    ax.set_xlabel("Test")

f.tight_layout()
f.savefig(f"{output_dir}/vbd_allomorphs.pdf")

## V3

In [None]:
plot_inflections_v3 = ["non", "agent", "comp"]

In [None]:
v3_cross_instances = pd.read_parquet("outputs/analogy_v3/inputs/librispeech-train-clean-100/w2v2/all_cross_instances.parquet")

### Pre-compute metadata

In [None]:
base_prons = v3_cross_instances.groupby("base").base_phones.value_counts().groupby("base").head(1) \
    .reset_index().drop(columns=["count"])

from src.utils import syllabifier
base_prons["num_syllables"] = base_prons.base_phones.apply(
    lambda phones: len(syllabifier.syllabify(syllabifier.English, phones)))

### Load results

In [None]:
all_results_v3 = concat_csv_with_indices(
        "outputs/analogy_v3/runs/**/experiment_results.csv",
        [lambda p: p.parent.name, lambda p: p.parents[1].name,
            lambda p: p.parents[2].name],
        ["equivalence", "model_name", "base_model_name"]) \
    .droplevel(-1).reset_index()

In [None]:
all_id_results_v3 = concat_csv_with_indices(
        "outputs/analogy_v3/runs_id/**/experiment_results.csv",
        [lambda p: p.parent.name],
        ["base_model_name"]) \
    .droplevel(-1).reset_index()
all_id_results_v3["model_name"] = "id"
all_id_results_v3["equivalence"] = "id"

In [None]:
all_results_v3 = pd.concat([all_results_v3, all_id_results_v3], ignore_index=True)

In [None]:
all_results_v3 = pd.merge(all_results_v3,
         base_prons.rename(columns={"base": "base_from",
                                    "num_syllables": "base_from_num_syllables"})
                    .drop(columns=["base_phones"]),
         on="base_from")

all_results_v3 = pd.merge(all_results_v3,
         base_prons.rename(columns={"base": "base_to",
                                    "num_syllables": "base_to_num_syllables"})
                    .drop(columns=["base_phones"]),
         on="base_to")

### load fixes

In [None]:
v3_fixes = pd.read_csv("20250218 annot.csv", index_col=0)
v3_fixes["morph"] = v3_fixes.morph.str.lower()

In [None]:
all_results_v3 = all_results_v3[all_results_v3.base_from.isin(v3_fixes.base) & all_results_v3.base_to.isin(v3_fixes.base)]

In [None]:
all_results_v3 = pd.merge(
    all_results_v3, v3_fixes[["base", "morph"]].rename(columns={"base": "base_from", "morph": "morph_from"}).fillna("non"),
    on=["base_from"], how="inner")
all_results_v3 = pd.merge(
    all_results_v3, v3_fixes[["base", "morph"]].rename(columns={"base": "base_to", "morph": "morph_to"}).fillna("non"),
    on=["base_to"], how="inner")

In [None]:
all_results_v3["inflection_from"] = all_results_v3.morph_from
all_results_v3["inflection_to"] = all_results_v3.morph_to

In [None]:
# DEV
all_results_v3 = all_results_v3[((all_results_v3.base_from_num_syllables == 1) & (all_results_v3.base_to_num_syllables == 1))]

### Layer-wise

In [None]:
all_results_v3.loc[(all_results_v3.experiment == "morph_related") & (all_results_v3.group == "(True,)"), "group_from"] = "morph"
all_results_v3.loc[(all_results_v3.experiment == "morph_related") & (all_results_v3.group == "(True,)"), "group_to"] = "morph"
all_results_v3.loc[(all_results_v3.experiment == "morph_related") & (all_results_v3.group == "(False,)"), "group_from"] = "non"
all_results_v3.loc[(all_results_v3.experiment == "morph_related") & (all_results_v3.group == "(False,)"), "group_to"] = "non"
all_results_v3.loc[(all_results_v3.experiment == "non_to_morph"), "group_from"] = "non"
all_results_v3.loc[(all_results_v3.experiment == "non_to_morph"), "group_to"] = "morph"
all_results_v3.loc[(all_results_v3.experiment == "morph_to_non"), "group_from"] = "morph"
all_results_v3.loc[(all_results_v3.experiment == "morph_to_non"), "group_to"] = "non"

In [None]:
plot_lw_v3 = all_results_v3
plot_lw_v3 = plot_lw_v3.groupby(run_groupers + ["inflection_from", "inflection_to", "base_to"]).correct.mean().reset_index(["inflection_from", "inflection_to", "base_to"])
# get just the relevant runs from plot_runs
plot_lw_v3 = pd.concat([plot_lw_v3.loc[plot_run] for plot_run in plot_runs]).reset_index()
plot_lw_v3["layer"] = plot_lw_v3.base_model_name.str.extract(r"_(\d+)$").astype(int)
plot_lw_v3["label"] = plot_lw_v3.inflection_from + " -> " + plot_lw_v3.inflection_to

plot_lw_v3 = plot_lw_v3[plot_lw_v3.inflection_from.isin(plot_inflections_v3)]

In [None]:
sns.catplot(data=plot_lw_v3, x="layer", y="correct", row="model_name", hue="label",
            kind="point", errorbar="se", height=3, aspect=2)

In [None]:
all_er_results = []

for run, run_results in all_results_v3.groupby(run_groupers):
    transfer_results = pd.Series(run_results.experiment[run_results.experiment.str.contains("_to_")]).unique()
    run_results = run_results.set_index("experiment")

    for expt in transfer_results:
        inflection_from, inflection_to = re.findall(r"(\w+)_to_(\w+)", expt)[0]
        expt_df = run_results.loc[expt].copy()

        num_seen_words = min(len(expt_df.base_from.unique()), len(expt_df.base_to.unique()))
        # DEV
        # if num_seen_words < 10:
        #     print(f"Skipping {expt} due to only {num_seen_words} seen words")
        #     continue

        # expt_df["inflection_from"] = inflection_from
        # expt_df["inflection_to"] = inflection_to

        all_er_results.append(expt_df)

    all_er_results.append(run_results.loc["morph_related"].copy())

all_er_results = pd.concat(all_er_results)

In [None]:
all_er_results["inflected_from"] = all_er_results.from_equiv_label.apply(lambda x: eval(x)[1])
all_er_results["inflected_to"] = all_er_results.to_equiv_label.apply(lambda x: eval(x)[1])
all_er_results["transfer_label"] = all_er_results.inflection_from + " -> " + all_er_results.inflection_to

all_er_results = pd.merge(all_er_results, word_freq_df.LogFreq.rename("from_base_freq"),
                            left_on="base_from", right_index=True)
all_er_results = pd.merge(all_er_results, word_freq_df.LogFreq.rename("from_inflected_freq"),
                            left_on="inflected_from", right_index=True)
all_er_results = pd.merge(all_er_results, word_freq_df.LogFreq.rename("to_base_freq"),
                              left_on="base_to", right_index=True)
all_er_results = pd.merge(all_er_results, word_freq_df.LogFreq.rename("to_inflected_freq"),
                            left_on="inflected_to", right_index=True)

all_er_results["from_freq"] = all_er_results[["from_base_freq", "from_inflected_freq"]].mean(axis=1)
all_er_results["to_freq"] = all_er_results[["to_base_freq", "to_inflected_freq"]].mean(axis=1)

_, er_frequency_bins = pd.qcut(pd.concat([all_er_results.to_freq, all_er_results.from_freq]), q=5, retbins=True)
all_er_results["to_freq_bin"] = pd.cut(all_er_results.to_freq, bins=frequency_bins, labels=[f"Q{i}" for i in range(1, 6)])
all_er_results["from_freq_bin"] = pd.cut(all_er_results.from_freq, bins=frequency_bins, labels=[f"Q{i}" for i in range(1, 6)])

In [None]:
# DEV
focus_base_model, focus_model, focus_equivalence = "w2v2_8", "id", "id"
# focus_base_model, focus_model, focus_equivalence = "w2v2_8", "ff_32", "word_broad_10frames_fixedlen25"

focus_er_results = all_er_results.query("base_model_name == @focus_base_model and model_name == @focus_model and equivalence == @focus_equivalence")

focus_er_results_summary = focus_er_results.groupby(["inflection_from", "inflection_to"]) \
    [["correct", "gt_label_rank", "gt_distance"]].mean().reset_index()

g = sns.FacetGrid(focus_er_results_summary, sharex=False, sharey=False, height=4, aspect=1.25)

def mapfn(data, **kwargs):
    ax = plt.gca()
    sns.heatmap(data.set_index(["inflection_from", "inflection_to"]).correct.unstack("inflection_to"),
                annot=True, ax=ax)

g.map_dataframe(mapfn)

for i, ax in enumerate(g.axes.flat):
    ax.set_title(ax.get_title().replace("inflection_from = ", ""))
    if i > 0:
        ax.set_ylabel("")
    if i < len(g.axes.flat) - 1:
        ax.collections[0].colorbar.remove()

g.fig.tight_layout()
g.fig.savefig(f"{output_dir}/er_results.pdf")

In [None]:
focus_er_results.groupby(["inflection_from", "inflection_to"]).correct.agg(["count", "mean"]).sort_values("mean")

In [None]:
freq_stats_df = pd.concat([
    all_er_results[["inflection_from", "base_from", "from_freq"]].reset_index(drop=True)
        .rename(columns={"inflection_from": "inflection", "base_from": "base", "from_freq": "freq"}),
    all_er_results[["inflection_to", "base_to", "to_freq"]].reset_index(drop=True)
        .rename(columns={"inflection_to": "inflection", "base_to": "base", "to_freq": "freq"})]) \
.drop_duplicates(["inflection", "base"])

In [None]:
g = sns.displot(data=freq_stats_df, x="freq", row="inflection", kind="hist", bins=10, height=2, aspect=3, facet_kws={"sharey": False})

# plot medians as vline
for ax, row_name in zip(g.axes.flat, g.row_names):
    ax.axvline(freq_stats_df.query("inflection == @row_name").freq.median(), color="red", linewidth=2)

In [None]:
from scipy.stats import ttest_ind
ttest_ind(freq_stats_df.query("inflection == 'comp'").freq,
          freq_stats_df.query("inflection == 'agent'").freq)

In [None]:
sns.catplot(data=focus_er_results.reset_index(),
            x="from_freq_bin", y="correct", hue="inflection_to",
            col="inflection_from", col_wrap=2, kind="point", errorbar="se")

In [None]:
sns.catplot(data=focus_er_results.reset_index(),
            x="to_freq_bin", y="correct", hue="inflection_from",
            col="inflection_to", col_wrap=2, kind="point", errorbar="se")

In [None]:
focus_er_results.query("inflection_to == 'agent'").groupby("base_to").correct.mean().sort_values().head(40)

In [None]:
focus_er_results.query("base_to == 'sin'").predicted_label.value_counts()

In [None]:
focus_er_results.query("inflection_to == 'comp'").groupby("base_to").correct.mean().sort_values().head(40)

In [None]:
focus_er_results.query("base_to == 'sad'").predicted_label.value_counts()