In [65]:
import glob
import os
import nibabel as nib
import numpy as np
from tqdm import tqdm
from sklearn.decomposition import PCA


In [66]:
np.random.seed(42)

In [67]:
def flatten_to_2d(data):
    shape = data.shape
    print(shape)
    vector_len = np.prod(shape[:-1])
    time_steps = shape[-1]
    flat_data = data.reshape(vector_len, time_steps)
    flat_data = np.transpose(flat_data)
    #print(transposed_flat_data.shape)
    print(flat_data.shape)
    return flat_data

In [68]:
def pca(flat_data):
    pca = PCA(n_components=200)
    pca_data = pca.fit_transform(flat_data)
    #print(pca_data.shape)
    expl_var = sum(pca.explained_variance_ratio_)
    return pca_data, expl_var

In [69]:
def process_unzipped_files(root_dir):
    gz_files = glob.glob(os.path.join(root_dir, '**', '*.gz'), recursive=True)
    en_subj_data = []
    ch_subj_data = []
    en_fcount = 0
    ch_fcount = 0
    vars = []
    for gz_file in tqdm(gz_files):
        try:
            img = nib.load(gz_file)
            data = img.get_fdata()
            data = flatten_to_2d(data)
            data, var = pca(data)
            vars.append(var)
            fields = gz_file.split('-')
            if 'EN' in fields[1]:
                sub = fields[1]
                en_subj_data.append(data)
                en_fcount += 1
            else:
                ch_subj_data.append(data)
                ch_fcount += 1
        except Exception as e:
            print(f"Error processing {gz_file}: {e}")
    return en_subj_data, ch_subj_data, vars

In [70]:
root_directory = 'LittlePrince'
en_subj_data, ch_subj_data, vars = process_unzipped_files(os.path.join(root_directory, 'derivatives'))

python(54679) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
  0%|          | 0/657 [00:00<?, ?it/s]

(73, 90, 74, 322)
(322, 486180)


  0%|          | 0/657 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [37]:
# Reading in all the data takes several hours, so save it to a zip file of numpy arrays
ch_file_path = 'ch_runs.npz'  
en_file_path = 'en_runs.npz'

np.savez(ch_file_path, *ch_subj_data)
np.savez(en_file_path, *ch_subj_data)

In [13]:
ch = np.load('ch_runs.npz')
en = np.load('en_runs.npz')

print(len(ch))
print(len(en))

ch_subj_data = [ch[i] for i in ch]
en_subj_data = [en[i] for i in en]
print(ch_subj_data[0].shape)
print(en_subj_data[0].shape)



297
360
(322, 200)
(343, 200)


In [14]:
# Create stacks of 9 runs to store data by subject
# There are 9 runs per subject

en_num_subjs = int(len(en_subj_data) / 9)
en_stacked = np.zeros((en_num_subjs, 2816, 200)) # 2816 is number of time points imaged for EN

for i in range(en_num_subjs):
  loc = i * 9
  subj_data = np.vstack((en_subj_data[loc:loc+9]))
  en_stacked[i, :, :] = subj_data

# Repeat for Chinese

ch_num_subjs = int(len(ch_subj_data) / 9)
ch_stacked = np.zeros((ch_num_subjs, 2977, 200)) # 2977 is number of time points imaged for CH

for i in range(ch_num_subjs):
  loc = i * 9
  subj_data = np.vstack((ch_subj_data[loc:loc+9]))
  ch_stacked[i, :, :] = subj_data

print(en_stacked.shape)
print(ch_stacked.shape)

# To ensure approximately same number of images for each langauge, select only 35 of 40 EN participants

en_stacked = en_stacked[:35, :, :]
print(en_stacked.shape)

(40, 2816, 200)
(33, 2977, 200)
(35, 2816, 200)


In [50]:
# Create train, dev, and test splits by participant

en_train_pps = en_stacked[:27,:,:]
en_dev_pps = en_stacked[27:30,:,:]
en_test_pps = en_stacked[30:,:,:]


ch_train_pps = ch_stacked[:26,:,:]
ch_dev_pps = ch_stacked[26:29,:,:]
ch_test_pps = ch_stacked[29:,:,:]

(5, 2816, 200)
(3, 2816, 200)


In [55]:
from torch.nn.utils.rnn import pad_sequence
import torch

def pad(pp_data, time_steps=5):
    pp_data = pp_data.reshape(np.prod(pp_data.shape[:-1]), pp_data.shape[-1])
    vectors = [torch.tensor(i, dtype=torch.float32) for i in pp_data]
    #print(len(vectors))
    sequences = []
    for i in range(int(len(vectors)/time_steps)+1):
        if i * time_steps == len(vectors):
            break
        loc = i*5
        sequences.append((vectors[loc:loc+time_steps]))
    #print(len(sequences))
    #print(sequences[0])
    padded = pad_sequence([torch.stack(seq) for seq in sequences], batch_first=True, padding_value=0)
    print(padded.shape)
    return padded

padded_en_train = pad(en_train_pps)
padded_en_dev = pad(en_dev_pps)
padded_en_test = pad(en_test_pps)
padded_ch_train = pad(ch_train_pps)
padded_ch_dev = pad(ch_dev_pps)
padded_ch_test = pad(ch_test_pps)


torch.Size([15207, 5, 200])
torch.Size([1690, 5, 200])
torch.Size([2816, 5, 200])
torch.Size([15481, 5, 200])
torch.Size([1787, 5, 200])
torch.Size([2382, 5, 200])


In [56]:
# Create labels
# 0 = English; Chinese = 1
en_train_labels = torch.zeros(padded_en_train.shape[0])
en_dev_labels = torch.zeros(padded_en_dev.shape[0])
en_test_labels = torch.zeros(padded_en_test.shape[0])

ch_train_labels = torch.ones(padded_ch_train.shape[0])
ch_dev_labels = torch.ones(padded_ch_dev.shape[0])
ch_test_labels = torch.ones(padded_ch_test.shape[0])



In [59]:
# Combine CH and EN to create final train, test, and dev

Xtrain = torch.cat((padded_en_train, padded_ch_train), dim=0)
Xdev = torch.cat((padded_en_dev, padded_ch_dev), dim=0)
Xtest = torch.cat((padded_en_test, padded_ch_test), dim=0)
ytrain = torch.cat((en_train_labels, ch_train_labels))
ydev = torch.cat((en_dev_labels, ch_dev_labels))
ytest = torch.cat((en_test_labels, ch_test_labels))

print(Xtrain.shape)
print(Xdev.shape)
print(Xtest.shape)
print(ytrain.shape)
print(ydev.shape)
print(ytest.shape)

torch.Size([30688, 5, 200])
torch.Size([3477, 5, 200])
torch.Size([5198, 5, 200])
torch.Size([30688])
torch.Size([3477])
torch.Size([5198])


In [62]:
# Shuffle train, test, and dev
train_ndxs = np.random.permutation(len(Xtrain))
Xtrain = Xtrain[train_ndxs]
ytrain = ytrain[train_ndxs]

dev_ndxs = np.random.permutation(len(Xdev))
Xdev = Xdev[dev_ndxs]
ydev = ydev[dev_ndxs]

test_ndxs = np.random.permutation(len(Xtest))
Xtest = Xtest[test_ndxs]
ytest = ytest[test_ndxs]

In [64]:
# Save the tensors to a directory
dir = 'preprocessed_data/'

torch.save(Xtrain, dir + 'Xtrain.pt')
torch.save(ytrain, dir + 'ytrain.pt')
torch.save(Xdev, dir + 'Xdev.pt')
torch.save(ydev, dir + 'ydev.pt')
torch.save(Xtest, dir + 'Xtest.pt')
torch.save(ytest, dir + 'ytest.pt')