Prepare stimuli shared across all analogy evaluations.

In [None]:
from collections import Counter, defaultdict
import functools
import pickle

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec
from src.analysis import analogy

In [None]:
state_space_specs_path = f"outputs/state_space_specs/librispeech-train-clean-100/w2v2_8/state_space_specs.h5"

pos_counts_path = "data/pos_counts.pkl"
output_dir = "."

seed = 1234

min_samples_per_word = 5
max_samples_per_word = 100

inflection_targets = [
    "VBD",
    "VBZ",
    "VBG",
    "NNS",
    "NOT-latin",
]

In [None]:
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
state_space_spec = state_space_spec.subsample_instances(max_samples_per_word)

In [None]:
with open(pos_counts_path, "rb") as f:
    pos_counts = pickle.load(f)

In [None]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label", "instance_idx", "frame_idx"]).sort_index()

In [None]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [None]:
ss_spans = state_space_spec.target_frame_spans_df

## Helper functions

In [None]:
def guess_nns_vbz_allomorph(base_phones):
    """
    Given a list of CMUDICT phones for a base form, 
    return the 'expected' post-divergence allomorph 
    (S, Z, or IH Z, etc.) for the English plural / 3sg verb.
    """
    last_phone = base_phones[-1]

    # Define sets or lists for final-phoneme checks
    SIBILANTS = {"S", "Z", "SH", "CH", "JH", "ZH"}
    VOICELESS = {"P", "T", "K", "F", "TH"}  # Could add others as needed
    
    if last_phone in SIBILANTS:
        # e.g., 'CH' -> "IH Z"
        return "IH Z"
    elif last_phone in VOICELESS:
        # e.g., 'K', 'P', 'T' -> "S"
        return "S"
    else:
        # default to voiced => "Z"
        return "Z"


def guess_past_allomorph(base_phones):
    """
    Given a list of CMUDICT phones for a base form,
    return the 'expected' post-divergence allomorph
    (T, D, or IH D) for the English past tense.
    """
    last_phone = base_phones[-1]

    ALVEOLAR_STOPS = {"T", "D"}
    # Example set of voiceless consonants (non-exhaustive—adjust as needed).
    VOICELESS = {"P", "F", "K", "S", "SH", "CH", "TH"}
    
    if last_phone in ALVEOLAR_STOPS:
        # E.g., "want" -> "wanted" => "AH0 D"
        return "IH D"
    elif last_phone in VOICELESS:
        # E.g., "jump" -> "jumped" => "T"
        return "T"
    else:
        # default to voiced => "D"
        return "D"

## Set up main stimuli

In [None]:
labels = state_space_spec.label_counts
labels = set(labels[labels > min_samples_per_word].index)

inflection_results_df = analogy.get_inflection_df(
    inflection_targets, labels)
inflection_results_df["base_idx"] = inflection_results_df.base.map({l: i for i, l in enumerate(state_space_spec.labels)})
inflection_results_df["inflected_idx"] = inflection_results_df.inflected.map({l: i for i, l in enumerate(state_space_spec.labels)})
inflection_results_df

In [None]:
# Add on random word pair baseline
num_random_word_pairs = inflection_results_df.groupby("inflection").size().max()
random_word_pairs = np.random.choice(len(list(labels)), size=(num_random_word_pairs, 2))
random_word_pairs = pd.DataFrame(random_word_pairs, columns=["base_idx", "inflected_idx"])
random_word_pairs["base"] = random_word_pairs.base_idx.map({i: l for i, l in enumerate(state_space_spec.labels)})
random_word_pairs["inflected"] = random_word_pairs.inflected_idx.map({i: l for i, l in enumerate(state_space_spec.labels)})
random_word_pairs["is_regular"] = False
random_word_pairs["inflection"] = "random"
random_word_pairs = random_word_pairs.set_index("inflection")
random_word_pairs

In [None]:
inflection_results_df = pd.concat([inflection_results_df, random_word_pairs])

## Prepare token-level features

### NNS/VBZ ambiguity

In [None]:
def is_noun_ambiguous(row):
    attested_pos = set(pos_counts[row.base].keys()) | set(pos_counts[row.inflected].keys())
    return len(attested_pos & {"VERB"}) > 0
inflection_results_df.loc["NNS", "base_ambig_NN_VB"] = inflection_results_df.loc["NNS"].apply(is_noun_ambiguous, axis=1)
# inflection_results_df.loc["NNS"].groupby("base_ambig_NN_VB").sample(10)

In [None]:
def is_verb_ambiguous(row):
    attested_pos = set(pos_counts[row.base].keys()) | set(pos_counts[row.inflected].keys())
    return len(attested_pos & {"NOUN"}) > 0
inflection_results_df.loc["VBZ", "base_ambig_NN_VB"] = inflection_results_df.loc["VBZ"].apply(is_verb_ambiguous, axis=1)
# inflection_results_df.loc["VBZ"].groupby("base_ambig_NN_VB").sample(10)

### Post-divergence analysis

In [None]:
@functools.lru_cache
def _get_base_forms(base_label: str) -> frozenset[tuple[str, ...]]:
    base_cuts = cuts_df.loc[base_label]
    base_phon_forms = frozenset(base_cuts.groupby("instance_idx").apply(
        lambda xs: tuple(xs.description)))
    return base_phon_forms

In [None]:
@functools.lru_cache
def _get_phonological_divergence(base_forms: frozenset[tuple[str, ...]],
                                 inflected_form: tuple[str, ...]) -> tuple[int, tuple[str, ...]]:
    phono_divergence_points = []
    for base_phones in base_forms:
        for idx in range(len(inflected_form) + 1):
            if inflected_form[:idx] != base_phones[:idx]:
                break
        phono_divergence_points.append(idx - 1)
    phono_divergence_point = max(phono_divergence_points)

    post_divergence = inflected_form[phono_divergence_point:]
    return phono_divergence_point, post_divergence

In [None]:
def get_phonological_divergence(base_label, inflected_label, inflected_instance_idx):
    try:
        base_phon_forms = _get_base_forms(base_label)
        inflected_phones = tuple(cuts_df.loc[inflected_label].loc[inflected_instance_idx].description)
    except KeyError:
        return Counter()

    div_point, div_content = _get_phonological_divergence(base_phon_forms, inflected_phones)
    return inflected_phones, div_content

In [None]:
inflection_instances = []

for inflection, row in tqdm(inflection_results_df.iterrows(), total=len(inflection_results_df)):
    inflected_instance_idxs = ss_spans.query(f"label == @row.inflected").instance_idx
    for inflected_instance_idx in inflected_instance_idxs:
        inflected_phones, post_divergence = \
            get_phonological_divergence(row.base, row.inflected, inflected_instance_idx)
        
        inflected_phones = " ".join(inflected_phones)
        post_divergence = " ".join(post_divergence)
        inflection_instances.append({
            "inflection": inflection,
            "base": row.base,
            "inflected": row.inflected,
            "inflected_instance_idx": inflected_instance_idx,
            "inflected_phones": inflected_phones,
            "post_divergence": post_divergence,
        })

In [None]:
inflection_instance_df = pd.DataFrame(inflection_instances)

# Now merge with type-level information.
inflection_instance_df = pd.merge(inflection_instance_df,
                                  inflection_results_df.reset_index(),
                                  how="left",
                                  on=["inflection", "base", "inflected"])
inflection_instance_df

In [None]:
# compute most frequent allomorph of each inflection
most_common_allomorphs = inflection_instance_df.groupby(["inflection", "base"]).post_divergence \
    .apply(lambda xs: xs.value_counts().idxmax()) \
    .rename("most_common_allomorph").reset_index()

## Build full cross product of stimuli

In [None]:
inflection_cross_instances = []
base_cross_instances = []

for inflection, row in tqdm(inflection_results_df.iterrows(), total=len(inflection_results_df)):
    inflected_instance_idxs = ss_spans.query(f"label == @row.inflected").instance_idx
    inflected_forms = cut_phonemic_forms.loc[row.inflected]
    for inflected_instance_idx in inflected_instance_idxs:
        inflection_cross_instances.append({
            "inflection": inflection,
            "base": row.base,
            "inflected": row.inflected,
            "inflected_instance_idx": inflected_instance_idx,
            "inflected_phones": inflected_forms.loc[inflected_instance_idx]
        })

    base_instance_idxs = ss_spans.query(f"label == @row.base").instance_idx
    base_forms = cut_phonemic_forms.loc[row.base]
    for base_instance_idx in base_instance_idxs:
        base_cross_instances.append({
            "inflection": inflection,
            "base": row.base,
            "inflected": row.inflected,
            "base_instance_idx": base_instance_idx,
            "base_phones": base_forms.loc[base_instance_idx]
        })

In [None]:
# add in post-divergence information
inflection_cross_instances_df = pd.DataFrame(inflection_cross_instances)
merge_on = ["inflection", "base", "inflected", "inflected_instance_idx"]
inflection_cross_instances_df = pd.merge(inflection_cross_instances_df,
                                         inflection_instance_df[merge_on + ["post_divergence"]],
                                         on=merge_on)

all_cross_instances = pd.merge(pd.DataFrame(base_cross_instances),
         inflection_cross_instances_df,
         on=["inflection", "base", "inflected"],
         how="outer")

# Now merge with type-level information.
all_cross_instances = pd.merge(inflection_results_df.reset_index(),
                               all_cross_instances,
                               on=["inflection", "base", "inflected"],
                               validate="1:m")

all_cross_instances["exclude_main"] = False
all_cross_instances

## Forced-choice experiment materials

In [None]:
def get_forced_choice_cross_instances(fc_pair: tuple[str, str], allomorph_guesser,
                                      min_frequency=2):
    """
    The "forced choice" experiment asks whether a model prefers to make predictions
    which are consistent with an allomorphy structure or not.

    For example, allomorphs of the plural morpheme in English can be either /s/ /z/
    or /Iz/ depending on the final phoneme of the base form. `allomorph_guesser`
    specifies this allomorphy rule for the given request.

    Parameters:
    - fc_pair: pair of phoneme strings (space-sparated CMUDICT phonemes) which
        form the forced choice pair
    - allomorph_guesser: function taking a list of CMUDICT phonemes and returning
        the appropriate allomorph for the forced choice setup
    - min_frequency: minimum number of instances required for a given pair to be
        included in the output
    """

    label_counts = cut_phonemic_forms.groupby("label").size()

    step0 = cut_phonemic_forms.loc[cut_phonemic_forms.str[-len(fc_pair[0]):] == fc_pair[0]]
    # if you remove post-div content, it's still attested
    step1 = step0.loc[step0.str[:-len(fc_pair[0])].str.strip().isin(cut_phonemic_forms)]
    # and the alternative post-div is also attested
    step2 = step1.loc[(step1.str[:-len(fc_pair[0])].str.strip() + (" " + fc_pair[1])).isin(cut_phonemic_forms)]
    step2 = step2.reset_index()

    def get_label0(description):
        candidates = cut_phonemic_forms[cut_phonemic_forms == description[:-len(fc_pair[0])].strip()].reset_index()
        return candidates.groupby(["label", "description"]).size().index[0]
    def get_label2(description):
        candidates = cut_phonemic_forms[cut_phonemic_forms == description[:-len(fc_pair[0])].strip() + " " + fc_pair[1]].reset_index()
        return candidates.groupby(["label", "description"]).size().index[0]
    step2_label0 = {description: get_label0(description) for description in step2.description.unique()}
    step2_label2 = {description: get_label2(description) for description in step2.description.unique()}
    step2["inferred_label0"] = step2.description.map({desc: label for desc, (label, _) in step2_label0.items()})
    step2["inferred_form0"] = step2.description.map({desc: form for desc, (_, form) in step2_label0.items()})
    step2["inferred_label2"] = step2.description.map({desc: label for desc, (label, _) in step2_label2.items()})
    step2["inferred_form2"] = step2.description.map({desc: form for desc, (_, form) in step2_label2.items()})

    # ignore where label2 == label
    step3 = step2.loc[step2.inferred_label2 != step2.label]

    # filter by frequency
    step4 = pd.merge(step3, label_counts.rename("label_count"),
                    left_on="label", right_index=True, how="inner")
    step4 = pd.merge(step4, label_counts.rename("label0_count"),
                    left_on="inferred_label0", right_index=True, how="inner")
    step4 = pd.merge(step4, label_counts.rename("label2_count"),
                    left_on="inferred_label2", right_index=True, how="inner")
    step4 = step4[(step4.label_count >= min_frequency)
                  & (step4.label0_count >= min_frequency)
                  & (step4.label2_count >= min_frequency)]
    
    step4 = step4.rename(columns={
        "label": "label1",
        "inferred_label0": "label0",
        "inferred_label2": "label2",

        "description": "form1",
        "inferred_form0": "form0",
        "inferred_form2": "form2",

        "instance_idx": "instance_idx1",
    }).drop(columns=["label0_count", "label_count", "label2_count"])
    
    # retrieve all instances of the variants
    fc_cross = pd.merge(
        step4,
        cut_phonemic_forms.reset_index().rename(
            columns={"label": "label0", "description": "form0",
                     "instance_idx": "instance_idx0"}),
        on=["label0", "form0"], how="left")
    fc_cross = pd.merge(
        fc_cross,
        cut_phonemic_forms.reset_index().rename(
            columns={"label": "label2", "description": "form2",
                     "instance_idx": "instance_idx2"}),
        on=["label2", "form2"], how="left")
    
    fc_cross = fc_cross[["label0", "label1", "label2",
                         "form0", "form1", "form2",
                         "instance_idx0", "instance_idx1", "instance_idx2"]]
    
    # now prepare a single flat structure relating base (label0) to inflected
    # (label1 w.l.o.g.)
    # we will record frequency of inflection to label1 vs label2
    fc_cross = fc_cross.rename(columns={
        "label0": "base",
        "label1": "inflected",
        "label2": "inflected2",
        
        "form0": "base_phones",
        "form1": "inflected_phones",
        "form2": "inflected2_phones",

        "instance_idx0": "base_instance_idx",
        "instance_idx1": "inflected_instance_idx",
    }).drop(columns=["instance_idx2"])
    fc_cross["inflection"] = "FC-" + "_".join(fc_pair)
    fc_cross["post_divergence"] = fc_pair[0]
    allomorph_map = {base_phones: allomorph_guesser(base_phones.split()) for base_phones in fc_cross.base_phones.unique()}
    fc_cross["strong"] = fc_cross.base_phones.map(allomorph_map) == fc_pair[0]

    fc_cross["base_idx"] = fc_cross.base.map({l: i for i, l in enumerate(state_space_spec.labels)})
    fc_cross["inflected_idx"] = fc_cross.inflected.map({l: i for i, l in enumerate(state_space_spec.labels)})
    fc_cross["exclude_main"] = True
    
    return fc_cross

In [None]:
fc_pairs = [(("Z", "S"), guess_nns_vbz_allomorph),
            (("Z", "IH Z"), guess_nns_vbz_allomorph),
            (("S", "IH Z"), guess_nns_vbz_allomorph),
            
            (("D", "T"), guess_past_allomorph),
            (("D", "IH D"), guess_past_allomorph),
            (("T", "IH D"), guess_past_allomorph),]

fc_cross_instances = pd.concat([get_forced_choice_cross_instances(*fc_pair) for fc_pair in fc_pairs])

In [None]:
all_cross_instances = pd.concat([all_cross_instances, fc_cross_instances], axis=0)

## False friend production

In [None]:
def compute_false_friends():
    false_friends_dfs = {}
    inflection_allomorph_grouper = most_common_allomorphs \
        [~most_common_allomorphs.inflection.isin(("random", "NOT-latin"))] \
        .groupby("inflection").most_common_allomorph \
        .apply(lambda xs: xs.value_counts()[:3]).index
    for inflection, post_divergence in tqdm(inflection_allomorph_grouper):
        avoid_inflections = {"POS", inflection}
        if inflection == "NNS":
            avoid_inflections.add("VBZ")
        elif inflection == "VBZ":
            avoid_inflections.add("NNS")
        avoid_inflections = list(avoid_inflections)

        try:
            false_friends_dfs[inflection, post_divergence] = \
                analogy.prepare_false_friends(
                    inflection_results_df,
                    inflection_instance_df,
                    cut_phonemic_forms,
                    post_divergence,
                    avoid_inflections=avoid_inflections)
        except:
            print("Failed for", inflection, post_divergence)
            continue

    return false_friends_dfs

false_friends_dfs = compute_false_friends()

In [None]:
false_friends_df = pd.concat(false_friends_dfs, names=["inflection", "post_divergence"]).droplevel(-1)

# manually exclude some cases that don't get filtered out, often just because they're too
# low frequency for both true base and inflected form to appear

# share exclusion list for NNS and VBZ since we have experiments relating these two
# so this is any false-friend for which their is a phonologically identical "base"
# that could instantiate a VBZ or NNS inflection
exclude_NNS_VBZ = ("adds americans arabs assyrians berries carlyle's childs christians "
                   "counties cruise dares dealings delawares europeans excellencies "
                   "fins fours galleries gaze germans indians isles maids mary's negroes "
                   "nuns peas phrase pyes reflections rodgers romans russians simpsons "
                   "spaniards sundays vickers weeds wigwams williams "
                   "jews odds news hose dis yes ice cease peace s us "
                    
                   "greeks lapse mix philips trunks its "
                    
                   "breeches occurrences personages").split()
false_friends_manual_exclude = {
    "NNS": exclude_NNS_VBZ,
    "VBZ": exclude_NNS_VBZ,
    "VBD": ("armored bald bard counseled crude dared enquired healed knowed legged "
            "mourned natured renowned rude second ward wild willed withered hauled "

            "tract wrapped fitted hearted heralded intrusted knitted wretched").split(),
    "VBG": ("ceiling daring fleeting morning roaming wasting weaving weighing "
            "whining willing chuckling kneeling sparkling startling").split()
}

false_friends_df = false_friends_df.groupby("inflection", as_index=False).apply(
    lambda xs: xs[~xs.inflected.isin(false_friends_manual_exclude.get(xs.name, []))]).droplevel(0)

# exclude the (quite interesting) cases where the "base" and "inflected" form are
# actually orthographically matched, and we're seeing the divergence due to a pronunciation
# variant (e.g. don't as D OW N vs D O WN T)
false_friends_df = false_friends_df[false_friends_df.base != false_friends_df.inflected]

false_friends_df

In [None]:
false_friends_df.loc[["NNS", "VBZ"], "strong_expected"] = false_friends_df.loc[["NNS", "VBZ"]].apply(lambda xs: guess_nns_vbz_allomorph(xs.base_form.split(" ")), axis=1)
false_friends_df.loc[["VBD"], "strong_expected"] = false_friends_df.loc[["VBD"]].apply(lambda xs: guess_past_allomorph(xs.base_form.split(" ")), axis=1)
false_friends_df["strong"] = false_friends_df.index.get_level_values("post_divergence") == false_friends_df.strong_expected

### Prepare false-friends cross product and merge

In [None]:
cross_false_friends_df = pd.merge(false_friends_df.reset_index(),
         cut_phonemic_forms.reset_index().rename(
             columns={"label": "base", "description": "base_form",
                      "instance_idx": "base_instance_idx"}),
         on=["base", "base_form"], how="left")
cross_false_friends_df = pd.merge(cross_false_friends_df,
         cut_phonemic_forms.reset_index().rename(
             columns={"label": "inflected", "description": "inflected_form",
                      "instance_idx": "inflected_instance_idx"}),
         on=["inflected", "inflected_form"], how="left")

# update to match all_cross_instances schema
cross_false_friends_df = cross_false_friends_df.rename(
    columns={"base_form": "base_phones",
             "inflected_form": "inflected_phones"})
cross_false_friends_df["base_idx"] = cross_false_friends_df.base.map({l: i for i, l in enumerate(state_space_spec.labels)})
cross_false_friends_df["inflected_idx"] = cross_false_friends_df.inflected.map({l: i for i, l in enumerate(state_space_spec.labels)})
cross_false_friends_df["is_regular"] = True

cross_false_friends_df["inflection"] = (cross_false_friends_df.inflection + "-FF-").str.cat(cross_false_friends_df.post_divergence, sep="")
cross_false_friends_df["exclude_main"] = True
cross_false_friends_df

In [None]:
all_cross_instances = pd.concat([all_cross_instances, cross_false_friends_df], axis=0)

## Save

In [None]:
state_space_spec.to_hdf5(f"{output_dir}/state_space_spec.h5")

In [None]:
inflection_results_df.to_parquet(f"{output_dir}/inflection_results.parquet")
inflection_instance_df.to_parquet(f"{output_dir}/inflection_instances.parquet")
all_cross_instances.to_parquet(f"{output_dir}/all_cross_instances.parquet")

In [None]:
false_friends_df.to_csv(f"{output_dir}/false_friends.csv")
most_common_allomorphs.to_csv(f"{output_dir}/most_common_allomorphs.csv")