Study the dynamics of morphemic processing and understand if/how they relate to the computed geometries from static word-level embeddings.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict, Counter
import itertools

from lemminflect import getInflection
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
from sklearn.decomposition import PCA
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/{train_dataset}.npy"

seed = 1234

max_samples_per_word = 100

metric = "cosine"

agg_fn = ("mean_within_cut", "phoneme")

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
state_space_spec = state_space_spec.subsample_instances(max_samples_per_word)

In [None]:
# load difference vectors for comparison
difference_vectors = torch.load(f"outputs/notebooks/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/geometry/analogy/analogy_difference_vectors.pt")

In [None]:
nns_difference_vectors = np.concatenate([x['difference_vectors'] for x in difference_vectors if "noun - plural_reg" in x["prefix"]])

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

In [None]:
traj_agg = aggregate_state_trajectory(trajectory, state_space_spec, agg_fn, keepdims=True)
agg_flat, agg_src = flatten_trajectory(traj_agg)

In [None]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

agg_flat_idxs = pd.Series({tuple(agg_src_i): i for i, agg_src_i in enumerate(agg_src)})
agg_flat_idxs.index.names = ["label_idx", "instance_idx", "frame_idx"]
cuts_df = pd.merge(cuts_df, agg_flat_idxs.rename("traj_flat_idx"), left_index=True, right_index=True)

In [None]:
cuts_df = cuts_df.set_index("label", append=True).reorder_levels(["label", "label_idx", "instance_idx", "frame_idx"]).sort_index()
cuts_df

In [None]:
inflection_targets = [
    ("VBD", "verb_inf - Ved"),
    ("VBZ", "verb_inf - 3pSg"),
    ("VBG", "verb_inf - Ving"),
    ("NNS", "noun - plural_reg"),
]
labels = state_space_spec.label_counts
labels = set(labels[labels > 15].index)

inflection_results = {target: {} for target, _ in inflection_targets}
inflection_reverse = defaultdict(set)
for target, _ in tqdm(inflection_targets):
    for label in labels:
        label_inflections = set(getInflection(label, tag=target, inflect_oov=False))
        # don't include zero-derived forms
        label_inflections -= {label}

        covered_inflections = label_inflections & labels
        if covered_inflections:
            inflection_results[target][label] = covered_inflections

            for infl in covered_inflections:
                inflection_reverse[infl].add((label, target))

from pprint import pprint
pprint({target: len(v) for target, v in inflection_results.items()})

ambiguous_inflected_forms = {k: v for k, v in inflection_reverse.items()
                             if len(v) > 1}
print(f"Ambiguous inflected forms ({len(ambiguous_inflected_forms)} total):")
print(" ".join(ambiguous_inflected_forms.keys()))

In [None]:
def run_inflection_study(study_inflection, study_inflection_difference_vector_prefix):
    study_triples, study_metadata = defaultdict(list), {}
    for base_form, inflected_forms in tqdm(inflection_results[study_inflection].items()):
        for inflected_form in inflected_forms:
            is_regular = inflected_form[:len(base_form)] == base_form \
                or inflected_form[-3:] == "ies" and base_form[-1] == "y" \
                or inflected_form[-3:] == "ves" and (base_form[-1] == "f" or base_form[-2:] == "fe")
            
            # orthographic divergence point
            ortho_divergence_point = sum(1 for idx in range(min(len(base_form), len(inflected_form)))
                                        if inflected_form[:idx] == base_form[:idx])
            
            base_cuts = cuts_df.loc[base_form]
            inflected_cuts = cuts_df.loc[inflected_form]

            # all attested phonological forms of base
            base_phono_forms = set(base_cuts.groupby("instance_idx").apply(
                lambda xs: tuple(xs.description)))

            for instance_idx, inflected_instance in inflected_cuts.groupby("instance_idx"):
                # phonological divergence point: latest point at which the inflected form overlaps with
                # any pronunciation of the base form
                inflected_phones = tuple(inflected_instance.description)
                phono_divergence_points = []
                for base_phones in base_phono_forms:
                    for idx in range(len(inflected_phones) + 1):
                        if inflected_phones[:idx] != base_phones[:idx]:
                            break
                    phono_divergence_points.append(idx - 1)
                phono_divergence_point = max(phono_divergence_points)

                # print(f"{base_phono_forms} -> {inflected_phones} (regular: {is_regular}, ortho_divergence: {ortho_divergence_point}, phono_divergence: {phono_divergence_point})")

                if phono_divergence_point == 0:
                    print(f"Ignoring {base_form} -> {inflected_form} due to phono_divergence_point == 0")
                    continue
                pre_diverging_frame = inflected_instance.iloc[phono_divergence_point - 1].traj_flat_idx
                diverging_frame = inflected_instance.iloc[phono_divergence_point].traj_flat_idx
                final_frame = inflected_instance.iloc[-1].traj_flat_idx
                
                study_triples[base_form].append((pre_diverging_frame, diverging_frame, final_frame))
                if base_form not in study_metadata:
                    study_metadata[base_form] = {
                        "label": base_form,
                        "is_regular": is_regular,
                        "post_divergence": Counter([inflected_phones[phono_divergence_point:]]),
                    }
                else:
                    study_metadata[base_form]["post_divergence"].update([inflected_phones[phono_divergence_point:]])

    # for each base -> inflected, get average in each of the three states
    study_triple_means = []
    study_triple_metadata = []
    for label, label_triples in tqdm(study_triples.items()):
        study_triple_means.append(agg_flat[label_triples].mean(axis=0))

        label_metadata = study_metadata[label]
        # get most common post-divergence phonological form
        label_metadata["post_divergence"] = " ".join(label_metadata["post_divergence"].most_common(1)[0][0])
        study_triple_metadata.append(label_metadata)

    study_triple_metadata = pd.DataFrame(study_triple_metadata)

    # get num_lemmata * 3 * model_dim representations
    inflected_states = np.stack(study_triple_means)
    inflected_states.shape

    assert len(study_triple_metadata) == len(inflected_states)

    #######

    inflection_updates = inflected_states[:, 2, :] - inflected_states[:, 0, :]

    regular_inflected_states = inflected_states[study_triple_metadata[study_triple_metadata["is_regular"]].index]
    regular_inflection_updates = regular_inflected_states[:, 2, :] - regular_inflected_states[:, 0, :]

    reference_difference_vectors = np.stack([x["difference_vectors"].mean(0) for x in difference_vectors
                                                   if study_inflection_difference_vector_prefix in x["prefix"]])
    counterfactual_difference_vectors = np.stack([x['difference_vectors'].mean(0) for x in difference_vectors
                                                        if study_inflection_difference_vector_prefix not in x["prefix"]])

    return inflection_updates, regular_inflection_updates, \
        reference_difference_vectors, counterfactual_difference_vectors, \
        study_triple_metadata

In [None]:
study_results = {}
study_results_df = {}
infl_updates = {}

for target, prefix in tqdm(inflection_targets):
    infl_updates[target], infl_updates_regular, reference_diff_vectors, reference_diff_vectors_counterfactual, metadata = \
        run_inflection_study(target, prefix)

    study_results_df[target] = {
        "within_inflection": cdist(infl_updates[target], infl_updates[target], metric=metric).mean(),
        "within_inflection_regular": cdist(infl_updates_regular, infl_updates_regular, metric=metric).mean(),
        "reference_diff": cdist(infl_updates[target], reference_diff_vectors, metric=metric).mean(),
        "reference_diff_counterfactual": cdist(infl_updates[target], reference_diff_vectors_counterfactual, metric=metric).mean(),
    }

    study_results[target] = {
        "infl_updates": infl_updates[target],
        "metadata": metadata,
    }

In [None]:
# Now that we have vectors from all inflections computed, estimate between-inflection distances
for target in study_results_df:
    study_results_df[target]["between_inflection"] = cdist(infl_updates[target], np.concatenate([infl_updates[t] for t in study_results if t != target]), metric=metric).mean()

In [None]:
study_results_df = pd.DataFrame.from_dict(study_results_df, orient="index")
study_results_df.to_csv(f"{output_dir}/inflection_study.csv")
study_results_df

In [None]:
sns.barplot(data=study_results_df.reset_index().melt(id_vars=["index"]),
            x="index", y="value", hue="variable")

### NNS sub-study

In [None]:
nns_updates, nns_metadata = study_results["NNS"]["infl_updates"], study_results["NNS"]["metadata"]

In [None]:
nns_metadata.post_divergence.value_counts()

In [None]:
compare_divergence_categories = ["Z", "S", "IH Z"]
nns_category_updates = {category: nns_updates[nns_metadata[nns_metadata.post_divergence == category].index]
                        for category in compare_divergence_categories}

In [None]:
nns_distances_within = {category: cdist(nns_category_updates[category], nns_category_updates[category], metric=metric).mean()
                        for category in compare_divergence_categories}
nns_distances_between = {category: cdist(nns_category_updates[category], np.concatenate([nns_category_updates[c] for c in compare_divergence_categories if c != category]), metric=metric).mean()
                         for category in compare_divergence_categories}

In [None]:
nns_study_df = pd.DataFrame({
    "distance_within": nns_distances_within,
    "distance_between": nns_distances_between,
})
nns_study_df["distance_counterfactual"] = cdist(nns_updates, reference_diff_vectors_counterfactual, metric=metric).mean()
nns_study_df.to_csv(f"{output_dir}/nns_study.csv")
nns_study_df

In [None]:
sns.barplot(data=nns_study_df.reset_index().melt(id_vars="index"), x="index", y="value", hue="variable")