Prepare state space trajectories for a lexical analysis.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
from collections import Counter, defaultdict
import itertools
from pathlib import Path
import pickle
from typing import Any

import datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split, cross_val_score
import torch
import transformers

from src.models import get_best_checkpoint
from src.analysis.state_space import StateSpaceAnalysisSpec
from src.models.integrator import ContrastiveEmbeddingModel, compute_embeddings

In [None]:
# model_dir = "out/ce_model_phoneme_within_word_prefix_6_32"
model_dir = "out/ce_model_random_32"

# use a word-level equivalence dataset regardless of model, so that we can look up cohort facts
equiv_dataset_path = "data/timit_equiv_phoneme_6_1.pkl"
timit_corpus_path = "data/timit_phonemes"

phoneme_response_window = (0, 3)

In [None]:
model = ContrastiveEmbeddingModel.from_pretrained(get_best_checkpoint(model_dir))
model.eval()

In [None]:
with open(equiv_dataset_path, "rb") as f:
    equiv_dataset = pickle.load(f)

In [None]:
timit_corpus = datasets.load_from_disk(timit_corpus_path)

In [None]:
all_phonemes = set([phone["phone"] for words in timit_corpus["train"]["word_phonemic_detail"]
 for word in words
 for phone in word])

In [None]:
cmudict_features = {
    "AA": "low back unrounded",
    "AE": "low front unrounded",
    "AH": "low central unrounded",
    "AO": "low back rounded",
    "AW": "mid back rounded",
    "AY": "high front unrounded",
    "B": "voiced bilabial plosive",
    "CH": "voiceless palato-alveolar affricate",
    "D": "voiced alveolar plosive",
    "DH": "voiced dental fricative",
    "EH": "mid front unrounded",
    "ER": "mid central unrounded",
    "EY": "mid front rounded",
    "F": "voiceless labiodental fricative",
    "G": "voiced velar plosive",
    "HH": "voiceless glottal fricative",
    "IH": "high front unrounded",
    "IY": "high front rounded",
    "JH": "voiced palato-alveolar affricate",
    "K": "voiceless velar plosive",
    "L": "voiced alveolar lateral approximant",
    "M": "voiced bilabial nasal",
    "N": "voiced alveolar nasal",
    "NG": "voiced velar nasal",
    "OW": "mid back rounded",
    "OY": "mid back rounded",
    "P": "voiceless bilabial plosive",
    "R": "voiced alveolar approximant",
    "S": "voiceless alveolar fricative",
    "SH": "voiceless palato-alveolar fricative",
    "T": "voiceless alveolar plosive",
    "TH": "voiceless dental fricative",
    "UH": "high back rounded",
    "UW": "high back rounded",
    "V": "voiced labiodental fricative",
    "W": "voiced labio-velar approximant",
    "Y": "voiced palatal approximant",
    "Z": "voiced alveolar fricative",
    "ZH": "voiced palato-alveolar fricative",
}
cmudict_features = {k: v.split() for k, v in cmudict_features.items()}

In [None]:
all_features = list(sorted(set(itertools.chain(*cmudict_features.values()))))
feature2idx = {f: i for i, f in enumerate(all_features)}

In [None]:
cmudict_feature_idxs = {k: [feature2idx[f] for f in v] for k, v in cmudict_features.items()}

In [None]:
feature_to_phonemes = {f: [k for k, v in cmudict_features.items() if f in v] for f in all_features}

In [None]:
assert all(type(label) == str for label in equiv_dataset.class_labels), "Assumes dataset with phoneme labels"

In [None]:
model_representations = load_or_compute_embeddings(model, equiv_dataset, model_dir, equiv_dataset_path)

## Extract representations

In [None]:
equiv_frames_by_item = equiv_dataset.hidden_state_dataset.frames_by_item

In [None]:
mean_rep = np.mean(model_representations, axis=0, keepdims=True)
std_rep = np.std(model_representations, axis=0, keepdims=True)

In [None]:
phoneme_responses = defaultdict(list)
phoneme_agg_fn = np.mean
zscore = True

def get_phoneme_responses(item, idx):
    start_frame, end_frame = equiv_frames_by_item[idx]
    compression_ratio = (end_frame - start_frame) / len(item["input_values"])

    window_left, window_right = phoneme_response_window

    for word in item["word_phonemic_detail"]:
        for phone in word:
            phone_start = start_frame + int(phone["start"] * compression_ratio)
            phone_end = start_frame + int(phone["stop"] * compression_ratio)

            response = model_representations[phone_end + window_left:phone_end + window_right]

            if zscore:
                response = (response - mean_rep) / std_rep

            phoneme_responses[phone["phone"]].append(phoneme_agg_fn(response, axis=0))

timit_corpus.map(get_phoneme_responses, with_indices=True)

## Aggregate by feature

In [None]:
feature_responses = defaultdict(list)
for feature, phonemes in feature_to_phonemes.items():
    for phoneme in phonemes:
        feature_responses[feature].extend(phoneme_responses[phoneme])

feature_responses = {k: np.stack(v) for k, v in feature_responses.items()}

In [None]:
plot_voiced = pca.transform(feature_responses["voiced"])
plot_voiceless = pca.transform(feature_responses["voiceless"])

In [None]:
plt.scatter(plot_voiced[:, 0], plot_voiced[:, 1], label="voiced", alpha=0.3)
plt.scatter(plot_voiceless[:, 0], plot_voiceless[:, 1], label="voiceless", alpha=0.3)

## Feature selectivity

In [None]:
# num_features * num_dimensions
feature_responses_mat = np.array([feature_responses_i.mean(axis=0)
                                  for feature_responses_i in feature_responses.values()])

In [None]:
feature_responses_mat

In [None]:
# For each feature--hidden unit pair, calculate feature selectivity index:
# FSI_ij receives 1 point for every feature to which hidden unit i responds
# more weakly than it did to feature j by a threshold (0.15).
feature_selectivity_threshold = 0.4
feature_selectivity = np.zeros_like(feature_responses_mat)

for hidden_idx in range(feature_selectivity.shape[1]):
    for feature_idx in range(feature_selectivity.shape[0]):
        feature_response = feature_responses_mat[feature_idx, hidden_idx]

        other_feature_responses = np.concatenate([
            feature_responses_mat[:feature_idx, hidden_idx],
            feature_responses_mat[feature_idx + 1:, hidden_idx],
        ])
        feature_selectivity[feature_idx, hidden_idx] = (np.abs(feature_response - other_feature_responses) > feature_selectivity_threshold).sum()

In [None]:
sns.clustermap(feature_selectivity, yticklabels=feature_to_phonemes.keys(), xticklabels=False)