Prepare state space trajectories for a phoneme position analysis.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [3]:
from collections import Counter, defaultdict
import itertools
from pathlib import Path
import pickle
from typing import Any

import datasets
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split, cross_val_score
import torch
import transformers

from src.analysis.state_space import StateSpaceAnalysisSpec

In [4]:
# use a word-level equivalence dataset regardless of model, so that we can look up cohort facts
equiv_dataset_path = "data/timit_equiv_phoneme_within_word_prefix_6_1.pkl"
timit_corpus_path = "data/timit_syllables"

out1 = "out/state_space_specs/all_phonemes_by_position.pkl"
out2 = "out/state_space_specs/all_phonemes_by_identity.pkl"

out_syllable_position = "out/state_space_specs/all_phonemes_by_syllable_position.pkl"
out_identity_syllable_position = "out/state_space_specs/all_phonemes_by_identity_and_syllable_position.pkl"

In [5]:
with open(equiv_dataset_path, "rb") as f:
    equiv_dataset = pickle.load(f)

In [6]:
timit_corpus = datasets.load_from_disk(timit_corpus_path)

In [7]:
assert all(type(label) == tuple for label in equiv_dataset.class_labels), "Assumes dataset with word prefix labels"

## Prepare cohort data

In [8]:
equiv_frames_by_item = equiv_dataset.hidden_state_dataset.frames_by_item

In [9]:
frame_spans_by_phoneme = defaultdict(list)
frame_spans_by_phoneme_position = defaultdict(list)
frame_spans_by_syllable_index = defaultdict(list)
frame_spans_by_phoneme_and_syllable_index = defaultdict(list)

def process_item(item, idx):
    # How many frames do we have stored for this item?
    start_frame, stop_frame = equiv_frames_by_item[idx]
    num_frames = stop_frame - start_frame

    compression_ratio = num_frames / len(item["input_values"])

    for word in item["word_phonemic_detail"]:
        for i, phone in enumerate(word):
            phone_start_frame = start_frame + int(phone["start"] * compression_ratio)
            phone_stop_frame = start_frame + int(phone["stop"] * compression_ratio)

            frame_spans_by_phoneme[phone["phone"]].append((phone_start_frame, phone_stop_frame))
            frame_spans_by_phoneme_position[i].append((phone_start_frame, phone_stop_frame))
            frame_spans_by_syllable_index[phone["syllable_idx"]].append((phone_start_frame, phone_stop_frame))
            frame_spans_by_phoneme_and_syllable_index[(phone["phone"], phone["syllable_idx"])].append((phone_start_frame, phone_stop_frame))

timit_corpus["train"].map(process_item, with_indices=True)

Map:   0%|          | 0/4620 [00:00<?, ? examples/s]

Dataset({
    features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id', 'word_phonetic_detail', 'input_values', 'phone_targets', 'phonemic_detail', 'word_phonemic_detail', 'word_syllable_detail'],
    num_rows: 4620
})

In [10]:
# Sanity check: we should have Q assignments for the final frame
Q_assignments = {phoneme: [equiv_dataset.Q[end].item() for start, end in spans]
                 for phoneme, spans in frame_spans_by_phoneme.items()}

In [11]:
Q_assignments_flat = np.array(list(itertools.chain.from_iterable(Q_assignments.values())))
(Q_assignments_flat >= 0).mean()

0.94518458590025

In [12]:
phoneme_positions = sorted(frame_spans_by_phoneme_position.keys())

position_spec = StateSpaceAnalysisSpec(
    total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
    labels=phoneme_positions,
    target_frame_spans=[frame_spans_by_phoneme_position[i] for i in phoneme_positions],
)

In [13]:
phonemes = sorted(frame_spans_by_phoneme.keys())

phoneme_spec = StateSpaceAnalysisSpec(
    total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
    labels=phonemes,
    target_frame_spans=[frame_spans_by_phoneme[phone] for phone in phonemes],
)

In [14]:
with open(out1, "wb") as f:
    pickle.dump(position_spec, f)
with open(out2, "wb") as f:
    pickle.dump(phoneme_spec, f)

In [15]:
with open(out_syllable_position, "wb") as f:
    pickle.dump(StateSpaceAnalysisSpec(
        total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
        labels=list(frame_spans_by_syllable_index.keys()),
        target_frame_spans=list(frame_spans_by_syllable_index.values()),
    ), f)

In [16]:
with open(out_identity_syllable_position, "wb") as f:
    pickle.dump(StateSpaceAnalysisSpec(
        total_num_frames=equiv_dataset.hidden_state_dataset.num_frames,
        labels=list(frame_spans_by_phoneme_and_syllable_index.keys()),
        target_frame_spans=list(frame_spans_by_phoneme_and_syllable_index.values()),
    ), f)