Prepare state space trajectories for a lexical analysis.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [3]:
from collections import Counter, defaultdict
import itertools
import json
from pathlib import Path
import pickle
from typing import Any

import datasets
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split, cross_val_score
import torch
import transformers
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec

In [4]:
# use a word-level equivalence dataset regardless of model, so that we can look up cohort facts
equiv_dataset_path = "data/timit_equiv_phoneme_within_word_prefix_6_1.pkl"
timit_corpus_path = "data/timit_phonemes"

out = "out/state_space_specs/all_words.pkl"

In [5]:
with open(equiv_dataset_path, "rb") as f:
    equiv_dataset = pickle.load(f)

In [6]:
timit_corpus = datasets.load_from_disk(timit_corpus_path)

In [7]:
assert all(type(label) == tuple for label in equiv_dataset.class_labels), "Assumes dataset with word prefix labels"

## Prepare cohort data

In [8]:
equiv_frames_by_item = equiv_dataset.hidden_state_dataset.frames_by_item

In [9]:
frame_spans_by_word = defaultdict(list)

def process_item(item, idx):
    # How many frames do we have stored for this item?
    start_frame, stop_frame = equiv_frames_by_item[idx]
    num_frames = stop_frame - start_frame

    compression_ratio = num_frames / len(item["input_values"])

    for start, stop, word in zip(item["word_detail"]["start"],
                                 item["word_detail"]["stop"],
                                 item["word_detail"]["utterance"]):
        word_start_frame = start_frame + int(start * compression_ratio)
        word_stop_frame = start_frame + int(stop * compression_ratio)

        frame_spans_by_word[word].append((word_start_frame, word_stop_frame))

timit_corpus["train"].map(process_item, with_indices=True)
None

Map:   0%|          | 0/4620 [00:00<?, ? examples/s]

In [10]:
# Sanity check: we should have Q assignments for the final frame
Q_assignments = {word: [equiv_dataset.Q[end].item() for start, end in spans]
                 for word, spans in frame_spans_by_word.items()}

In [11]:
Q_assignments_flat = np.array(list(itertools.chain.from_iterable(Q_assignments.values())))
(Q_assignments_flat >= 0).mean()

0.7989657076869007

In [12]:
words = list(frame_spans_by_word.keys())
spans = list(frame_spans_by_word.values())

spec = StateSpaceAnalysisSpec(
    total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
    labels=words,
    target_frame_spans=spans,
)

In [13]:
with open(out, "wb") as f:
    pickle.dump(spec, f)

### Find word cohorts with interesting overlaps

In [14]:
timit_word_to_phon = {}

def process_item(item):
    for word, word_phons in zip(item["word_detail"]["utterance"], item["word_phonemic_detail"]):
        if len(word_phons) == 0:
            continue

        timit_word_to_phon[word] = tuple(phone["phone"] for phone in word_phons)
timit_corpus.map(process_item)

Map:   0%|          | 0/4620 [00:00<?, ? examples/s]

Map:   0%|          | 0/1680 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id', 'phonemic_detail', 'word_phonetic_detail', 'word_phonemic_detail', 'input_values'],
        num_rows: 4620
    })
    test: Dataset({
        features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id', 'phonemic_detail', 'word_phonetic_detail', 'word_phonemic_detail', 'input_values'],
        num_rows: 1680
    })
})

In [15]:
k = 4
shared_suffixes, shared_suffix_words = Counter(), defaultdict(set)
for w1, w2 in itertools.combinations(timit_word_to_phon.keys(), 2):
    phons1, phons2 = timit_word_to_phon[w1], timit_word_to_phon[w2]
    if len(phons1) > k and len(phons2) > k and phons1[-k:] == phons2[-k:]:
        shared_suffixes[phons1[-k:]] += 1
        shared_suffix_words[phons1[-k:]].add(w1)
        shared_suffix_words[phons1[-k:]].add(w2)

In [16]:
[(key, shared_suffix_words[key]) for key, count in shared_suffixes.most_common(10)]

[(('EY', 'SH', 'AH', 'N'),
  {'administration',
   'agglomeration',
   'application',
   'approximation',
   'clarification',
   'compensation',
   'cooperation',
   'creation',
   'demineralization',
   'denunciation',
   'desegregation',
   'determination',
   'education',
   'evaluation',
   'formation',
   'graduation',
   'imagination',
   'information',
   'interpretation',
   'irradiation',
   'justification',
   'operation',
   'panelization',
   'preparation',
   'preservation',
   'radiation',
   'recommendation',
   'recreation',
   'rehabilitation',
   'renunciation',
   'reorganization',
   'situation',
   'vacation',
   'vaporization'}),
 (('IH', 'K', 'AH', 'L'),
  {'article',
   'atypical',
   'biblical',
   'chemical',
   'critical',
   'cyclical',
   'ecumenical',
   'empirical',
   'hypothetical',
   'hysterical',
   'identical',
   'ideological',
   'logical',
   'mechanical',
   'medical',
   'morphological',
   'musical',
   'optical',
   'periodical',
   'physical

In [17]:
suffix_overlap_words = set(itertools.chain.from_iterable([shared_suffix_words[key] for key, count in shared_suffixes.most_common(10)]))

In [18]:
k = 4
shared_prefixes, shared_prefix_words = Counter(), defaultdict(set)
for w1, w2 in itertools.combinations(timit_word_to_phon.keys(), 2):
    phons1, phons2 = timit_word_to_phon[w1], timit_word_to_phon[w2]
    if len(phons1) > k and len(phons2) > k and phons1[:k] == phons2[:k]:
        shared_prefixes[phons1[:k]] += 1
        shared_prefix_words[phons1[:k]].add(w1)
        shared_prefix_words[phons1[:k]].add(w2)

In [19]:
prefix_overlap_words = set(itertools.chain.from_iterable([shared_prefix_words[key] for key, count in shared_prefixes.most_common(10)]))

In [20]:
multiple_overlap_words = suffix_overlap_words & prefix_overlap_words

In [21]:
multiple_overlap_words

{'comparable',
 'compensation',
 'complicated',
 'compositions',
 'conservatism',
 'consolidation',
 'distance'}

In [22]:
# Source some prefix overlaps and suffix overlaps for each case
complex_cohort_set = {word: (list(shared_prefix_words[timit_word_to_phon[word][:k]]),
                             list(shared_suffix_words[timit_word_to_phon[word][-k:]]))
                      for word in multiple_overlap_words}
complex_cohort_set

{'consolidation': (['consumers',
   'consists',
   'consider',
   'consolidation',
   'construction',
   'consume',
   'conceived',
   'constructions',
   'conservatism',
   'considerably',
   'considered',
   'consistently'],
  ['taxation',
   'pronunciation',
   'realization',
   'representation',
   'location',
   'confirmation',
   'sophistication',
   'formulation',
   'civilization',
   'demonstration',
   'population',
   'accreditation',
   'radiosterilization',
   'congregation',
   'elongation',
   'confabulation',
   'depreciation',
   'desolation',
   'sterilization',
   'consolidation',
   'legislation',
   'investigation',
   'configuration',
   'salvation',
   'modernization',
   'explanation',
   'infuriation',
   'hospitalization']),
 'comparable': (['complicity',
   'compile',
   'compare',
   'comparable',
   'compositions',
   'compose',
   'competitive',
   'compliance',
   'competitors',
   "company's",
   'complexity',
   'complete',
   'completely',
   'competin

In [23]:
with open("complex_cohort_set.json", "w") as f:
    json.dump(complex_cohort_set, f)