Prepare stimuli shared across all analogy evaluations.

In [2]:
from collections import Counter, defaultdict
import functools
import pickle

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec
from src.analysis import analogy

In [43]:
state_space_specs_path = f"outputs/state_space_specs/librispeech-train-clean-100/w2v2_8/state_space_specs.h5"

pos_counts_path = "data/pos_counts.pkl"
output_dir = "."

seed = 1234

min_samples_per_word = 5
max_samples_per_word = 100

inflection_targets = [
    "VBD",
    "VBZ",
    "VBG",
    "NNS",
    "NOT-latin",
]

In [5]:
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
state_space_spec = state_space_spec.subsample_instances(max_samples_per_word)

In [19]:
with open(pos_counts_path, "rb") as f:
    pos_counts = pickle.load(f)

In [6]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label", "instance_idx", "frame_idx"]).sort_index()

In [7]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [28]:
ss_spans = state_space_spec.target_frame_spans_df

## Set up main stimuli

In [11]:
labels = state_space_spec.label_counts
labels = set(labels[labels > min_samples_per_word].index)

inflection_results_df = analogy.get_inflection_df(
    inflection_targets, labels)
inflection_results_df["base_idx"] = inflection_results_df.base.map({l: i for i, l in enumerate(state_space_spec.labels)})
inflection_results_df["inflected_idx"] = inflection_results_df.inflected.map({l: i for i, l in enumerate(state_space_spec.labels)})
inflection_results_df

Unnamed: 0_level_0,inflected,base,is_regular,base_idx,inflected_idx
inflection,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
VBD,honored,honor,True,1254,10249
VBD,halted,halt,True,7053,9015
VBD,eyed,eye,True,1024,3838
VBD,humored,humor,True,7899,6691
VBD,felt,feel,False,630,1482
...,...,...,...,...,...
NOT-latin,indistinct,distinct,True,2207,8621
NOT-latin,indefinitely,definitely,True,5070,52
NOT-latin,insufficient,sufficient,True,1424,8630
NOT-latin,involuntary,voluntary,True,1938,8174


In [12]:
# Add on random word pair baseline
num_random_word_pairs = inflection_results_df.groupby("inflection").size().max()
random_word_pairs = np.random.choice(len(list(labels)), size=(num_random_word_pairs, 2))
random_word_pairs = pd.DataFrame(random_word_pairs, columns=["base_idx", "inflected_idx"])
random_word_pairs["base"] = random_word_pairs.base_idx.map({i: l for i, l in enumerate(state_space_spec.labels)})
random_word_pairs["inflected"] = random_word_pairs.inflected_idx.map({i: l for i, l in enumerate(state_space_spec.labels)})
random_word_pairs["is_regular"] = False
random_word_pairs["inflection"] = "random"
random_word_pairs = random_word_pairs.set_index("inflection")
random_word_pairs

Unnamed: 0_level_0,base_idx,inflected_idx,base,inflected,is_regular
inflection,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
random,8272,5559,wholeheartedly,akin,False
random,6987,3822,turkey,ached,False
random,2588,8844,faltering,gap,False
random,8946,9033,shimmering,whittled,False
random,2981,8851,fox's,majestic,False
...,...,...,...,...,...
random,9199,1198,reigned,intention,False
random,6086,6066,ethereal,metaphors,False
random,8662,7271,sterile,affect,False
random,2285,9184,cast,reed,False


In [13]:
inflection_results_df = pd.concat([inflection_results_df, random_word_pairs])

## Prepare token-level features

### NNS/VBZ ambiguity

In [20]:
def is_noun_ambiguous(row):
    attested_pos = set(pos_counts[row.base].keys()) | set(pos_counts[row.inflected].keys())
    return len(attested_pos & {"VERB"}) > 0
inflection_results_df.loc["NNS", "base_ambig_NN_VB"] = inflection_results_df.loc["NNS"].apply(is_noun_ambiguous, axis=1)
inflection_results_df.loc["NNS"].groupby("base_ambig_NN_VB").sample(10)

Unnamed: 0_level_0,inflected,base,is_regular,base_idx,inflected_idx,base_ambig_NN_VB
inflection,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
NNS,gowns,gown,True,11452,1265,False
NNS,weapons,weapon,True,11000,7153,False
NNS,games,game,True,2764,3388,False
NNS,husbands,husband,True,3405,11695,False
NNS,luxuries,luxury,True,6154,13965,False
NNS,yards,yard,True,3377,10792,False
NNS,impulses,impulse,True,1742,12185,False
NNS,passions,passion,True,1588,7961,False
NNS,agonies,agony,True,1909,15665,False
NNS,orphans,orphan,True,14353,12057,False


In [21]:
def is_verb_ambiguous(row):
    attested_pos = set(pos_counts[row.base].keys()) | set(pos_counts[row.inflected].keys())
    return len(attested_pos & {"NOUN"}) > 0
inflection_results_df.loc["VBZ", "base_ambig_NN_VB"] = inflection_results_df.loc["VBZ"].apply(is_verb_ambiguous, axis=1)
inflection_results_df.loc["VBZ"].groupby("base_ambig_NN_VB").sample(10)

Unnamed: 0_level_0,inflected,base,is_regular,base_idx,inflected_idx,base_ambig_NN_VB
inflection,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
VBZ,serves,serve,True,3209,4609,False
VBZ,involves,involve,True,8454,4761,False
VBZ,belongs,belong,True,1615,1616,False
VBZ,describes,describe,True,3330,4341,False
VBZ,follows,follow,True,695,4468,False
VBZ,speaks,speak,True,2224,6952,False
VBZ,owns,own,True,1222,13193,False
VBZ,pleases,please,True,3692,3087,False
VBZ,reminds,remind,True,7425,8557,False
VBZ,eats,eat,True,2919,11348,False


### Post-divergence analysis

In [22]:
@functools.lru_cache
def _get_base_forms(base_label: str) -> frozenset[tuple[str, ...]]:
    base_cuts = cuts_df.loc[base_label]
    base_phon_forms = frozenset(base_cuts.groupby("instance_idx").apply(
        lambda xs: tuple(xs.description)))
    return base_phon_forms

In [23]:
@functools.lru_cache
def _get_phonological_divergence(base_forms: frozenset[tuple[str, ...]],
                                 inflected_form: tuple[str, ...]) -> tuple[int, tuple[str, ...]]:
    phono_divergence_points = []
    for base_phones in base_forms:
        for idx in range(len(inflected_form) + 1):
            if inflected_form[:idx] != base_phones[:idx]:
                break
        phono_divergence_points.append(idx - 1)
    phono_divergence_point = max(phono_divergence_points)

    post_divergence = inflected_form[phono_divergence_point:]
    return phono_divergence_point, post_divergence

In [24]:
def get_phonological_divergence(base_label, inflected_label, inflected_instance_idx):
    try:
        base_phon_forms = _get_base_forms(base_label)
        inflected_phones = tuple(cuts_df.loc[inflected_label].loc[inflected_instance_idx].description)
    except KeyError:
        return Counter()

    div_point, div_content = _get_phonological_divergence(base_phon_forms, inflected_phones)
    return inflected_phones, div_content

In [31]:
inflection_instances = []

for inflection, row in tqdm(inflection_results_df.iterrows(), total=len(inflection_results_df)):
    inflected_instance_idxs = ss_spans.query(f"label == @row.inflected").instance_idx
    for inflected_instance_idx in inflected_instance_idxs:
        inflected_phones, post_divergence = \
            get_phonological_divergence(row.base, row.inflected, inflected_instance_idx)
        
        inflected_phones = " ".join(inflected_phones)
        post_divergence = " ".join(post_divergence)
        inflection_instances.append({
            "inflection": inflection,
            "base": row.base,
            "inflected": row.inflected,
            "inflected_instance_idx": inflected_instance_idx,
            "inflected_phones": inflected_phones,
            "post_divergence": post_divergence,
        })

  0%|          | 0/4632 [00:00<?, ?it/s]

In [32]:
inflection_instance_df = pd.DataFrame(inflection_instances)

# Now merge with type-level information.
inflection_instance_df = pd.merge(inflection_instance_df,
                                  inflection_results_df.reset_index(),
                                  how="left",
                                  on=["inflection", "base", "inflected"])
inflection_instance_df

Unnamed: 0,inflection,base,inflected,inflected_instance_idx,inflected_phones,post_divergence,is_regular,base_idx,inflected_idx,base_ambig_NN_VB
0,VBD,honor,honored,0,AA N ER D,D,True,1254,10249,
1,VBD,honor,honored,1,AA N ER D,D,True,1254,10249,
2,VBD,honor,honored,2,AA N AW ER D,AW ER D,True,1254,10249,
3,VBD,honor,honored,3,AA N ER D,D,True,1254,10249,
4,VBD,honor,honored,4,AA N ER D,D,True,1254,10249,
...,...,...,...,...,...,...,...,...,...,...
120144,random,shouted,daring,17,D EH R IH NG,D EH R IH NG,False,2739,343,
120145,random,shouted,daring,18,D EH R IH NG,D EH R IH NG,False,2739,343,
120146,random,shouted,daring,19,D EH R IH NG,D EH R IH NG,False,2739,343,
120147,random,shouted,daring,20,D EH R IH NG,D EH R IH NG,False,2739,343,


In [33]:
# compute most frequent allomorph of each inflection
most_common_allomorphs = inflection_instance_df.groupby(["inflection", "base"]).post_divergence \
    .apply(lambda xs: xs.value_counts().idxmax()) \
    .rename("most_common_allomorph").reset_index()

## Build full cross product of stimuli

In [34]:
inflection_cross_instances = []
base_cross_instances = []

for inflection, row in tqdm(inflection_results_df.iterrows(), total=len(inflection_results_df)):
    inflected_instance_idxs = ss_spans.query(f"label == @row.inflected").instance_idx
    inflected_forms = cut_phonemic_forms.loc[row.inflected]
    for inflected_instance_idx in inflected_instance_idxs:
        inflection_cross_instances.append({
            "inflection": inflection,
            "base": row.base,
            "inflected": row.inflected,
            "inflected_instance_idx": inflected_instance_idx,
            "inflected_phones": inflected_forms.loc[inflected_instance_idx]
        })

    base_instance_idxs = ss_spans.query(f"label == @row.base").instance_idx
    base_forms = cut_phonemic_forms.loc[row.base]
    for base_instance_idx in base_instance_idxs:
        base_cross_instances.append({
            "inflection": inflection,
            "base": row.base,
            "inflected": row.inflected,
            "base_instance_idx": base_instance_idx,
            "base_phones": base_forms.loc[base_instance_idx]
        })

  0%|          | 0/4632 [00:00<?, ?it/s]

In [35]:
# add in post-divergence information
inflection_cross_instances_df = pd.DataFrame(inflection_cross_instances)
merge_on = ["inflection", "base", "inflected", "inflected_instance_idx"]
inflection_cross_instances_df = pd.merge(inflection_cross_instances_df,
                                         inflection_instance_df[merge_on + ["post_divergence"]],
                                         on=merge_on)

all_cross_instances = pd.merge(pd.DataFrame(base_cross_instances),
         inflection_cross_instances_df,
         on=["inflection", "base", "inflected"],
         how="outer")

# Now merge with type-level information.
all_cross_instances = pd.merge(inflection_results_df.reset_index(),
                               all_cross_instances,
                               on=["inflection", "base", "inflected"],
                               validate="1:m")

all_cross_instances["exclude_main"] = False
all_cross_instances

Unnamed: 0,inflection,inflected,base,is_regular,base_idx,inflected_idx,base_ambig_NN_VB,base_instance_idx,base_phones,inflected_instance_idx,inflected_phones,post_divergence,exclude_main
0,VBD,honored,honor,True,1254,10249,,0,AA N ER,0,AA N ER D,D,False
1,VBD,honored,honor,True,1254,10249,,0,AA N ER,1,AA N ER D,D,False
2,VBD,honored,honor,True,1254,10249,,0,AA N ER,2,AA N AW ER D,AW ER D,False
3,VBD,honored,honor,True,1254,10249,,0,AA N ER,3,AA N ER D,D,False
4,VBD,honored,honor,True,1254,10249,,0,AA N ER,4,AA N ER D,D,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...
6582067,random,daring,shouted,False,2739,343,,64,SH AW T AH D,17,D EH R IH NG,D EH R IH NG,False
6582068,random,daring,shouted,False,2739,343,,64,SH AW T AH D,18,D EH R IH NG,D EH R IH NG,False
6582069,random,daring,shouted,False,2739,343,,64,SH AW T AH D,19,D EH R IH NG,D EH R IH NG,False
6582070,random,daring,shouted,False,2739,343,,64,SH AW T AH D,20,D EH R IH NG,D EH R IH NG,False


## False friend production

In [37]:
def compute_false_friends():
    false_friends_dfs = {}
    inflection_allomorph_grouper = most_common_allomorphs \
        [~most_common_allomorphs.inflection.isin(("random", "NOT-latin"))] \
        .groupby("inflection").most_common_allomorph \
        .apply(lambda xs: xs.value_counts()[:3]).index
    for inflection, post_divergence in tqdm(inflection_allomorph_grouper):
        avoid_inflections = {"POS", inflection}
        if inflection == "NNS":
            avoid_inflections.add("VBZ")
        elif inflection == "VBZ":
            avoid_inflections.add("NNS")
        avoid_inflections = list(avoid_inflections)

        try:
            false_friends_dfs[inflection, post_divergence] = \
                analogy.prepare_false_friends(
                    inflection_results_df,
                    inflection_instance_df,
                    cut_phonemic_forms,
                    post_divergence,
                    avoid_inflections=avoid_inflections)
        except:
            print("Failed for", inflection, post_divergence)
            continue

    return false_friends_dfs

false_friends_dfs = compute_false_friends()

  0%|          | 0/12 [00:00<?, ?it/s]

In [38]:
false_friends_df = pd.concat(false_friends_dfs, names=["inflection", "post_divergence"]).droplevel(-1)

# manually exclude some cases that don't get filtered out, often just because they're too
# low frequency for both true base and inflected form to appear

# share exclusion list for NNS and VBZ since we have experiments relating these two
# so this is any false-friend for which their is a phonologically identical "base"
# that could instantiate a VBZ or NNS inflection
exclude_NNS_VBZ = ("adds americans arabs assyrians berries carlyle's childs christians "
                   "counties cruise dares dealings delawares europeans excellencies "
                   "fins fours galleries gaze germans indians isles maids mary's negroes "
                   "nuns peas phrase pyes reflections rodgers romans russians simpsons "
                   "spaniards sundays vickers weeds wigwams williams "
                   "jews odds news hose dis yes ice cease peace s us "
                    
                   "greeks lapse mix philips trunks its "
                    
                   "breeches occurrences personages").split()
false_friends_manual_exclude = {
    "NNS": exclude_NNS_VBZ,
    "VBZ": exclude_NNS_VBZ,
    "VBD": ("armored bald bard counseled crude dared enquired healed knowed legged "
            "mourned natured renowned rude second ward wild willed withered "

            "tract wrapped fitted hearted heralded intrusted knitted wretched").split(),
    "VBG": ("ceiling daring fleeting morning roaming wasting weaving weighing "
            "whining willing chuckling kneeling sparkling startling").split()
}

false_friends_df = false_friends_df.groupby("inflection", as_index=False).apply(
    lambda xs: xs[~xs.inflected.isin(false_friends_manual_exclude.get(xs.name, []))]).droplevel(0)

# exclude the (quite interesting) cases where the "base" and "inflected" form are
# actually orthographically matched, and we're seeing the divergence due to a pronunciation
# variant (e.g. don't as D OW N vs D O WN T)
false_friends_df = false_friends_df[false_friends_df.base != false_friends_df.inflected]

false_friends_df

Unnamed: 0_level_0,Unnamed: 1_level_0,base,base_form,inflected,inflected_form
inflection,post_divergence,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
NNS,Z,adam,AE D AH M,adams,AE D AH M Z
NNS,Z,afterward,AE F T ER W ER D,afterwards,AE F T ER W ER D Z
NNS,Z,an,AE N,anne's,AE N Z
NNS,Z,eh,EH,as,EH Z
NNS,Z,backward,B AE K W ER D,backwards,B AE K W ER D Z
...,...,...,...,...,...
VBZ,S,victoria,V IH K T AO R IY AH,victorious,V IH K T AO R IY AH S
VBZ,S,when,W EH N,whence,W EH N S
VBZ,S,when,HH W EH N,whence,HH W EH N S
VBZ,S,we're,W ER,worse,W ER S


In [39]:
# Define sets or lists for final-phoneme checks
SIBILANTS = {"S", "Z", "SH", "CH", "JH", "ZH"}
VOICELESS = {"P", "T", "K", "F", "TH"}  # Could add others as needed

def guess_nns_vbz_allomorph(base_phones):
    """
    Given a list of CMUDICT phones for a base form, 
    return the 'expected' post-divergence allomorph 
    (S, Z, or IH Z, etc.) for the English plural / 3sg verb.
    """
    last_phone = base_phones[-1]
    
    if last_phone in SIBILANTS:
        # e.g., 'CH' -> "IH Z"
        return "IH Z"
    elif last_phone in VOICELESS:
        # e.g., 'K', 'P', 'T' -> "S"
        return "S"
    else:
        # default to voiced => "Z"
        return "Z"
    

ALVEOLAR_STOPS = {"T", "D"}
# Example set of voiceless consonants (non-exhaustive—adjust as needed).
VOICELESS = {"P", "F", "K", "S", "SH", "CH", "TH"}  # Typically would also have /ʃ/, etc.

def guess_past_allomorph(base_phones):
    """
    Given a list of CMUDICT phones for a base form,
    return the 'expected' post-divergence allomorph
    (T, D, or IH D) for the English past tense.
    """
    last_phone = base_phones[-1]
    
    if last_phone in ALVEOLAR_STOPS:
        # E.g., "want" -> "wanted" => "AH0 D"
        return "IH D"
    elif last_phone in VOICELESS:
        # E.g., "jump" -> "jumped" => "T"
        return "T"
    else:
        # default to voiced => "D"
        return "D"


false_friends_df.loc[["NNS", "VBZ"], "strong_expected"] = false_friends_df.loc[["NNS", "VBZ"]].apply(lambda xs: guess_nns_vbz_allomorph(xs.base_form.split(" ")), axis=1)
false_friends_df.loc[["VBD"], "strong_expected"] = false_friends_df.loc[["VBD"]].apply(lambda xs: guess_past_allomorph(xs.base_form.split(" ")), axis=1)
false_friends_df["strong"] = false_friends_df.index.get_level_values("post_divergence") == false_friends_df.strong_expected

### Prepare false-friends cross product and merge

In [40]:
cross_false_friends_df = pd.merge(false_friends_df.reset_index(),
         cut_phonemic_forms.reset_index().rename(
             columns={"label": "base", "description": "base_form",
                      "instance_idx": "base_instance_idx"}),
         on=["base", "base_form"], how="left")
cross_false_friends_df = pd.merge(cross_false_friends_df,
         cut_phonemic_forms.reset_index().rename(
             columns={"label": "inflected", "description": "inflected_form",
                      "instance_idx": "inflected_instance_idx"}),
         on=["inflected", "inflected_form"], how="left")

# update to match all_cross_instances schema
cross_false_friends_df = cross_false_friends_df.rename(
    columns={"base_form": "base_phones",
             "inflected_form": "inflected_phones"})
cross_false_friends_df["base_idx"] = cross_false_friends_df.base.map({l: i for i, l in enumerate(state_space_spec.labels)})
cross_false_friends_df["inflected_idx"] = cross_false_friends_df.inflected.map({l: i for i, l in enumerate(state_space_spec.labels)})
cross_false_friends_df["is_regular"] = True

cross_false_friends_df["inflection"] = (cross_false_friends_df.inflection + "-FF-").str.cat(cross_false_friends_df.post_divergence, sep="")
cross_false_friends_df["exclude_main"] = True
cross_false_friends_df

Unnamed: 0,inflection,post_divergence,base,base_phones,inflected,inflected_phones,strong_expected,strong,base_instance_idx,inflected_instance_idx,base_idx,inflected_idx,is_regular,exclude_main
0,NNS-FF-Z,Z,adam,AE D AH M,adams,AE D AH M Z,Z,True,0,0,6076,14620,True,True
1,NNS-FF-Z,Z,adam,AE D AH M,adams,AE D AH M Z,Z,True,0,1,6076,14620,True,True
2,NNS-FF-Z,Z,adam,AE D AH M,adams,AE D AH M Z,Z,True,0,2,6076,14620,True,True
3,NNS-FF-Z,Z,adam,AE D AH M,adams,AE D AH M Z,Z,True,0,3,6076,14620,True,True
4,NNS-FF-Z,Z,adam,AE D AH M,adams,AE D AH M Z,Z,True,0,4,6076,14620,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
406070,VBZ-FF-IH Z,IH Z,rich,R IH CH,riches,R IH CH IH Z,IH Z,True,99,14,2806,4671,True,True
406071,VBZ-FF-IH Z,IH Z,rich,R IH CH,riches,R IH CH IH Z,IH Z,True,99,15,2806,4671,True,True
406072,VBZ-FF-IH Z,IH Z,rich,R IH CH,riches,R IH CH IH Z,IH Z,True,99,16,2806,4671,True,True
406073,VBZ-FF-IH Z,IH Z,rich,R IH CH,riches,R IH CH IH Z,IH Z,True,99,17,2806,4671,True,True


In [41]:
all_cross_instances = pd.concat([all_cross_instances, cross_false_friends_df], axis=0)

## Save

In [44]:
state_space_spec.to_hdf5(f"{output_dir}/state_space_spec.h5")

  value = self._g_getattr(self._v_node, name)
your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block0_values] [items->Index(['description'], dtype='object')]

  self.cuts.to_hdf(path, key=cuts_key, mode="a")


In [None]:
inflection_results_df.to_parquet(f"{output_dir}/inflection_results.parquet")
inflection_instance_df.to_parquet(f"{output_dir}/inflection_instances.parquet")
all_cross_instances.to_parquet(f"{output_dir}/all_cross_instances.parquet")

In [48]:
false_friends_df.to_csv(f"{output_dir}/false_friends.csv")
most_common_allomorphs.to_csv(f"{output_dir}/most_common_allomorphs.csv")