Prepare state space trajectories for a syllabic analysis.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [3]:
from collections import Counter, defaultdict
import itertools
from pathlib import Path
import pickle
from typing import Any

import datasets
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split, cross_val_score
import torch
import transformers

from src.analysis.state_space import StateSpaceAnalysisSpec

In [19]:
# use a word-level equivalence dataset regardless of model, so that we can look up cohort facts
equiv_dataset_path = "data/timit_equiv_phoneme_6_1.pkl"
timit_corpus_path = "data/timit_syllables"

out_by_identity = "out/state_space_specs/all_syllables.pkl"
# out_by_onset = "out/state_space_specs/all_syllables_by_onset.pkl"
# out_by_nucleus = "out/state_space_specs/all_syllables_by_nucleus.pkl"
out_by_ordinal = "out/state_space_specs/all_syllables_by_ordinal.pkl"
out_by_identity_and_ordinal = "out/state_space_specs/all_syllables_by_identity_and_ordinal.pkl"

In [5]:
with open(equiv_dataset_path, "rb") as f:
    equiv_dataset = pickle.load(f)

In [6]:
timit_corpus = datasets.load_from_disk(timit_corpus_path)

## Prepare cohort data

In [8]:
equiv_frames_by_item = equiv_dataset.hidden_state_dataset.frames_by_item

In [9]:
timit_corpus["train"][0]["word_syllable_detail"]

[[{'idx': 0,
   'phoneme_end_idx': 2,
   'phoneme_start_idx': 0,
   'phones': ['SH', 'IH'],
   'start': 3050,
   'stop': 5723,
   'stress': None}],
 [{'idx': 0,
   'phoneme_end_idx': 4,
   'phoneme_start_idx': 0,
   'phones': ['HH', 'EH', 'D', 'JH'],
   'start': 5723,
   'stop': 10337,
   'stress': None}],
 [{'idx': 0,
   'phoneme_end_idx': 2,
   'phoneme_start_idx': 0,
   'phones': ['JH', 'IH'],
   'start': 9190,
   'stop': 11517,
   'stress': None}],
 [{'idx': 0,
   'phoneme_end_idx': 3,
   'phoneme_start_idx': 0,
   'phones': ['D', 'AH', 'K'],
   'start': 11517,
   'stop': 16334,
   'stress': None}],
 [{'idx': 0,
   'phoneme_end_idx': 3,
   'phoneme_start_idx': 0,
   'phones': ['S', 'UW', 'T'],
   'start': 16334,
   'stop': 21199,
   'stress': None}],
 [{'idx': 0,
   'phoneme_end_idx': 2,
   'phoneme_start_idx': 0,
   'phones': ['AH', 'N'],
   'start': 21199,
   'stop': 22560,
   'stress': None}],
 [{'idx': 0,
   'phoneme_end_idx': 3,
   'phoneme_start_idx': 0,
   'phones': ['G', 'R

In [12]:
frame_spans_by_syllable = defaultdict(list)
frame_spans_by_syllable_ordinal = defaultdict(list)
frame_spans_by_syllable_and_ordinal = defaultdict(list)

def process_item(item, idx):
    # How many frames do we have stored for this item?
    start_frame, stop_frame = equiv_frames_by_item[idx]
    num_frames = stop_frame - start_frame

    compression_ratio = num_frames / len(item["input_values"])

    for word in item["word_syllable_detail"]:
        for syllable in word:
            syllable_start_frame = start_frame + int(syllable["start"] * compression_ratio)
            syllable_stop_frame = start_frame + int(syllable["stop"] * compression_ratio)

            syllable_phones = tuple(syllable["phones"])
            span = (syllable_start_frame, syllable_stop_frame)
            frame_spans_by_syllable[syllable_phones].append(span)
            frame_spans_by_syllable_ordinal[syllable["idx"]].append(span)
            frame_spans_by_syllable_and_ordinal[(syllable_phones, syllable["idx"])].append(span)

timit_corpus["train"].map(process_item, with_indices=True)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Dataset({
    features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id', 'phonemic_detail', 'word_phonetic_detail', 'word_phonemic_detail', 'word_syllable_detail', 'input_values'],
    num_rows: 500
})

In [13]:
# Sanity check: we should have Q assignments for the final frame
Q_assignments = {syll: [equiv_dataset.Q[end].item() for start, end in spans]
                 for syll, spans in frame_spans_by_syllable.items()}

In [14]:
Q_assignments_flat = np.array(list(itertools.chain.from_iterable(Q_assignments.values())))
(Q_assignments_flat >= 0).mean()

1.0

In [17]:
spec = StateSpaceAnalysisSpec(
    total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
    labels=list(frame_spans_by_syllable.keys()),
    target_frame_spans=list(frame_spans_by_syllable.values()),
)

with open(out_by_identity, "wb") as f:
    pickle.dump(spec, f)

In [20]:
spec = StateSpaceAnalysisSpec(
    total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
    labels=list(frame_spans_by_syllable_ordinal.keys()),
    target_frame_spans=list(frame_spans_by_syllable_ordinal.values()),
)

with open(out_by_ordinal, "wb") as f:
    pickle.dump(spec, f)

In [None]:
spec = StateSpaceAnalysisSpec(
    total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
    labels=list(frame_spans_by_syllable_and_ordinal.keys()),
    target_frame_spans=list(frame_spans_by_syllable_and_ordinal.values()),
)

with open(out_by_identity_and_ordinal, "wb") as f:
    pickle.dump(spec, f)