In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import datasets
import torch
import transformers
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
test_dataset_path = "../out/rnn/w2v2base_rnn2_hidden128_drop6/test_result"
model_sfreq = 50
tokenizer_name = "charsiu/tokenizer_en_cmu"

In [None]:
test_dataset = datasets.load_from_disk(test_dataset_path)

In [None]:
tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained(tokenizer_name)

In [None]:
hidden_state_sources = []
hidden_states = []
for i, item in enumerate(tqdm(test_dataset)):
    hidden_state_sources += [(i, j) for j in range(item["real_frames"])]
    hidden_states.append(np.array(item["rnn_hidden_states"][:item["real_frames"]]))

hidden_states = np.concatenate(hidden_states, axis=0)
assert len(hidden_state_sources) == hidden_states.shape[0]

hidden_state_source_to_flat_idx = {source: i for i, source in enumerate(hidden_state_sources)}

In [None]:
boundary_event_names = ["file", "phoneme", "word"]
boundary_event_to_idx = {event_name: i for i, event_name in enumerate(boundary_event_names)}
boundary_matrix = np.zeros((hidden_states.shape[0], len(boundary_event_names)))

def add_boundaries(item, idx, boundary_matrix, boundary_event_to_idx, hidden_state_source_to_flat_idx):
    compression_ratio = item["compression_ratio"]

    file_start_idx = hidden_state_source_to_flat_idx[(idx, 0)]
    boundary_matrix[file_start_idx, boundary_event_to_idx["file"]] = 1

    for word in item["word_phonemic_detail"]:
        if len(word) == 0:
            continue

        word_start = int(word[0]["start"] * compression_ratio)
        word_start_idx = hidden_state_source_to_flat_idx[(idx, word_start)]
        boundary_matrix[word_start_idx, boundary_event_to_idx["word"]] = 1

        for phoneme in word:
            phoneme_start = int(phoneme["start"] * compression_ratio)
            phoneme_start_idx = hidden_state_source_to_flat_idx[(idx, phoneme_start)]
            boundary_matrix[phoneme_start_idx, boundary_event_to_idx["phoneme"]] = 1

    return None

test_dataset.map(add_boundaries, batched=False, with_indices=True,
                 fn_kwargs={"boundary_matrix": boundary_matrix, "boundary_event_to_idx": boundary_event_to_idx,
                            "hidden_state_source_to_flat_idx": hidden_state_source_to_flat_idx})

## PCA decomposition and time series analysis

In [None]:
d_pca = 4
pca = PCA(d_pca)
hidden_states_pca = pca.fit_transform((hidden_states - hidden_states.mean(axis=0)) / hidden_states.std(axis=0))

In [None]:
pca.explained_variance_ratio_

### TRF model

In [None]:
from mne.decoding import ReceptiveField
from sklearn.model_selection import train_test_split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(boundary_matrix, hidden_states_pca, shuffle=True, test_size=0.2, random_state=42)

In [None]:
trf = ReceptiveField(tmin=-5, tmax=20, sfreq=1, feature_names=boundary_event_names)
trf.fit(X_train, y_train)

In [None]:
trf.score(X_test, y_test)

In [None]:
f, axs = plt.subplots(trf.coef_.shape[1], 1, figsize=(5 * trf.coef_.shape[1], 5))
for input_dim, (name, ax) in enumerate(zip(boundary_event_names, axs.ravel())):
    ax.set_title(name)
    for output_dim in range(trf.coef_.shape[0]):
        ax.plot(trf.delays_, trf.coef_[output_dim, input_dim, :], label=output_dim)
        ax.legend()

### Epoched analysis

In [None]:
def epoch_and_plot(epoch_event, hidden_states, boundary_matrix, epoch_window=(-5, 20), baseline_window=(-5, 0)):
    epoch_data = []
    target_width = epoch_window[1] - epoch_window[0]
    for event_idx in tqdm(boundary_matrix[:, boundary_event_to_idx[epoch_event]].nonzero()[0]):
        epoch_states = hidden_states[event_idx + epoch_window[0]:event_idx + epoch_window[1]]

        if baseline_window is not None:
            baseline_states = hidden_states[event_idx + baseline_window[0]:event_idx + baseline_window[1]]
            epoch_states -= baseline_states.mean(axis=0)

        if epoch_states.shape[0] < target_width:
            # Pad
            epoch_states = np.pad(epoch_states, ((0, target_width - epoch_states.shape[0]), (0, 0)))
        epoch_data.append(epoch_states)

    epoch_data = np.array(epoch_data)

    f, ax = plt.subplots(figsize=(10, 5))
    ax.axvline(0, color="gray")
    ax.axhline(0, color="gray")

    for dim in range(epoch_data.shape[2]):
        data = epoch_data[:, :, dim]
        data_mean = data.mean(axis=0)
        data_sem = data.std(axis=0) / np.sqrt(data.shape[0])

        xs = np.arange(epoch_window[0], epoch_window[1]) / model_sfreq
        ax.plot(xs, data_mean, label=dim)
        ax.fill_between(xs, data_mean - data_sem, data_mean + data_sem, alpha=0.3)

    return ax, epoch_data

In [None]:
ax, _ = epoch_and_plot("word", hidden_states_pca, boundary_matrix)
ax.set_title("Word onset")

In [None]:
ax, _ = epoch_and_plot("phoneme", hidden_states_pca, boundary_matrix)
ax.set_title("Phoneme onset")

## Norm time series analysis

In [None]:
offset = 1000
lim = 400
f, ax = plt.subplots(figsize=(20, 8))
ax.plot(np.linalg.norm(hidden_states, axis=1)[offset:offset + lim])
for word_onset in boundary_matrix[:, boundary_event_to_idx["word"]].nonzero()[0]:
    if word_onset > lim:
        break
    ax.axvline(word_onset, color="red", alpha=0.5)
for phon_onset in boundary_matrix[:, boundary_event_to_idx["phoneme"]].nonzero()[0]:
    if phon_onset > lim:
        break
    if boundary_matrix[phon_onset, boundary_event_to_idx["word"]] == 1:
        continue
    ax.axvline(phon_onset, color="green", alpha=0.5)

In [None]:
trf_norm = ReceptiveField(tmin=-5, tmax=20, sfreq=1, feature_names=boundary_event_names)
norm_X_train, norm_X_test, norm_y_train, norm_y_test = train_test_split(boundary_matrix, np.linalg.norm(hidden_states, axis=1), shuffle=True, test_size=0.1, random_state=42)

In [None]:
trf_norm.fit(norm_X_train, norm_y_train)

In [None]:
trf_norm.score(norm_X_test, norm_y_test)

In [None]:
trf_norm.coef_.shape

In [None]:
f, axs = plt.subplots(trf_norm.coef_.shape[0], 1, figsize=(5 * trf_norm.coef_.shape[0], 5))
for input_dim, (name, ax) in enumerate(zip(boundary_event_names, axs.ravel())):
    ax.set_title(name)

    ax.plot(trf_norm.delays_, trf_norm.coef_[input_dim, :], label=output_dim)
ax.legend()

In [None]:
ax, _ = epoch_and_plot("word", np.linalg.norm(hidden_states, axis=1, keepdims=True), boundary_matrix)
None

In [None]:
ax, _ = epoch_and_plot("phoneme", np.linalg.norm(hidden_states, axis=1, keepdims=True), boundary_matrix)
None

## PCA decomposition of state transitions

In [None]:
hidden_state_shifts = hidden_states[1:] - hidden_states[:-1]
boundary_matrix_shift = boundary_matrix[1:]

In [None]:
d_pca = 4
shift_pca = PCA(d_pca)
hidden_state_shifts_pca = shift_pca.fit_transform((hidden_state_shifts - hidden_state_shifts.mean(axis=0)) / hidden_state_shifts.std(axis=0))

In [None]:
pca.explained_variance_ratio_

In [None]:
trf_shift = ReceptiveField(tmin=-5, tmax=20, sfreq=1, feature_names=boundary_event_names)
shift_X_train, shift_X_test, shift_y_train, shift_y_test = train_test_split(boundary_matrix_shift, hidden_state_shifts_pca, shuffle=True, test_size=0.2, random_state=42)

In [None]:
trf_shift.fit(shift_X_train, shift_y_train)

In [None]:
trf_shift.score(shift_X_test, shift_y_test)

In [None]:
f, axs = plt.subplots(trf.coef_.shape[1], 1, figsize=(5 * trf.coef_.shape[1], 5))
for input_dim, (name, ax) in enumerate(zip(boundary_event_names, axs.ravel())):
    ax.set_title(name)
    for output_dim in range(trf.coef_.shape[0]):
        ax.plot(trf.delays_, trf.coef_[output_dim, input_dim, :], label=output_dim)
        ax.legend()

In [None]:
def epoch_and_plot_shift(epoch_event, epoch_window=(-5, 20), baseline_window=(-5, 0)):
    epoch_data = []
    target_width = epoch_window[1] - epoch_window[0]
    for event_idx in tqdm(boundary_matrix_shift[:, boundary_event_to_idx[epoch_event]].nonzero()[0]):
        epoch_states = hidden_state_shifts_pca[event_idx + epoch_window[0]:event_idx + epoch_window[1]]
        baseline_states = hidden_state_shifts_pca[event_idx + baseline_window[0]:event_idx + baseline_window[1]]
        
        epoch_states -= baseline_states.mean(axis=0)
        if epoch_states.shape[0] < target_width:
            # Pad
            epoch_states = np.pad(epoch_states, ((0, target_width - epoch_states.shape[0]), (0, 0)))
        epoch_data.append(epoch_states)

    epoch_data = np.array(epoch_data)

    f, ax = plt.subplots(figsize=(10, 5))
    ax.axvline(0, color="gray")
    ax.axhline(0, color="gray")

    for dim in range(epoch_data.shape[2]):
        data = epoch_data[:, :, dim]
        data_mean = data.mean(axis=0)
        data_sem = data.std(axis=0) / np.sqrt(data.shape[0])

        xs = np.arange(epoch_window[0], epoch_window[1]) / model_sfreq
        ax.plot(xs, data_mean, label=dim)
        ax.fill_between(xs, data_mean - data_sem, data_mean + data_sem, alpha=0.3)

    return ax, epoch_data

In [None]:
ax, _ = epoch_and_plot_shift("word")
ax.set_title("Word onset")

In [None]:
ax, _ = epoch_and_plot_shift("phoneme")
ax.set_title("Phoneme onset")