In [70]:
from collections import defaultdict

from loguru import logger as L
import numpy as np
from omegaconf import OmegaConf
import pandas as pd
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec

In [87]:
base_model = "w2v2_pc_8"

model_class = "ffff_32-pc-mAP1"#discrim-rnn_32-pc-mAP1"
model_name = "word_broad_10frames_fixedlen25"

train_dataset = "librispeech-train-clean-100"
dataset = train_dataset
# hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.h5"

pos_counts_path = "data/pos_counts.pkl"

experiment = "syllable_at_0"

output_dir = f"outputs/analogy_pseudocausal_broad/inputs/{train_dataset}/w2v2_pc/{experiment}"

seed = 42
max_samples_per_word = 100

metric = "cosine"

agg_fns = [
    ("mean_within_cut", "phoneme")
]

In [56]:
# load OmegaConf from yaml with `experiment`
config = OmegaConf.load(f"conf/experiments/analogy_pseudocausal/{experiment}.yaml")
config.unit_level = "syllable"
# DEV
config.next_units = None

## Load data

In [8]:
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
state_space_spec = state_space_spec.subsample_instances(max_samples_per_word)

In [None]:
cuts_df = state_space_spec.cuts.xs(config.unit_level, level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

In [50]:
if type(cuts_df.description.iloc[0]) == tuple:
    cuts_df["description"] = cuts_df.description.apply(''.join)
cut_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [51]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
word_freq_df = word_freq_df.loc[~word_freq_df.index.duplicated()]
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq_rel", "TwitterFreq_rel", "NewsFreq_rel"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

## Prepare cohorts

In [71]:
if config.next_units is None or not config.next_units.strip():
    all_next_units = cuts_df.description.value_counts()
    if len(all_next_units) > 100:
        L.warning("Next unit set is large, taking the top 100")
        all_next_units = all_next_units[:100]
    next_unit_set = set(all_next_units.index)
else:
    next_unit_set = set(config.next_units.strip().split())
assert config.target_small_cohort_size < len(next_unit_set)



In [73]:
cohorts = defaultdict(set)
for units in tqdm(cut_forms.unique()):
    units = tuple(units.split())
    for i in range(len(units) + 1):
        cohorts[units[:i]].add(units)

csz_next = pd.DataFrame([(" ".join(coh), " ".join(item), item[len(coh)]) for coh, items in cohorts.items()
                            for item in items if len(item) > len(coh)],
                            columns=["cohort", "item", "next_unit"])

  0%|          | 0/32460 [00:00<?, ?it/s]

In [74]:
if config.target_cohort_length == 0:
    expt_cohort = csz_next[csz_next.cohort == ""]
else:
    expt_cohort = csz_next[csz_next.cohort.str.count(" ") == config.target_cohort_length - 1]

# removed constraint from below -- don't only include cohorts which cover all next units
# .groupby("cohort").filter(lambda xs: set(xs.next_unit) >= next_unit_set) \

expt_cohort = expt_cohort \
    .groupby("cohort").apply(lambda xs: sorted(set(xs.next_unit)))
expt_cohort

  .groupby("cohort").apply(lambda xs: sorted(set(xs.next_unit)))


cohort
    [AA, AAB, AAD, AADZ, AAFT, AAG, AAK, AAKS, AAL...
dtype: object

In [75]:
# Now search for type-small cohorts -- cohorts which only have N of the phone set
if config.target_cohort_length == 0:
    expt_cohort_small = csz_next[csz_next.cohort == ""]
else:
    expt_cohort_small = csz_next[csz_next.cohort.str.count(" ") == config.target_cohort_length - 1]
expt_cohort_small = expt_cohort_small \
    .groupby("cohort").filter(lambda xs: len(set(xs.next_unit)) == config.target_small_cohort_size and set(xs.next_unit) <= next_unit_set) \
    .groupby("cohort").apply(lambda xs: sorted(set(xs.next_unit)))
expt_cohort_small

  .groupby("cohort").apply(lambda xs: sorted(set(xs.next_unit)))


Unnamed: 0_level_0,cohort,item,next_unit
cohort,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1


### Prepare instance-level metadata

In [76]:
all_instances = []

# Sample at most this many combinations of cohort + next unit
max_items_per_cohort_and_next_unit = 15

label2idx = {l: i for i, l in enumerate(state_space_spec.labels)}
for cohort, next_units in tqdm(expt_cohort.items(), total=len(expt_cohort)):
    for unit in next_units:
        if unit not in next_unit_set:
            continue

        inflected_phones = f"{cohort} {unit}" if cohort else unit
        instances = cut_forms[cut_forms.str.match(f"{inflected_phones}\\b")].index

        # Pick the top K labels with the highest frequency from the cohort.
        coh_labels = instances.get_level_values("label").str.replace("'s$", "", regex=True)
        if len(coh_labels) > max_items_per_cohort_and_next_unit:
            label_freqs = word_freq_df.reindex(coh_labels.unique()).LogFreq.fillna(word_freq_df.LogFreq.min())
            keep_labels = label_freqs.nlargest(max_items_per_cohort_and_next_unit).index
            instances = instances[coh_labels.isin(keep_labels)]
            print(cohort, unit, len(instances))

        # equiv_key = (inflected_phones,)
        # if equiv_key not in all_prediction_equivalences:
        #     all_prediction_equivalences[equiv_key] = \
        #         analogy_pseudocausal.prepare_prediction_equivalences(cuts_df, cut_phonemic_forms, cohort, phon)

        for label, instance_idx in instances:
            all_instances.append({
                "base_phones": cohort,
                "inflected_phones": inflected_phones,
                "post_divergence": unit,

                "inflection": unit,
                "next_unit_in_restricted_set": unit in next_unit_set,

                "cohort_length": config.target_cohort_length,
                "next_phoneme_idx": config.target_cohort_length,

                "inflected": label,
                "inflected_idx": label2idx[label],
                "inflected_instance_idx": instance_idx,
            })

  0%|          | 0/1 [00:00<?, ?it/s]

 AA 544
 AE 559
 AH 1441
 AHN 807
 AHS 100
 AY 690
 BAHL 16
 BER 126
 BIH 971
 BIY 725
 CHER 45
 DAH 326
 DAHN 130
 DER 329
 DIH 690
 DIHD 100
 DIY 218
 EH 713
 EHN 415
 ER 596
 EY 305
 FAH 244
 FAOR 558
 FER 591
 FIH 489
 HHAE 703
 IH 524
 IHK 861
 IHM 451
 IHN 904
 IHNG 255
 IY 570
 JHAH 199
 KAAN 356
 KAH 688
 KAHL 82
 KAHM 572
 KAHN 742
 KIHNG 180
 LAH 213
 LER 72
 LEY 605
 LIH 640
 LIHNG 84
 LIY 339
 LOW 342
 MAE 639
 MAH 483
 MAHN 35
 MEH 860
 MER 232
 MIH 604
 MIY 345
 NAH 265
 NER 109
 NEY 550
 NIH 120
 NIY 275
 OW 739
 PAA 514
 PAH 524
 PER 790
 PIY 468
 PRAH 394
 RAH 228
 REH 311
 RIH 646
 RIY 471
 SAH 751
 SAHN 390
 SAY 370
 SEH 397
 SER 728
 SIH 873
 SIHNG 279
 SIHZ 17
 SIY 589
 TAH 316
 TEH 373
 TER 243
 TEY 389
 TIH 117
 TIHNG 17
 TIY 265
 VER 215
 VIH 483
 WIH 579
 YUW 560


In [77]:
for cohort, next_phons in tqdm(expt_cohort_small.items(), total=len(expt_cohort_small)):
    for phon in next_phons:
        if phon not in next_unit_set:
            continue
        inflected_phones = f"{cohort} {phon}" if cohort else phon
        instances = cut_forms[cut_forms.str.match(f"{inflected_phones}\\b")].index

        # Pick the top K labels with the highest frequency from the cohort.
        coh_labels = instances.get_level_values("label").str.replace("'s$", "", regex=True)
        if len(coh_labels) > max_items_per_cohort_and_next_unit:
            label_freqs = word_freq_df.reindex(coh_labels.unique()).LogFreq.fillna(word_freq_df.LogFreq.min())
            keep_labels = label_freqs.nlargest(max_items_per_cohort_and_next_unit).index
            instances = instances[coh_labels.isin(keep_labels)]

        # equiv_key = (inflected_phones,)
        # if equiv_key not in all_prediction_equivalences:
        #     all_prediction_equivalences[equiv_key] = \
        #         analogy_pseudocausal.prepare_prediction_equivalences(cuts_df, cut_phonemic_forms,
        #                                                              cohort, phon)

        for label, instance_idx in instances:
            all_instances.append({
                "base_phones": cohort,
                "inflected_phones": inflected_phones,
                "post_divergence": phon,

                "inflection": f"small-{phon}",
                "next_phoneme_in_restricted_set": phon in next_unit_set,

                "cohort_length": config.target_cohort_length,
                "next_phoneme_idx": config.target_cohort_length,

                "inflected": label,
                "inflected_idx": label2idx[label],
                "inflected_instance_idx": instance_idx,
            })

0it [00:00, ?it/s]

In [78]:
all_instances_df = pd.DataFrame(all_instances)

In [89]:
state_space_spec.to_hdf5(f"{output_dir}/state_space_spec.h5")

  value = self._g_getattr(self._v_node, name)
your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block0_values] [items->Index(['description'], dtype='object')]

  self.cuts.to_hdf(path, key=cuts_key, mode="a")


In [90]:
all_instances_df.to_csv(f"{output_dir}/instances.csv")