Study the dynamics of morphemic processing and understand if/how they relate to the computed geometries from static word-level embeddings.

In [1]:
%load_ext autoreload
%autoreload 2

In [13]:
from collections import defaultdict, Counter

import lemminflect
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
from sklearn.decomposition import PCA
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [3]:
base_model = "w2v2_8"
model_class = "discrim-rnn_32-mAP1"
model_name = "word_broad_10frames_fixedlen25"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{train_dataset}.npy"

seed = 1234

max_samples_per_word = 100

metric = "cosine"

agg_fn = ("mean_within_cut", "phoneme")

In [4]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
assert state_space_spec.is_compatible_with(model_representations)

In [5]:
state_space_spec = state_space_spec.subsample_instances(max_samples_per_word)

In [6]:
# load difference vectors for comparison
difference_vectors = torch.load(f"{output_dir.replace('analogy_dynamic', 'analogy')}/analogy_difference_vectors.pt")
nns_difference_vectors = np.concatenate([x['difference_vectors'] for x in difference_vectors if "noun - plural_reg" in x["prefix"]])

In [7]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

In [None]:
traj_agg = aggregate_state_trajectory(trajectory, state_space_spec, agg_fn, keepdims=True)
agg_flat, agg_src = flatten_trajectory(traj_agg)

In [9]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

agg_flat_idxs = pd.Series({tuple(agg_src_i): i for i, agg_src_i in enumerate(agg_src)})
agg_flat_idxs.index.names = ["label_idx", "instance_idx", "frame_idx"]
cuts_df = pd.merge(cuts_df, agg_flat_idxs.rename("traj_flat_idx"), left_index=True, right_index=True)

In [None]:
cuts_df = cuts_df.set_index("label", append=True).reorder_levels(["label", "label_idx", "instance_idx", "frame_idx"]).sort_index()
cuts_df

In [32]:
def get_inflection(word, all_labels, target) -> set[str]:
    if target in ("VBD", "VBZ", "VBG", "NNS"):
        inflections = set(lemminflect.getInflection(word, tag=target, inflect_oov=False))
        # don't include zero-derived form
        inflections -= {word}
    elif target == "NOT-latin":
        inflections = {"in" + word}
        if word[0] == "l":
            inflections |= {"il" + word}
        elif word[0] in ["p", "b", "m"]:
            inflections |= {"im" + word}
        elif word[0] == "r":
            inflections |= {"ir" + word}

        # catch exceptional cases -- these predicted forms are attested, but don't count
        # as a negative inflection. e.g. "come" -> "income"
        if word in ("come comes deed diana dies doors fancy form formation formed forming "
                    "habit jury justice k l laid land most n part parted port press pressed "
                    "prove proved pulse pulses side sight stead sure tend tended tending tent "
                    "to trusted vent ward"):
            inflections = set()
    else:
        raise ValueError(f"Unknown target: {target}")
    
    covered_inflections = inflections & all_labels
    return covered_inflections

In [None]:
inflection_targets = [
    ("VBD", "verb_inf - Ved"),
    ("VBZ", "verb_inf - 3pSg"),
    ("VBG", "verb_inf - Ving"),
    ("NNS", "noun - plural_reg"),
    ("NOT-latin", None),
]
labels = state_space_spec.label_counts
labels = set(labels[labels >= 5].index)

inflection_results = {target: {} for target, _ in inflection_targets}
inflection_reverse = defaultdict(set)
for target, _ in tqdm(inflection_targets):
    for label in labels:
        label_inflections = get_inflection(label, labels, target)
        if label_inflections:
            inflection_results[target][label] = label_inflections

            for infl in label_inflections:
                inflection_reverse[infl].add((label, target))

from pprint import pprint
pprint({target: len(v) for target, v in inflection_results.items()})

ambiguous_inflected_forms = {k: v for k, v in inflection_reverse.items()
                             if len(v) > 1}
print(f"Ambiguous inflected forms ({len(ambiguous_inflected_forms)} total):")
print(" ".join(ambiguous_inflected_forms.keys()))

In [None]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)
cut_phonemic_forms

In [218]:
def is_regular(inflection, base, inflected):
    if inflection == "NNS":
        return inflected[:len(base)] == base \
                or inflected[-3:] == "ies" and base[-1] == "y" \
                or inflected[-3:] == "ves" and (base[-1] == "f" or base[-2:] == "fe")
    elif inflection == "VBZ":
        return inflected == base + "s" \
                or inflected == base + "es" \
                or (base[-1] == "y" and inflected == base[:-1] + "ies")
    elif inflection == "VBG":
        return inflected == base + "ing" \
                or (base[-1] == "e" and inflected == base[:-1] + "ing") \
                or (base[-1] in "bcdfghjklmnpqrstvwxz" and inflected == base + base[-1] + "ing") \
                or (base[-2:] == "ie" and inflected == base[:-2] + "ying")
    elif inflection == "VBD":
        return inflected == base + "ed" \
                or inflected == base + "d" \
                or inflected == base + "t" \
                or (base[-1] == "y" and inflected == base[:-1] + "ied") \
                or (base[-2:] == "ay" and inflected == base[:-1] + "id") \
                or (base[-1] in "bcdfghjklmnpqrstvwxz" and inflected == base + base[-1] + "ed")
    elif inflection == "NOT-latin":
        return inflected[:2] == "in"
    else:
        raise ValueError(f"Unknown inflection {inflection}")


def run_inflection_study(study_inflection, study_inflection_difference_vector_prefix,
                         smoke_test=False):
    study_triples, study_metadata = defaultdict(list), {}
    i = 0
    for base_form, inflected_forms in tqdm(inflection_results[study_inflection].items()):
        if smoke_test and i > 2:
            break
        i += 1

        for inflected_form in inflected_forms:
            # orthographic divergence point
            ortho_divergence_point = sum(1 for idx in range(min(len(base_form), len(inflected_form)))
                                        if inflected_form[:idx] == base_form[:idx])
            
            base_cuts = cuts_df.loc[base_form]
            inflected_cuts = cuts_df.loc[inflected_form]

            # all attested phonological forms of base
            base_phono_forms = set(base_cuts.groupby("instance_idx").apply(
                lambda xs: tuple(xs.description)))

            for instance_idx, inflected_instance in inflected_cuts.groupby("instance_idx"):
                # phonological divergence point: latest point at which the inflected form overlaps with
                # any pronunciation of the base form
                inflected_phones = tuple(inflected_instance.description)
                phono_divergence_points = []
                for base_phones in base_phono_forms:
                    for idx in range(len(inflected_phones) + 1):
                        if inflected_phones[:idx] != base_phones[:idx]:
                            break
                    phono_divergence_points.append(idx - 1)
                phono_divergence_point = max(phono_divergence_points)

                # print(f"{base_phono_forms} -> {inflected_phones} (regular: {is_regular}, ortho_divergence: {ortho_divergence_point}, phono_divergence: {phono_divergence_point})")

                if phono_divergence_point == 0:
                    print(f"Ignoring {base_form} -> {inflected_form} due to phono_divergence_point == 0")
                    continue
                pre_diverging_frame = inflected_instance.iloc[phono_divergence_point - 1].traj_flat_idx
                diverging_frame = inflected_instance.iloc[phono_divergence_point].traj_flat_idx
                final_frame = inflected_instance.iloc[-1].traj_flat_idx
                
                study_triples[base_form].append((pre_diverging_frame, diverging_frame, final_frame))
                if base_form not in study_metadata:
                    study_metadata[base_form] = {
                        "label": base_form,
                        "is_regular": is_regular(study_inflection, base_form, inflected_form),
                        "post_divergence": Counter([inflected_phones[phono_divergence_point:]]),
                    }
                else:
                    study_metadata[base_form]["post_divergence"].update([inflected_phones[phono_divergence_point:]])

    # for each base -> inflected, get average in each of the three states
    study_triple_means = []
    study_triple_metadata = []
    for label, label_triples in tqdm(study_triples.items()):
        study_triple_means.append(agg_flat[label_triples].mean(axis=0))

        label_metadata = study_metadata[label]
        # get most common post-divergence phonological form
        label_metadata["post_divergence"] = " ".join(label_metadata["post_divergence"].most_common(1)[0][0])
        study_triple_metadata.append(label_metadata)

    if not study_triple_means:
        return

    study_triple_metadata = pd.DataFrame(study_triple_metadata)

    # get num_lemmata * 3 * model_dim representations
    inflected_states = np.stack(study_triple_means)

    assert len(study_triple_metadata) == len(inflected_states)

    #######

    inflection_updates = inflected_states[:, 2, :] - inflected_states[:, 0, :]

    ########
    # fetch control state updates. for each observed divergence point, attempt to find word tokens
    # with the same phonemic content, but which words don't correspond to the inflection being studied
    control_divergence_lookup = study_triple_metadata.post_divergence.unique()

    # regular_inflected_states = inflected_states[study_triple_metadata[study_triple_metadata["is_regular"]].index]
    # regular_inflection_updates = regular_inflected_states[:, 2, :] - regular_inflected_states[:, 0, :]

    if study_inflection_difference_vector_prefix is None:
        reference_difference_vectors = None
        counterfactual_difference_vectors = None
    else:
        reference_difference_vectors = np.stack([x["difference_vectors"].mean(0) for x in difference_vectors
                                                 if study_inflection_difference_vector_prefix in x["prefix"]])
        counterfactual_difference_vectors = np.stack([x['difference_vectors'].mean(0) for x in difference_vectors
                                                      if study_inflection_difference_vector_prefix not in x["prefix"]])

    return inflection_updates, \
        reference_difference_vectors, counterfactual_difference_vectors, \
        study_triple_metadata

In [None]:
study_results = {}
study_results_df = {}
infl_updates = {}

for target, prefix in tqdm(inflection_targets):
    ret = run_inflection_study(target, prefix)
    if ret is None:
        continue
    infl_updates[target], reference_diff_vectors, reference_diff_vectors_counterfactual, metadata = ret

    has_irregulars = (~metadata.is_regular).sum() > 0
    study_results_df[target] = {
        "within_inflection": cdist(infl_updates[target], infl_updates[target], metric=metric).mean(),
        "within_inflection_reg-reg": cdist(infl_updates[target][metadata.is_regular],
                                           infl_updates[target][metadata.is_regular], metric=metric).mean(),
        "within_inflection_irreg-irreg": np.nan if not has_irregulars else \
            cdist(infl_updates[target][~metadata.is_regular],
                  infl_updates[target][~metadata.is_regular], metric=metric).mean(),
        "within_inflection_reg-irreg": np.nan if not has_irregulars else \
            cdist(infl_updates[target][metadata.is_regular],
                  infl_updates[target][~metadata.is_regular], metric=metric).mean(),
        "reference_diff": cdist(infl_updates[target], reference_diff_vectors, metric=metric).mean()
                          if reference_diff_vectors is not None else None,
        "reference_diff_counterfactual": cdist(infl_updates[target], reference_diff_vectors_counterfactual, metric=metric).mean()
                                         if reference_diff_vectors_counterfactual is not None else None,
    }

    study_results[target] = {
        "infl_updates": infl_updates[target],
        "metadata": metadata,

        "reference_diff_vectors": reference_diff_vectors,
        "reference_diff_vectors_counterfactual": reference_diff_vectors_counterfactual,
    }

for target in study_results_df:
    study_results_df[target]["counterfactual"] = cdist(infl_updates[target], np.concatenate([infl_updates[t] for t in study_results if t != target]), metric=metric).mean()
    study_results[target]["vectors_counterfactual"] = np.concatenate([infl_updates[t] for t in study_results if t != target])

In [228]:
# Now that we have vectors from all inflections computed, estimate between-inflection distances
for target in study_results_df:
    study_results_df[target]["between_inflection"] = cdist(infl_updates[target], np.concatenate([infl_updates[t] for t in study_results if t != target]), metric=metric).mean()

In [None]:
study_results_df = pd.DataFrame.from_dict(study_results_df, orient="index")
study_results_df.to_csv(f"{output_dir}/inflection_study.csv")
study_results_df

In [None]:
sns.barplot(data=study_results_df.reset_index().melt(id_vars=["index"]),
            x="index", y="value", hue="variable")

### NNS sub-study

In [231]:
def run_controlled_study(target, compare_divergence_categories, smoke_test=False):
    updates = study_results[target]["infl_updates"]
    metadata = study_results[target]["metadata"]
    # reference_diff_vectors = study_results[target]["reference_diff_vectors_counterfactual"]
    reference_diff_vectors = study_results[target]["vectors_counterfactual"]

    assert set(compare_divergence_categories) <= set(metadata.post_divergence)
    compare_divergence_categories += ["<irreg>", "<counterfactual>"]

    category_updates = {category: updates[metadata[metadata.post_divergence == category].index]
                        for category in compare_divergence_categories}
    category_updates["<irreg>"] = updates[~metadata.is_regular]
    category_updates["<counterfactual>"] = reference_diff_vectors

    distances = {}
    distance_averages = np.zeros((len(compare_divergence_categories), len(compare_divergence_categories)))
    for cat1, cat1_updates in category_updates.items():
        for cat2, cat2_updates in category_updates.items():
            distances_ij = cdist(cat1_updates, cat2_updates, metric=metric)
            if cat1 == cat2:
                distances_ij = distances_ij[np.triu_indices_from(distances_ij, k=1)]
            else:
                distances_ij = distances_ij[np.triu_indices_from(distances_ij)]

            distances[cat1, cat2] = distances_ij
            distance_averages[compare_divergence_categories.index(cat1), compare_divergence_categories.index(cat2)] = \
                distances[cat1, cat2].mean()
            
    distance_df = pd.concat([
        pd.Series(distances_i, name="distance").to_frame().assign(**{"from": cat1, "to": cat2})
        for (cat1, cat2), distances_i in distances.items()
    ])
    distance_df.loc[distance_df["from"] == distance_df["to"], "type"] = "within"
    distance_df.loc[distance_df["from"] != distance_df["to"], "type"] = "between"
    distance_df.loc[distance_df["to"] == "<irreg>", "type"] = "irreg"
    distance_df.loc[distance_df["to"] == "<counterfactual>", "type"] = "counterfactual"

    # drop irrelevant counterfactual sources
    distance_df = distance_df[~(distance_df["from"] == "<counterfactual>")]

    ####

    nrows = 4
    f, axs = plt.subplots(nrows, 1, figsize=(10, 5 * nrows))

    sns.heatmap(distance_averages, annot=True,
                xticklabels=compare_divergence_categories,
                yticklabels=compare_divergence_categories,
                ax=axs[0])

    sns.barplot(data=distance_df, x="from", hue="to", y="distance",
                errorbar=None if smoke_test else ("ci", 95), ax=axs[1])
    
    sns.barplot(data=distance_df, x="from", hue="type", y="distance",
                errorbar=None if smoke_test else ("ci", 95), ax=axs[2])
    
    sns.barplot(data=distance_df, x="type", y="distance",
                errorbar=None if smoke_test else ("ci", 95), ax=axs[3])
    
    f.tight_layout()

    return distance_df, distance_averages

In [None]:
nns_distance_df, _ = run_controlled_study("NNS", ["S", "Z", "IH Z", "IH N"])
nns_distance_df.to_csv(f"{output_dir}/nns_distance.csv")

In [None]:
vbd_distance_df, _ = run_controlled_study("VBD", ["D", "T", "IH D", "UW"])
vbd_distance_df.to_csv(f"{output_dir}/vbd_distance.csv")