### Extracting Top 50 pairs of Canonical Variables (CCs) between Wav2vec2.0 and GPT2

In [9]:
import sys
sys.path.append("..")
from brain_encoding.util import scale_and_pca_one_dataset
from sklearn.cross_decomposition import CCA
import time 
import numpy as np
import pickle  

In [14]:
# Language model feature dictionary:
    # - Key: GPT-2 XL layer name
    # - Value: Feature matrix (shape: [num_words_in_TIMIT, feature_dim])
with open("../data/gpt2xl_wemb_TIMIT.pkl", "rb") as f:
    lg_feat_dict = pickle.load(f)  


# Speech model feature dictionary:
    # - Key: Wav2Vec2.0 layer name
    # - Value: Feature matrix (shape: [num_words_in_TIMIT, feature_dim])
    #   - Each row is the mean word embedding over the span of a word
with open("../data/wav2vec2_mean_wemb_TIMIT.pkl", "rb") as f:
    sp_feat_dict = pickle.load(f)

In [8]:
pca_variance_ratio = 0.95 
cca_comps = 50
sp_ccs_dict, lg_ccs_dict = {}, {}

sp_layer_name, lg_layer_name = "encoder7", "decoder8"  # layer name of speech model and language model

sp_stim = sp_feat_dict[sp_layer_name]

print(sp_layer_name)

lg_stim = lg_feat_dict[lg_layer_name]
print(lg_layer_name)

sp_stim = scale_and_pca_one_dataset(sp_stim, variance_ratio=pca_variance_ratio, scale=True)
lg_stim = scale_and_pca_one_dataset(lg_stim, variance_ratio=pca_variance_ratio, scale=True)

print(sp_stim.shape, lg_stim.shape)

# cca_comps = min(sp_stim.shape[1], lg_stim.shape[1])
print("cca_comps:", cca_comps)

cca_st = time.time()
my_cca = CCA(n_components=cca_comps, max_iter=20000)

# Apply CCA (Canonical Correlation Analysis) to transform the speech and language stimuli. 
# - 'sp_ccs': Transformed speech features in the CCA space
# - 'lg_ccs': Transformed language features in the CCA space
# These features ('sp_ccs' and 'lg_ccs') can then be used for neural encoding.
sp_ccs, lg_ccs = my_cca.fit_transform(sp_stim, lg_stim)

cca_et = time.time()
print("CCA consume time: {} s".format(cca_et - cca_st))


encoder7
decoder8
(3575, 462) (3575, 705)
cca_comps: 50
CCA consume time: 47.14986062049866 s
