Prepare state space trajectories for a lexical analysis.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
from collections import Counter, defaultdict
import itertools
import json
from pathlib import Path
import pickle
from typing import Any

import datasets
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split, cross_val_score
import torch
import transformers
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec

In [None]:
# use a word-level equivalence dataset regardless of model, so that we can look up cohort facts
equiv_dataset_path = "data/timit_equiv_phoneme_within_word_prefix_6_1.pkl"
timit_corpus_path = "data/timit_syllables"

out = "out/state_space_specs/all_words.pkl"

In [None]:
with open(equiv_dataset_path, "rb") as f:
    equiv_dataset = pickle.load(f)

In [None]:
timit_corpus = datasets.load_from_disk(timit_corpus_path)

In [None]:
assert all(type(label) == tuple for label in equiv_dataset.class_labels), "Assumes dataset with word prefix labels"

## Prepare cohort data

In [None]:
equiv_frames_by_item = equiv_dataset.hidden_state_dataset.frames_by_item

In [None]:
frame_spans_by_word = defaultdict(list)
cuts_df = []

def process_item(item, idx):
    # How many frames do we have stored for this item?
    start_frame, stop_frame = equiv_frames_by_item[idx]
    num_frames = stop_frame - start_frame

    compression_ratio = num_frames / len(item["input_values"])

    for i, word_detail in enumerate(item["word_syllable_detail"]):
        if not word_detail:
            continue

        word_start_frame = start_frame + int(word_detail[0]["start"] * compression_ratio)
        word_stop_frame = start_frame + int(word_detail[-1]["stop"] * compression_ratio)
        word = item["word_detail"]["utterance"][i]

        instance_idx = len(frame_spans_by_word[word])
        frame_spans_by_word[word].append((word_start_frame, word_stop_frame))

        for syllable in word_detail:
            cuts_df.append({
                "label": word,
                "instance_idx": instance_idx,
                "level": "syllable",
                "description": tuple(syllable["phones"]),
                "onset_frame_idx": start_frame + int(syllable["start"] * compression_ratio),
                "offset_frame_idx": start_frame + int(syllable["stop"] * compression_ratio),
                "item_idx": idx,
            })

        for phoneme in item["word_phonemic_detail"][i]:
            cuts_df.append({
                "label": word,
                "instance_idx": instance_idx,
                "level": "phoneme",
                "description": phoneme["phone"],
                "onset_frame_idx": start_frame + int(phoneme["start"] * compression_ratio),
                "offset_frame_idx": start_frame + int(phoneme["stop"] * compression_ratio),
                "item_idx": idx,
            })

timit_corpus["train"].map(process_item, with_indices=True)
None

In [None]:
# Sanity check: we should have Q assignments for the final frame
Q_assignments = {word: [equiv_dataset.Q[end].item() for start, end in spans]
                 for word, spans in frame_spans_by_word.items()}

In [None]:
Q_assignments_flat = np.array(list(itertools.chain.from_iterable(Q_assignments.values())))
(Q_assignments_flat >= 0).mean()

In [None]:
words = list(frame_spans_by_word.keys())
spans = list(frame_spans_by_word.values())

spec = StateSpaceAnalysisSpec(
    total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
    labels=words,
    target_frame_spans=spans,
    cuts=pd.DataFrame(cuts_df).set_index(["label", "instance_idx", "level"]).sort_index(),
)

In [None]:
with open(out, "wb") as f:
    pickle.dump(spec, f)

### Find word cohorts with interesting overlaps

In [None]:
timit_word_to_phon = {}

def process_item(item):
    for word, word_phons in zip(item["word_detail"]["utterance"], item["word_phonemic_detail"]):
        if len(word_phons) == 0:
            continue

        timit_word_to_phon[word] = tuple(phone["phone"] for phone in word_phons)
timit_corpus.map(process_item)

In [None]:
k = 4
shared_suffixes, shared_suffix_words = Counter(), defaultdict(set)
for w1, w2 in itertools.combinations(timit_word_to_phon.keys(), 2):
    phons1, phons2 = timit_word_to_phon[w1], timit_word_to_phon[w2]
    if len(phons1) > k and len(phons2) > k and phons1[-k:] == phons2[-k:]:
        shared_suffixes[phons1[-k:]] += 1
        shared_suffix_words[phons1[-k:]].add(w1)
        shared_suffix_words[phons1[-k:]].add(w2)

In [None]:
[(key, shared_suffix_words[key]) for key, count in shared_suffixes.most_common(10)]

In [None]:
suffix_overlap_words = set(itertools.chain.from_iterable([shared_suffix_words[key] for key, count in shared_suffixes.most_common(10)]))

In [None]:
k = 4
shared_prefixes, shared_prefix_words = Counter(), defaultdict(set)
for w1, w2 in itertools.combinations(timit_word_to_phon.keys(), 2):
    phons1, phons2 = timit_word_to_phon[w1], timit_word_to_phon[w2]
    if len(phons1) > k and len(phons2) > k and phons1[:k] == phons2[:k]:
        shared_prefixes[phons1[:k]] += 1
        shared_prefix_words[phons1[:k]].add(w1)
        shared_prefix_words[phons1[:k]].add(w2)

In [None]:
prefix_overlap_words = set(itertools.chain.from_iterable([shared_prefix_words[key] for key, count in shared_prefixes.most_common(10)]))

In [None]:
multiple_overlap_words = suffix_overlap_words & prefix_overlap_words

In [None]:
multiple_overlap_words

In [None]:
# Source some prefix overlaps and suffix overlaps for each case
complex_cohort_set = {word: (list(shared_prefix_words[timit_word_to_phon[word][:k]]),
                             list(shared_suffix_words[timit_word_to_phon[word][-k:]]))
                      for word in multiple_overlap_words}
complex_cohort_set

In [None]:
with open("complex_cohort_set.json", "w") as f:
    json.dump(complex_cohort_set, f)