In [34]:
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm

from src.analysis import analogy_pseudocausal
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory

In [24]:
base_model = "w2v2_pc_8"

model_class = "ffff_32-pc-mAP1"#discrim-rnn_32-pc-mAP1"
model_name = "word_broad_10frames_fixedlen25"

train_dataset = "librispeech-train-clean-100"
# hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
hidden_states_path = f"/scratch/jgauthier/{base_model}_{train_dataset}.h5"
state_space_specs_path = f"outputs/analogy/inputs/{train_dataset}/w2v2_pc/state_space_spec.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{train_dataset}.npy"

output_dir = f"outputs/analogy_pseudocausal_broad/inputs/{train_dataset}/w2v2_pc/"

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

agg_fns = [
    ("mean_within_cut", "phoneme")
]

## Load data

In [18]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

  value = self._g_getattr(self._v_node, name)


In [22]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec)
trajectory = aggregate_state_trajectory(trajectory, state_space_spec, agg_fns[0])

  0%|          | 0/32046 [00:00<?, ?it/s]

Aggregating:   0%|          | 0/32046 [00:00<?, ?label/s]

In [23]:
agg, agg_src = flatten_trajectory(trajectory)

In [25]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

agg_flat_idxs = pd.Series(list(range(len(agg_src))),
                          index=pd.MultiIndex.from_tuples([tuple(xs) for xs in agg_src],
                                                          names=["label_idx", "instance_idx", "frame_idx"]))
cuts_df = pd.merge(cuts_df, agg_flat_idxs.rename("traj_flat_idx"), left_index=True, right_index=True)

In [26]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [27]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
word_freq_df = word_freq_df.loc[~word_freq_df.index.duplicated()]
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq_rel", "TwitterFreq_rel", "NewsFreq_rel"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

## Prepare cohorts

In [28]:
next_phon_set = set("AH ER IH L S Z T D M N".split())
target_cohort_length = 2
# defines an alternative "small" cohort: prefixes which have only N of the above phones
target_small_cohort_size = 3
assert target_small_cohort_size < len(next_phon_set)

In [29]:
cohorts = defaultdict(set)
for phones in tqdm(cut_phonemic_forms.unique()):
    phones = tuple(phones.split())
    for i in range(len(phones)):
        cohorts[phones[:i + 1]].add(phones)

csz_next = pd.DataFrame([(" ".join(coh), " ".join(item), item[len(coh)]) for coh, items in cohorts.items()
                            for item in items if len(item) > len(coh)],
                            columns=["cohort", "item", "next_phoneme"])

  0%|          | 0/32462 [00:00<?, ?it/s]

In [30]:
expt_cohort = csz_next[csz_next.cohort.str.count(" ") == target_cohort_length - 1] \
    .groupby("cohort").filter(lambda xs: set(xs.next_phoneme) >= next_phon_set) \
    .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))
expt_cohort

  .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))


cohort
AH N    [AA, AE, AH, AO, AW, AY, B, CH, D, EH, ER, EY,...
K ER    [AA, AE, AH, AO, AW, B, CH, D, EH, ER, EY, F, ...
K OW    [AH, B, CH, D, ER, HH, IH, IY, JH, K, L, M, N,...
L AY    [AE, AH, B, D, DH, ER, F, IH, K, L, M, N, P, R...
P ER    [AA, AE, AH, AY, B, CH, D, EH, ER, EY, F, G, H...
P EY    [AH, D, ER, G, IH, JH, L, M, N, P, S, SH, T, T...
R OW    [AH, B, D, ER, G, HH, IH, IY, K, L, M, N, P, S...
dtype: object

In [14]:
# Now search for type-small cohorts -- cohorts which only have N of the phone set
expt_cohort_small = csz_next[csz_next.cohort.str.count(" ") == target_cohort_length - 1].groupby("cohort").filter(lambda xs: len(set(xs.next_phoneme)) == target_small_cohort_size and set(xs.next_phoneme) <= next_phon_set) \
    .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))
expt_cohort_small

  .groupby("cohort").apply(lambda xs: sorted(set(xs.next_phoneme)))


cohort
AA F      [AH, L, T]
AO G     [AH, ER, M]
AY OW     [AH, L, T]
EH TH    [AH, IH, N]
ER EY      [D, N, Z]
ER JH    [AH, D, IH]
ER OW      [M, N, Z]
F UW       [D, L, Z]
K OY       [L, N, T]
TH AH      [D, M, N]
W AW       [N, T, Z]
Z AY      [AH, D, S]
dtype: object

### Prepare instance-level metadata

In [31]:
all_instances = []
all_prediction_equivalences = {}

# Sample at most this many combinations of cohort + next phone
max_items_per_cohort_and_next_phone = 15

label2idx = {l: i for i, l in enumerate(state_space_spec.labels)}
for cohort, next_phons in tqdm(expt_cohort.items(), total=len(expt_cohort)):
    for phon in next_phons:
        if phon not in next_phon_set:
            continue

        inflected_phones = f"{cohort} {phon}"
        instances = cut_phonemic_forms[cut_phonemic_forms.str.match(f"{inflected_phones}\\b")].index

        # Pick the top K labels with the highest frequency from the cohort.
        coh_labels = instances.get_level_values("label").str.replace("'s$", "", regex=True)
        if len(coh_labels) > max_items_per_cohort_and_next_phone:
            label_freqs = word_freq_df.reindex(coh_labels.unique()).LogFreq.fillna(word_freq_df.LogFreq.min())
            keep_labels = label_freqs.nlargest(max_items_per_cohort_and_next_phone).index
            instances = instances[coh_labels.isin(keep_labels)]
            print(cohort, phon, len(instances))
        
        equiv_key = (inflected_phones,)
        if equiv_key not in all_prediction_equivalences:
            all_prediction_equivalences[equiv_key] = \
                analogy_pseudocausal.prepare_prediction_equivalences(cuts_df, cut_phonemic_forms, cohort, phon)

        for label, instance_idx in instances:
            all_instances.append({
                "base_phones": cohort,
                "inflected_phones": inflected_phones,
                "post_divergence": phon,

                "inflection": phon,
                "next_phoneme_in_restricted_set": phon in next_phon_set,

                "cohort_length": target_cohort_length,
                "next_phoneme_idx": target_cohort_length,

                "inflected": label,
                "inflected_idx": label2idx[label],
                "inflected_instance_idx": instance_idx,
            })

  0%|          | 0/7 [00:00<?, ?it/s]

AH N AH 155
AH N D 456
AH N IH 89
AH N L 163
AH N M 32
AH N N 117
AH N S 67
AH N T 190
K ER AH 87
K ER IH 94
K ER L 43
K ER N 75
K ER S 26
K ER T 73
K OW D 23
K OW L 187
K OW M 21
K OW S 86
K OW T 103
L AY AH 76
L AY IH 98
L AY N 190
L AY S 23
L AY T 331
L AY Z 54
P ER D 16
P ER L 72
P ER M 137
P ER S 521
P ER T 225
P ER Z 39
P EY D 97
P EY IH 27
P EY L 120
P EY N 284
P EY S 72
P EY T 76
R OW D 215
R OW L 131
R OW M 223
R OW S 16
R OW T 87
R OW Z 195


In [37]:
for cohort, next_phons in tqdm(expt_cohort_small.items(), total=len(expt_cohort_small)):
    for phon in next_phons:
        if phon not in next_phon_set:
            continue
        inflected_phones = f"{cohort} {phon}"
        instances = cut_phonemic_forms[cut_phonemic_forms.str.match(f"{inflected_phones}\\b")].index

        # Pick the top K labels with the highest frequency from the cohort.
        coh_labels = instances.get_level_values("label").str.replace("'s$", "", regex=True)
        if len(coh_labels) > max_items_per_cohort_and_next_phone:
            label_freqs = word_freq_df.reindex(coh_labels.unique()).LogFreq.fillna(word_freq_df.LogFreq.min())
            keep_labels = label_freqs.nlargest(max_items_per_cohort_and_next_phone).index
            instances = instances[coh_labels.isin(keep_labels)]

        equiv_key = (inflected_phones,)
        if equiv_key not in all_prediction_equivalences:
            all_prediction_equivalences[equiv_key] = \
                analogy_pseudocausal.prepare_prediction_equivalences(cuts_df, cut_phonemic_forms,
                                                                     cohort, phon)

        for label, instance_idx in instances:
            all_instances.append({
                "base_phones": cohort,
                "inflected_phones": inflected_phones,
                "post_divergence": phon,

                "inflection": f"small-{phon}",
                "next_phoneme_in_restricted_set": phon in next_phon_set,

                "cohort_length": target_cohort_length,
                "next_phoneme_idx": target_cohort_length,

                "inflected": label,
                "inflected_idx": label2idx[label],
                "inflected_instance_idx": instance_idx,
            })

  0%|          | 0/12 [00:00<?, ?it/s]

In [38]:
all_instances_df = pd.DataFrame(all_instances)
all_instances_df

Unnamed: 0,base_phones,inflected_phones,post_divergence,inflection,next_phoneme_in_restricted_set,cohort_length,next_phoneme_idx,inflected,inflected_idx,inflected_instance_idx
0,AH N,AH N AH,AH,AH,True,2,2,another,2725,0
1,AH N,AH N AH,AH,AH,True,2,2,another,2725,1
2,AH N,AH N AH,AH,AH,True,2,2,another,2725,2
3,AH N,AH N AH,AH,AH,True,2,2,another,2725,3
4,AH N,AH N AH,AH,AH,True,2,2,another,2725,4
...,...,...,...,...,...,...,...,...,...,...
6305,W AW,W AW Z,Z,small-Z,True,2,2,wowzer,28693,2
6306,Z AY,Z AY AH,AH,small-AH,True,2,2,zion,25333,0
6307,Z AY,Z AY AH,AH,small-AH,True,2,2,zion,25333,1
6308,Z AY,Z AY D,D,small-D,True,2,2,zuyder,20487,0


In [39]:
all_instances_df.to_csv(f"{output_dir}/instances.csv")

In [40]:
torch.save(all_prediction_equivalences, f"{output_dir}/prediction_equivalences.pt")

In [41]:
torch.save([agg, agg_src], f"{output_dir}/trajectories.pt")