In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import defaultdict
import itertools
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import analogy
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechHiddenStateDataset


In [3]:
torch.set_num_threads(8)

In [4]:
base_model = "w2v2_pc_8"

model_class = "ff_32"#discrim-rnn_32-pc-mAP1"
model_name = "word_broad_10frames_fixedlen25"

train_dataset = "librispeech-train-clean-100"
# hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
hidden_states_path = f"/scratch/jgauthier/{base_model}_{train_dataset}.h5"
state_space_specs_path = f"outputs/analogy/inputs/{train_dataset}/w2v2_pc/state_space_spec.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{train_dataset}.npy"

output_dir = f"."

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

agg_fns = [
    ("mean_within_cut", "phoneme")
]

## Prepare model representations

In [5]:
if embeddings_path == "ID":
    model_representations = SpeechHiddenStateDataset.from_hdf5(hidden_states_path).states
else:
    with open(embeddings_path, "rb") as f:
        model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

  value = self._g_getattr(self._v_node, name)


In [6]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec)
trajectory = aggregate_state_trajectory(trajectory, state_space_spec, agg_fns[0])

  0%|          | 0/32046 [00:00<?, ?it/s]

Aggregating:   0%|          | 0/32046 [00:00<?, ?label/s]

In [7]:
agg, agg_src = flatten_trajectory(trajectory)

## Prepare metadata

In [8]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

agg_flat_idxs = pd.Series(list(range(len(agg_src))),
                          index=pd.MultiIndex.from_tuples([tuple(xs) for xs in agg_src],
                                                          names=["label_idx", "instance_idx", "frame_idx"]))
cuts_df = pd.merge(cuts_df, agg_flat_idxs.rename("traj_flat_idx"), left_index=True, right_index=True)

In [9]:
label2idx = {l: i for i, l in enumerate(state_space_spec.labels)}

In [10]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [11]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
word_freq_df = word_freq_df.loc[~word_freq_df.index.duplicated()]
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq_rel", "TwitterFreq_rel", "NewsFreq_rel"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

## Prepare items

In [12]:
next_phon_set = set("AH ER IH L S Z T D M N".split())
target_cohort_length = 2
# defines an alternative "small" cohort: prefixes which have only N of the above phones
target_small_cohort_size = 3
assert target_small_cohort_size < len(next_phon_set)

In [13]:
cohorts = defaultdict(set)
for phones in tqdm(cut_phonemic_forms.unique()):
    phones = tuple(phones.split())
    for i in range(len(phones)):
        cohorts[phones[:i + 1]].add(phones)

csz_next = pd.DataFrame([(" ".join(coh), " ".join(item), item[len(coh)]) for coh, items in cohorts.items()
                            for item in items if len(item) > len(coh)],
                            columns=["cohort", "item", "next_phoneme"])

  0%|          | 0/32462 [00:00<?, ?it/s]

In [14]:
expt_cohort = csz_next[csz_next.cohort.str.count(" ") == target_cohort_length - 1] \
    .groupby("cohort").filter(lambda xs: set(xs.next_phoneme) >= next_phon_set) \
    .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))
expt_cohort

  .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))


cohort
AH N    [AA, AE, AH, AO, AW, AY, B, CH, D, EH, ER, EY,...
K ER    [AA, AE, AH, AO, AW, B, CH, D, EH, ER, EY, F, ...
K OW    [AH, B, CH, D, ER, HH, IH, IY, JH, K, L, M, N,...
L AY    [AE, AH, B, D, DH, ER, F, IH, K, L, M, N, P, R...
P ER    [AA, AE, AH, AY, B, CH, D, EH, ER, EY, F, G, H...
P EY    [AH, D, ER, G, IH, JH, L, M, N, P, S, SH, T, T...
R OW    [AH, B, D, ER, G, HH, IH, IY, K, L, M, N, P, S...
dtype: object

In [15]:
# Now search for type-small cohorts -- cohorts which only have N of the phone set
expt_cohort_small = csz_next[csz_next.cohort.str.count(" ") == target_cohort_length - 1].groupby("cohort").filter(lambda xs: len(set(xs.next_phoneme)) == target_small_cohort_size and set(xs.next_phoneme) <= next_phon_set) \
    .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))
expt_cohort_small

  .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))


cohort
AA F      [AH, L, T]
AO G     [AH, ER, M]
AY OW     [AH, L, T]
EH TH    [AH, IH, N]
ER EY      [D, N, Z]
ER JH    [AH, D, IH]
ER OW      [M, N, Z]
F UW       [D, L, Z]
K OY       [L, N, T]
TH AH      [D, M, N]
W AW       [N, T, Z]
Z AY      [AH, D, S]
dtype: object

### Prepare instance-level metadata

In [16]:
def prepare_equivalences(cohort, next_phoneme):
    """
    Equivalence-class vocabulary for evaluating predictions on the given
    cohort + next phoneme.
    """

    cohort_length = cohort.count(" ") + 1

    # next-phoneme strict: match next phoneme
    matches_next_phoneme = cuts_df.xs(cohort_length, level="frame_idx") \
        .query("description == @next_phoneme")

    # next-phoneme weak: match next phoneme, but allow predicting that phoneme frame
    # or any future frame of the word
    matches_next_phoneme_weak = cuts_df.query("frame_idx >= @cohort_length").merge(
        matches_next_phoneme.traj_flat_idx.rename("next_phoneme_flat_idx"),
        how="inner", left_index=True, right_index=True
    )

    # matches cohort
    matches_cohort = cut_phonemic_forms[cut_phonemic_forms.str.match(f"^{cohort}")].index
    matches_cohort = pd.merge(cuts_df.reset_index(),
                              matches_cohort.to_frame(index=False),
                              how="inner",
                              on=["label", "instance_idx"]) \
        .query("frame_idx >= @cohort_length")
    
    matches_cohort_and_next_phoneme = matches_next_phoneme[
        matches_next_phoneme.traj_flat_idx.isin(matches_cohort.traj_flat_idx)]
    
    matches_cohort_and_next_phoneme_weak = matches_next_phoneme_weak[
        matches_next_phoneme_weak.traj_flat_idx.isin(matches_cohort.traj_flat_idx)]

    ret = {
        "matches_next_phoneme": matches_next_phoneme.traj_flat_idx.values,
        "matches_next_phoneme_weak": matches_next_phoneme_weak.traj_flat_idx.values,
        "matches_cohort": matches_cohort.traj_flat_idx.values,
        "matches_cohort_and_next_phoneme": matches_cohort_and_next_phoneme.traj_flat_idx.values,
        "matches_cohort_and_next_phoneme_weak": matches_cohort_and_next_phoneme_weak.traj_flat_idx.values,
    }

    return ret

In [17]:
all_instances = []
all_prediction_equivalences = {}

# Sample at most this many combinations of cohort + next phone
max_items_per_cohort_and_next_phone = 15

label2idx = {l: i for i, l in enumerate(state_space_spec.labels)}
for cohort, next_phons in tqdm(expt_cohort.items(), total=len(expt_cohort)):
    for phon in next_phons:
        if phon not in next_phon_set:
            continue

        inflected_phones = f"{cohort} {phon}"
        instances = cut_phonemic_forms[cut_phonemic_forms.str.match(f"{inflected_phones}\\b")].index

        # Pick the top K labels with the highest frequency from the cohort.
        coh_labels = instances.get_level_values("label").str.replace("'s$", "", regex=True)
        if len(coh_labels) > max_items_per_cohort_and_next_phone:
            label_freqs = word_freq_df.reindex(coh_labels.unique()).LogFreq.fillna(word_freq_df.LogFreq.min())
            keep_labels = label_freqs.nlargest(max_items_per_cohort_and_next_phone).index
            instances = instances[coh_labels.isin(keep_labels)]
            print(cohort, phon, len(instances))
        
        equiv_key = (inflected_phones,)
        if equiv_key not in all_prediction_equivalences:
            all_prediction_equivalences[equiv_key] = prepare_equivalences(cohort, phon)

        for label, instance_idx in instances:
            all_instances.append({
                "base_phones": cohort,
                "inflected_phones": inflected_phones,
                "post_divergence": phon,

                "inflection": phon,
                "next_phoneme_in_restricted_set": phon in next_phon_set,

                "cohort_length": target_cohort_length,
                "next_phoneme_idx": target_cohort_length,

                "inflected": label,
                "inflected_idx": label2idx[label],
                "inflected_instance_idx": instance_idx,
            })

  0%|          | 0/7 [00:00<?, ?it/s]

AH N AH 155
AH N D 456
AH N IH 89
AH N L 163
AH N M 32
AH N N 117
AH N S 67
AH N T 190
K ER AH 87
K ER IH 94
K ER L 43
K ER N 75
K ER S 26
K ER T 73
K OW D 23
K OW L 187
K OW M 21
K OW S 86
K OW T 103
L AY AH 76
L AY IH 98
L AY N 190
L AY S 23
L AY T 331
L AY Z 54
P ER D 16
P ER L 72
P ER M 137
P ER S 521
P ER T 225
P ER Z 39
P EY D 97
P EY IH 27
P EY L 120
P EY N 284
P EY S 72
P EY T 76
R OW D 215
R OW L 131
R OW M 223
R OW S 16
R OW T 87
R OW Z 195


In [18]:
for cohort, next_phons in tqdm(expt_cohort_small.items(), total=len(expt_cohort_small)):
    for phon in next_phons:
        if phon not in next_phon_set:
            continue
        inflected_phones = f"{cohort} {phon}"
        instances = cut_phonemic_forms[cut_phonemic_forms.str.match(f"{inflected_phones}\\b")].index

        # Pick the top K labels with the highest frequency from the cohort.
        coh_labels = instances.get_level_values("label").str.replace("'s$", "", regex=True)
        if len(coh_labels) > max_items_per_cohort_and_next_phone:
            label_freqs = word_freq_df.reindex(coh_labels.unique()).LogFreq.fillna(word_freq_df.LogFreq.min())
            keep_labels = label_freqs.nlargest(max_items_per_cohort_and_next_phone).index
            instances = instances[coh_labels.isin(keep_labels)]

        equiv_key = (inflected_phones,)
        if equiv_key not in all_prediction_equivalences:
            all_prediction_equivalences[equiv_key] = prepare_equivalences(cohort, phon)

        for label, instance_idx in instances:
            all_instances.append({
                "base_phones": cohort,
                "inflected_phones": inflected_phones,
                "post_divergence": phon,

                "inflection": f"small-{phon}",
                "next_phoneme_in_restricted_set": phon in next_phon_set,

                "cohort_length": target_cohort_length,
                "next_phoneme_idx": target_cohort_length,

                "inflected": label,
                "inflected_idx": label2idx[label],
                "inflected_instance_idx": instance_idx,
            })

  0%|          | 0/12 [00:00<?, ?it/s]

In [19]:
all_instances_df = pd.DataFrame(all_instances)
all_instances_df

Unnamed: 0,base_phones,inflected_phones,post_divergence,inflection,next_phoneme_in_restricted_set,cohort_length,next_phoneme_idx,inflected,inflected_idx,inflected_instance_idx
0,AH N,AH N AH,AH,AH,True,2,2,another,2725,0
1,AH N,AH N AH,AH,AH,True,2,2,another,2725,1
2,AH N,AH N AH,AH,AH,True,2,2,another,2725,2
3,AH N,AH N AH,AH,AH,True,2,2,another,2725,3
4,AH N,AH N AH,AH,AH,True,2,2,another,2725,4
...,...,...,...,...,...,...,...,...,...,...
6305,W AW,W AW Z,Z,small-Z,True,2,2,wowzer,28693,2
6306,Z AY,Z AY AH,AH,small-AH,True,2,2,zion,25333,0
6307,Z AY,Z AY AH,AH,small-AH,True,2,2,zion,25333,1
6308,Z AY,Z AY D,D,small-D,True,2,2,zuyder,20487,0


In [20]:
all_instances_df.query("inflected == 'awful'")

Unnamed: 0,base_phones,inflected_phones,post_divergence,inflection,next_phoneme_in_restricted_set,cohort_length,next_phoneme_idx,inflected,inflected_idx,inflected_instance_idx
5548,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,4
5549,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,5
5550,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,16
5551,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,24
5552,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,33
5553,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,34
5554,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,35
5555,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,36
5556,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,39
5557,AA F,AA F AH,AH,small-AH,True,2,2,awful,3571,42


In [21]:
all_instances_df.groupby(["base_phones", "post_divergence"]).apply(lambda xs: len(xs.inflected.unique())).sort_values()

  all_instances_df.groupby(["base_phones", "post_divergence"]).apply(lambda xs: len(xs.inflected.unique())).sort_values()


base_phones  post_divergence
AA F         AH                  1
             L                   1
AH N         Z                   1
AY OW        AH                  1
R OW         IH                  1
                                ..
L AY         T                  16
AH N         AH                 16
R OW         Z                  16
P EY         T                  16
K OW         L                  17
Length: 106, dtype: int64

## Behavioral tests

In [30]:
from typing import Optional, TypeAlias
import torch
from src.analysis.analogy import nxn_cos_sim
import logging

L = logging.getLogger(__name__)

PredictionEquivalenceKey: TypeAlias = tuple
PredictionEquivalenceCollection: TypeAlias = dict[PredictionEquivalenceKey, dict[str, set[int]]]


def iter_equivalences(
        config, all_cross_instances, agg_src: np.ndarray,
        num_samples=100, max_num_vector_samples=250,
        seed=None,):
    
    # Pre-compute lookup from label idx, instance idx to flat idx
    if isinstance(agg_src, torch.Tensor):
        agg_src = agg_src.cpu().numpy()
    flat_idx_lookup = {(label_idx, instance_idx, phoneme_idx): flat_idx
                       for flat_idx, (label_idx, instance_idx, phoneme_idx) in enumerate(agg_src)}

    if seed is not None:
        np.random.seed(seed)

    if "group_by" in config:
        grouper = all_cross_instances.groupby(config["group_by"])
    else:
        grouper = [("", all_cross_instances)]

    for group, rows in tqdm(grouper, leave=False):
        try:
            if "base_query" in config:
                rows_from = rows.query(config["base_query"])
            else:
                rows_from = rows

            if "inflected_query" in config:
                rows_to = rows.query(config["inflected_query"])
            else:
                rows_to = rows

            if "all_query" in config:
                rows_from = rows_from.query(config["all_query"])
                rows_to = rows_to.query(config["all_query"])

            inflection_from = rows_from.inflection.iloc[0]
            inflection_to = rows_to.inflection.iloc[0]
        except IndexError:
            continue

        if len(rows_from) == 0 or len(rows_to) == 0:
            continue

        # prepare equivalences for 'from' and 'to' groups.
        # equivalences define the set of instances over which we can average representations
        # before computing the analogy.
        if "equivalence_keys" in config:
            from_equivalence_keys = config["equivalence_keys"]
            to_equivalence_keys = config["equivalence_keys"]
        else:
            from_equivalence_keys = ["inflected_phones"]
            to_equivalence_keys = ["inflected_phones"]

        # we must group on at least the forms themselves
        assert set(["inflected_phones"]) <= set(from_equivalence_keys)
        assert set(["inflected_phones"]) <= set(to_equivalence_keys)

        from_equiv = rows_from.groupby(from_equivalence_keys)
        to_equiv = rows_to.groupby(to_equivalence_keys)
        from_equiv_labels = [k for k, count in from_equiv.size().items() if count >= 1]
        to_equiv_labels = [k for k, count in to_equiv.size().items() if count >= 1]

        if len(set(from_equiv_labels) | set(to_equiv_labels)) <= 1:
            # not enough labels to support transfer.
            L.error(f"Skipping {group} due to insufficient labels")
            continue

        # Make sure labels are tuples
        if not isinstance(from_equiv_labels[0], tuple):
            from_equiv_labels = [(label,) for label in from_equiv_labels]
        if not isinstance(to_equiv_labels[0], tuple):
            to_equiv_labels = [(label,) for label in to_equiv_labels]

        # sample pairs of base forms
        candidate_pairs = [(x, y) for x, y in itertools.product(from_equiv_labels, to_equiv_labels) if x != y]
        num_samples_i = min(num_samples, len(candidate_pairs))
        samples = np.random.choice(len(candidate_pairs), num_samples_i, replace=False)

        for idx in tqdm(samples, leave=False):
            from_equiv_label_i, to_equiv_label_i = candidate_pairs[idx]
            rows_from_i = from_equiv.get_group(tuple(from_equiv_label_i))
            rows_to_i = to_equiv.get_group(tuple(to_equiv_label_i))

            # sample pairs for comparison across the two forms
            n = min(max_num_vector_samples, max(len(rows_from_i), len(rows_to_i)))
            if len(rows_from_i) < n:
                rows_from_i = rows_from_i.sample(n, replace=True)
            elif len(rows_from_i) > n:
                rows_from_i = rows_from_i.sample(n, replace=False)

            if len(rows_to_i) < n:
                rows_to_i = rows_to_i.sample(n, replace=True)
            elif len(rows_to_i) > n:
                rows_to_i = rows_to_i.sample(n, replace=False)

            from_label = rows_from_i.inflected.iloc[0]
            from_idx = rows_from_i.inflected_idx.iloc[0]
            to_label = rows_to_i.inflected.iloc[0]
            to_idx = rows_to_i.inflected_idx.iloc[0]

            # what are the "base" and "inflected" forms?
            from_base_phones = rows_from_i.base_phones.iloc[0].split()
            from_post_divergence = rows_from_i.post_divergence.iloc[0].split()
            to_base_phones = rows_to_i.base_phones.iloc[0].split()
            to_post_divergence = rows_to_i.post_divergence.iloc[0].split()

            # compute individual representation indices
            from_inflected_flat_idx = torch.tensor(
                [flat_idx_lookup[(row.inflected_idx, row.inflected_instance_idx, row.next_phoneme_idx)]
                 for _, row in rows_from_i.iterrows()])
            # TODO design choice: do we take repr from previous phoneme or averaged over all previous
            # phonemes?
            from_base_flat_idx = torch.tensor(
                [flat_idx_lookup[(row.inflected_idx, row.inflected_instance_idx, row.next_phoneme_idx - 1)]
                 for _, row in rows_from_i.iterrows()])
            to_base_flat_idx = torch.tensor(
                [flat_idx_lookup[(row.inflected_idx, row.inflected_instance_idx, row.next_phoneme_idx - 1)]
                 for _, row in rows_to_i.iterrows()])
            
            yield {
                "group": group,

                "from_label": from_label,
                "from_idx": from_idx,
                "to_label": to_label,
                "to_idx": to_idx,

                "from_inflected_phones": rows_from_i.inflected_phones.iloc[0],
                "from_base_phones": " ".join(from_base_phones),
                "from_post_divergence": " ".join(from_post_divergence),

                "to_inflected_phones": rows_to_i.inflected_phones.iloc[0],
                "to_base_phones": " ".join(to_base_phones),
                "to_post_divergence": " ".join(to_post_divergence),

                "inflection_from": inflection_from,
                "inflection_to": inflection_to,
                "from_equiv_label_i": from_equiv_label_i,
                "to_equiv_label_i": to_equiv_label_i,
                
                "from_inflected_flat_idx": from_inflected_flat_idx,
                "from_base_flat_idx": from_base_flat_idx,
                "to_base_flat_idx": to_base_flat_idx,                
            }

def run_experiment_equiv_level(
        experiment_name, config,
        state_space_spec, all_cross_instances,
        agg, agg_src,
        device: str = "cpu",
        verbose=False,
        num_samples=100, max_num_vector_samples=250,
        seed=None,
        prediction_equivalences: Optional[PredictionEquivalenceCollection] = None,
        exclude_idxs_from_predictions: Optional[dict[int, list[int]]] = None,
        include_idxs_in_predictions: Optional[dict[int, list[int]]] = None):
    """
    Args:
        prediction_equivalences: defines a collection of equivalence classes
            instantiating different prediction evaluations. Each equivalence 
            class specifies, for a given prediction instance, a set of flat
            indices (indices into `agg` which should be counted as "correct")
            for that prediction instance. The config item `prediction_equivalence_keys`
            determines which properties of a sample returned by `iter_equivalences`
            are used to map from item to prediction instance.

            If this is `None`, then prediction success is determined based on
            `include_idxs_in_predictions` (label indices), with a backup to
            simply matching the single ground truth inflected label.
    """
    print(experiment_name)

    prediction_equivalences_tensor = None
    if prediction_equivalences is not None:
        if include_idxs_in_predictions is not None:
            raise ValueError("Cannot specify both `prediction_equivalences` and `include_idxs_in_predictions`")
        if "prediction_equivalence_keys" not in config:
            raise ValueError("`prediction_equivalence_keys` must be specified in `config` if `prediction_equivalences` is provided")
        
        prediction_equivalences_tensor = {
            key: {
                equiv: torch.tensor(flat_idxs)
                for equiv, flat_idxs in items.items()
            }
            for key, items in prediction_equivalences.items()
        }

    # move data to device
    agg = torch.tensor(agg).to(device)
    agg_src = torch.tensor(agg_src).to(device)
    
    results = []
    for sample in iter_equivalences(
            config, all_cross_instances, agg_src,
            num_samples=num_samples,
            max_num_vector_samples=max_num_vector_samples,
            seed=seed):

        from_inflected_flat_idx = sample["from_inflected_flat_idx"]
        from_base_flat_idx = sample["from_base_flat_idx"]
        to_base_flat_idx = sample["to_base_flat_idx"]

        # Critical analogy logic
        pair_difference = agg[from_inflected_flat_idx] - agg[from_base_flat_idx]
        pair_base = agg[to_base_flat_idx]

        pair_predicted = pair_base + pair_difference
        pair_predicted /= torch.norm(pair_predicted, dim=1, keepdim=True)

        ### Compute ranks over entire set of word tokens

        references, references_src = agg, agg_src
        with torch.no_grad():
            dists = 1 - nxn_cos_sim(pair_predicted, references)
            # mean over instances of pair
            dists = dists.mean(0)

        if exclude_idxs_from_predictions is not None:
            invalid_idxs = torch.tensor(list(exclude_idxs_from_predictions[sample["inflection_to"], sample["to_idx"]]))
            invalid_flat_idxs = torch.where(torch.isin(references_src[:, 0], invalid_idxs))[0]
            dists[invalid_flat_idxs] = torch.inf

        sorted_indices = dists.argsort()
        ranks = torch.zeros_like(sorted_indices)
        ranks[sorted_indices] = torch.arange(len(sorted_indices)).to(sorted_indices)

        ### Prepare evaluations

        # Map from evaluation name to a tensor of valid flat idxs for this prediction problem
        evaluations: dict[str, torch.Tensor] = {}

        if prediction_equivalences_tensor is not None:
            prediction_equivalence_keys = config["prediction_equivalence_keys"]
            prediction_equivalence_keys = tuple(sample[key] for key in prediction_equivalence_keys)
            if prediction_equivalence_keys not in prediction_equivalences_tensor:
                continue

            for subexperiment, valid_flat_idxs in prediction_equivalences_tensor[prediction_equivalence_keys].items():
                evaluations[subexperiment] = valid_flat_idxs
        else:
            if include_idxs_in_predictions is not None:
                valid_label_idxs = torch.tensor(list(include_idxs_in_predictions[sample["to_idx"]])).to(device)
            else:
                valid_label_idxs = torch.tensor([sample["to_idx"]]).to(device)
            valid_flat_idxs = torch.where(torch.isin(references_src[:, 0], valid_label_idxs))[0]
            evaluations[""] = valid_flat_idxs

        ### Run evaluations

        evaluation_results = {}
        for evaluation, valid_flat_idxs in evaluations.items():
            nearest_neighbor = references_src[sorted_indices[0]]

            # terminology
            # target: nearest valid embedding for this evaluation
            target_rank, target_subidx = torch.min(ranks[valid_flat_idxs], dim=0)
            target_rank, target_idx = target_rank.item(), valid_flat_idxs[target_subidx].item()
            target_distance = dists[target_idx].item()
            target_label_idx = references_src[target_idx, 0].item()
            target_instance_idx = references_src[target_idx, 1].item()
            target_phoneme_idx = references_src[target_idx, 2].item()
            target_label = state_space_spec.labels[target_label_idx]
            target_phones = cut_phonemic_forms.loc[target_label].loc[target_instance_idx]

            # predicted: nearest neighbor
            predicted_label_idx = nearest_neighbor[0].item()
            predicted_instance_idx = nearest_neighbor[1].item()
            predicted_label = state_space_spec.labels[predicted_label_idx]

            evaluation_results[evaluation] = {
                "target_rank": target_rank,
                "target_distance": target_distance,
                "target_label_idx": target_label_idx,
                "target_instance_idx": target_instance_idx,
                "target_phoneme_idx": target_phoneme_idx,
                "target_label": target_label,
                "target_phones": target_phones,

                "predicted_distance": dists[sorted_indices[0]].item(),
                "predicted_label_idx": predicted_label_idx,
                "predicted_instance_idx": predicted_instance_idx,
                "predicted_phoneme_idx": nearest_neighbor[2].item(),
                "predicted_label": predicted_label,
                "predicted_phones": cut_phonemic_forms.loc[predicted_label].loc[predicted_instance_idx],
            }

        if verbose:
            for dist, (label_idx, instance_idx, _) in zip(dists[sorted_indices[:5]], references_src[sorted_indices[:5]]):
                print(f"{sample['group']} {sample['from_equiv_label_i']} -> {sample['to_equiv_label_i']} {evaluation}: {state_space_spec.labels[label_idx]} {instance_idx}")

        # Merge into a single flat dictionary
        results_i = {}
        for evaluation_name, evaluation in evaluation_results.items():
            for key, value in evaluation.items():
                output_key = f"{evaluation_name}_{key}" if evaluation_name else key
                results_i[output_key] = value
        results_i.update({
            "group": sample["group"],

            "from": sample["from_label"],
            "to": sample["to_label"],

            "from_inflected_phones": sample["from_inflected_phones"],
            "from_base_phones": sample["from_base_phones"],
            "from_post_divergence": sample["from_post_divergence"],

            "to_inflected_phones": sample["to_inflected_phones"],
            "to_base_phones": sample["to_base_phones"],
            "to_post_divergence": sample["to_post_divergence"],

            "inflection_from": sample["inflection_from"],
            "inflection_to": sample["inflection_to"],
            "from_equiv_label": sample["from_equiv_label_i"],
            "to_equiv_label": sample["to_equiv_label_i"],
        })

        results.append(results_i)

    return pd.DataFrame(results)

In [23]:
experiments = {
    f"{a}_to_{b}": {
        "base_query": f"inflection == '{a}'",
        "inflected_query": f"inflection == '{b}'",
        "equivalence_keys": ["inflected_phones", "inflected"],
        "prediction_equivalence_keys": ["to_inflected_phones"],
    }
    for a, b in itertools.product(next_phon_set, repeat=2)
}

In [24]:
small_targets = all_instances_df[all_instances_df.inflection.str.startswith("small-")].inflection.str.split("small-").str[1].unique()
for phone in small_targets:
    for source_phone in next_phon_set:
        experiments[f"{source_phone}-to-small-{phone}"] = {
            "base_query": f"inflection == '{source_phone}'",
            "inflected_query": f"inflection == 'small-{phone}'",
            "equivalence_keys": ["inflected_phones", "inflected"],
            "prediction_equivalence_keys": ["to_inflected_phones"],
        }

In [None]:
# ret = run_experiment_equiv_level(
#     "tst",
#     # config={"base_query": "inflection == 'D'",
#     #         "inflected_query": "inflection == 'T'"},
#     config=experiments["D_to_D"],
#     state_space_spec=state_space_spec,
#     all_cross_instances=all_instances_df,
#     prediction_equivalences=all_prediction_equivalences,
#     agg=agg,
#     agg_src=agg_src,
#     num_samples=20,
#     device="cpu",
# )

tst


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
experiment_results = pd.concat({
    experiment: run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_instances_df,
        agg, agg_src,
        prediction_equivalences=all_prediction_equivalences,
        num_samples=1000,
        max_num_vector_samples=100,
        seed=seed,
        device="cuda:2")
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])

In [None]:
experiment_results.to_csv(f"{output_dir}/pseudocausal_broad_experiment_results.csv")

### Analyze

In [None]:
def get_predicted_phone(row):
    predicted_phones = row.predicted_phones.split(" ")
    predicted_phone_idx = target_cohort_length
    if len(predicted_phones) - 1 < predicted_phone_idx:
        return None
    return predicted_phones[predicted_phone_idx]
def get_predicted_base_phones(row):
    predicted_phones = row.predicted_phones.split(" ")
    return " ".join(predicted_phones[:target_cohort_length])

experiment_results["predicted_phone"] = experiment_results.apply(get_predicted_phone, axis=1)
experiment_results["predicted_base_phones"] = experiment_results.apply(get_predicted_base_phones, axis=1)
experiment_results["correct"] = experiment_results.predicted_phone == experiment_results.to_post_divergence
experiment_results["correct_base"] = experiment_results.predicted_base_phones == experiment_results.to_base_phones
experiment_results["correct_distinct_base"] = experiment_results.correct_base & (experiment_results.predicted_label_idx != experiment_results.gt_label_idx)
experiment_results["correct_or_correct_base"] = experiment_results.correct | experiment_results.correct_base
experiment_results["control"] = experiment_results.inflection_to.str.split("-").str[-1] != experiment_results.inflection_from

post_div_set = experiment_results.groupby("to_base_phones").apply(lambda xs: frozenset(xs.to_post_divergence))
experiment_results["post_div_set"] = experiment_results.to_base_phones.map(post_div_set)

In [None]:
main_results = experiment_results[~experiment_results.index.get_level_values(0).str.contains("to-small-")]
small_results = experiment_results[experiment_results.index.get_level_values(0).str.contains("to-small-")]

### Main experiment results

In [None]:
main_results[["control", "correct_base", "correct"]].value_counts().sort_index()

In [None]:
main_results[["control", "correct_base", "correct"]].value_counts().groupby("control").apply(lambda xs: xs / xs.sum()).sort_index()

In [None]:
plot_all_phones = False

full_phone_list = sorted(next_phon_set)
if plot_all_phones:
    full_phone_list += sorted(set(main_results.predicted_phone.fillna("NA").unique()) - set(full_phone_list))
heatmap_results = main_results \
    .groupby(["control", "correct_base", "inflection_from"]).predicted_phone.value_counts(normalize=True) \
    .reindex(pd.MultiIndex.from_product([[False, True], [False, True], sorted(next_phon_set), full_phone_list],
                                        names=["control", "correct_base", "inflection_from", "predicted_phone"])).fillna(0)

g = sns.FacetGrid(data=heatmap_results.reset_index(), row="control", col="correct_base", height=5, aspect=2 if plot_all_phones else 1.2, sharex=False, sharey=False)
def f(data, **kwargs):
    sns.heatmap(data.pivot_table(index="inflection_from", columns="predicted_phone", values="proportion").reindex(full_phone_list, axis=1))
g.map_dataframe(f, annot=True, cmap="Blues")

In [None]:
plot_all_phones = True

full_phone_list = sorted(next_phon_set)
if plot_all_phones:
    # full_phone_list += sorted(set(main_results.predicted_phone.fillna("NA").unique()) - set(full_phone_list))
    # DEV plot just the non-studied phones
    full_phone_list = sorted(set(main_results.predicted_phone.fillna("NA").unique()) - set(full_phone_list))
heatmap_results = main_results \
    .groupby(["control", "correct_base", "inflection_from"]).predicted_phone.value_counts(normalize=True) \
    .reindex(pd.MultiIndex.from_product([[False, True], [False, True], sorted(next_phon_set), full_phone_list],
                                        names=["control", "correct_base", "inflection_from", "predicted_phone"])).fillna(0)

g = sns.FacetGrid(data=heatmap_results.reset_index(), row="control", col="correct_base", height=5, aspect=2 if plot_all_phones else 1, sharex=False, sharey=False)
def f(data, **kwargs):
    sns.heatmap(data.pivot_table(index="inflection_from", columns="predicted_phone", values="proportion").reindex(full_phone_list, axis=1))
g.map_dataframe(f, annot=True, cmap="Blues")

In [None]:
main_results.query("not control and not correct")[["from_inflected_phones", "gt_label", "to_base_phones", "correct", "correct_base", "predicted_label", "predicted_phones", "from_post_divergence", "predicted_phone"]].sample(20).sort_values(["correct", "correct_base"], ascending=False)

In [None]:
sns.barplot(data=main_results,
            x="inflection_to", y="correct", hue="control")

In [None]:
sns.displot(main_results.query("not control").groupby(["from", "control"]).correct.mean().sort_values())

In [None]:
sns.displot(main_results.query("not control").groupby(["to", "control"]).correct.mean().sort_values())

In [None]:
d = main_results.query("not control").groupby(["from_base_phones"]).predicted_base_phones.value_counts(normalize=True).unstack().fillna(0)
e = -(d * np.log2(d)).sum(axis=1)
e.sort_values()

In [None]:
diversity = main_results.query("not control").groupby(["from", "from_base_phones", "from_post_divergence"]).predicted_phone.value_counts(normalize=True).unstack().fillna(0)
entropy = - (diversity * np.log2(diversity)).sum(axis=1)
entropy.sort_values().tail(20)

In [None]:
entropy.sort_values().head(20)

In [None]:
base_diversity = main_results.query("not control").groupby(["from", "from_base_phones", "from_post_divergence"]).predicted_base_phones.value_counts(normalize=True).unstack().fillna(0)
base_entropy = - (diversity * np.log2(diversity)).sum(axis=1)
base_entropy.sort_values().tail(20)

In [None]:
to_diversity = main_results.query("not control").groupby(["to", "to_base_phones", "to_post_divergence"]).predicted_phone.value_counts(normalize=True).unstack().fillna(0)
to_entropy = - (to_diversity * np.log2(to_diversity)).sum(axis=1)
to_entropy.sort_values().tail(20)

In [None]:
to_diversity.loc["careering"].melt().sort_values("value", ascending=False).head(10)

In [None]:
base_diversity.loc["unseen"].melt().sort_values("value", ascending=False).head(20)

In [None]:
main_results["from_base_final"] = main_results.from_base_phones.str.split(" ").str[-1]

In [None]:
main_results.query("not control and `to` == 'licenses'")[["from", "from_base_phones", "correct", "correct_base", "to", "predicted_label", "predicted_base_phones", "predicted_phone"]]

In [None]:
main_results.query("not control").groupby(["to", "to_base_phones"]).correct.agg(["mean", "count"]).query("count >= 10").sort_values("mean")

In [None]:
main_results.query("not control").groupby(["from_base_phones"]).correct_base.agg(["mean", "count"]).query("count >= 4").sort_values("mean")

In [None]:
main_with_freq = pd.merge(main_results, word_freq_df.LogFreq.rename("from_freq"),
                            left_on="from", right_index=True)
main_with_freq = pd.merge(main_with_freq, word_freq_df.LogFreq.rename("to_freq"),
                            left_on="to", right_index=True)

In [None]:
def get_mass(group):
    group = group.drop_duplicates("inflected")
    group = pd.merge(group, word_freq_df.LogFreq,
                     left_on="inflected", right_index=True).set_index(["inflected", "post_divergence"])
    # mass = group.LogFreq ** 10 / (group.LogFreq ** 10).sum()
    mass = group.LogFreq / group.LogFreq.sum()
    return mass
    
masses = all_instances_df[~all_instances_df.inflection.str.startswith("small-")].groupby("base_phones").apply(get_mass)

In [None]:
sns.displot(data=masses.reset_index(), x="LogFreq", hue="base_phones", kind="ecdf")

In [None]:
sns.regplot(data=main_with_freq.query("not control").groupby(["from", "from_freq"]).correct.agg(["mean", "count"]).reset_index().query("count >= 20"),
                x="from_freq", y="mean")

In [None]:
sns.regplot(data=main_with_freq.query("not control").groupby(["to", "to_freq"]).correct.agg(["mean", "count"]).reset_index().query("count >= 20"),
                x="to_freq", y="mean")

### Small cohorts

In [None]:
small_results[["control", "correct_base", "correct"]].value_counts().sort_index()

In [None]:
small_results[["control", "correct_base", "correct"]].value_counts().groupby("control").apply(lambda xs: xs / xs.sum()).sort_index()

In [None]:
small_results["attested"] = small_results.apply(lambda x: x.from_post_divergence in eval(x.post_div_set) if isinstance(x.post_div_set, str) else x.from_post_divergence in x.post_div_set, axis=1)
small_results["condition"] = "main"
small_results.loc[small_results.control & small_results.attested, "condition"] = "control_attested"
small_results.loc[small_results.control & ~small_results.attested, "condition"] = "control_unattested"

In [None]:
sns.catplot(data=small_results, x="inflection_to", y="correct", hue="condition", kind="bar", aspect=2.5)

In [None]:
plot_all_phones = False

full_phone_list = sorted(next_phon_set)
if plot_all_phones:
    full_phone_list += sorted(set(small_results.predicted_phone.fillna("NA").unique()) - set(full_phone_list))
heatmap_results = small_results \
    .groupby(["control", "correct_base", "inflection_from"]).predicted_phone.value_counts(normalize=True) \
    .reindex(pd.MultiIndex.from_product([[False, True], [False, True], sorted(next_phon_set), full_phone_list],
                                        names=["control", "correct_base", "inflection_from", "predicted_phone"])).fillna(0)

g = sns.FacetGrid(data=heatmap_results.reset_index(), row="control", col="correct_base", height=5, aspect=2 if plot_all_phones else 1.2, sharex=False, sharey=False)
def f(data, **kwargs):
    sns.heatmap(data.pivot_table(index="inflection_from", columns="predicted_phone", values="proportion").reindex(full_phone_list, axis=1))
g.map_dataframe(f, annot=True, cmap="Blues")

In [None]:
small_results.query("not control")[["from_inflected_phones", "gt_label", "to_inflected_phones", "correct", "correct_base", "predicted_label", "predicted_phones"]].sample(20) \
    .sort_values(["correct", "correct_base"], ascending=False)

In [None]:
experiment_results[experiment_results.inflection_to.str.startswith("small-")].groupby(["to_base_phones", "inflection_from"]).predicted_phone \
    .value_counts(normalize=True).unstack().fillna(0)#.reindex(columns=next_phon_set).fillna(0).sort_index().sort_index(axis=1)

In [None]:
# TODO look at a bunch of individual prediction examples to get an intuition for what is happening here.

In [None]:
experiment_results[experiment_results.inflection_to.str.startswith("small-")].groupby("post_div_set").correct.agg(["count", "mean"]).sort_values("mean")

In [None]:
experiment_results[(experiment_results.to_base_phones == "AA F") & (experiment_results.to_post_divergence == "T")][["from_inflected_phones", "from_post_divergence", "to", "predicted_label", "predicted_phones"]]

In [None]:
# Plot predicted phone distributions for predicted words with correct base
small_cohort_results = small_results.query("correct_base").groupby(["to_base_phones", "to_post_divergence"]).predicted_phone.value_counts(normalize=True).unstack().fillna(0).reindex(columns=next_phon_set).fillna(0).sort_index().sort_index(axis=1)
small_cohort_results

n_cols = 3
n_rows = int(np.ceil(small_cohort_results.index.get_level_values("to_base_phones").nunique() / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

for i, (ax, ((base_phones, target), row)) in enumerate(zip(axes.flat, small_cohort_results.sample(n_rows * n_cols).sort_index().iterrows())):
    row = row.rename("accuracy").to_frame().reset_index()
    row["in_cohort"] = row.predicted_phone.isin(expt_cohort_small.loc[base_phones])
    sns.barplot(data=row, x="predicted_phone", y="accuracy", ax=ax)
    ax.set_title(f"{base_phones} + {target}")
    ax.set_xlabel("Predicted phone")
    ax.set_ylabel("Probability")
    ax.set_ylim(0, 1)
    ax.grid(axis="y")

plt.tight_layout()bhdvh

In [None]:
# Plot predicted phone distributions for predicted words with correct base
small_cohort_results = small_results.query("not correct_base").groupby(["to_base_phones", "to_post_divergence"]).predicted_phone.value_counts(normalize=True).unstack().fillna(0).reindex(columns=next_phon_set).fillna(0).sort_index().sort_index(axis=1)
small_cohort_results

n_cols = 3
n_rows = int(np.ceil(small_cohort_results.index.get_level_values("to_base_phones").nunique() / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

for i, (ax, ((base_phones, target), row)) in enumerate(zip(axes.flat, small_cohort_results.sample(n_rows * n_cols).sort_index().iterrows())):
    row = row.rename("accuracy").to_frame().reset_index()
    row["in_cohort"] = row.predicted_phone.isin(expt_cohort_small.loc[base_phones])
    sns.barplot(data=row, x="predicted_phone", y="accuracy", hue="in_cohort", ax=ax)
    ax.set_title(f"{base_phones} + {target}")
    ax.set_xlabel("Predicted phone")
    ax.set_ylabel("Probability")
    ax.set_ylim(0, 1)
    ax.legend(title="In attested cohort?")
    ax.grid(axis="y")

plt.tight_layout()

In [None]:
sns.barplot(data=pd.concat({"small": small_results.query("not control").groupby("to_post_divergence").correct.mean(),
"main": main_results.query("not control").groupby("to_post_divergence").correct.mean()}, names=["size"]).reset_index(),
    x="to_post_divergence", y="correct", hue="size")

### Save

In [None]:
experiment_results.to_csv(f"{output_dir}/pseudocausal_broad_experiment_results.csv")

In [None]:
experiment_results = pd.read_csv(f"{output_dir}/pseudocausal_broad_experiment_results.csv", index_col=[0, 1])

In [None]:
experiment_results