### Semantic features as word embeddings

In [1]:
import numpy as np
import os
import torch
import json
from tqdm import tqdm
import scipy
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.ensemble import StackingRegressor
from sklearn.neural_network import MLPRegressor

from tqdm import tqdm
from numpy import dot
from numpy.linalg import norm
from sklearn.model_selection import LeavePOut
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import r2_score
from scipy import spatial
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
def cosim(a, b):
    return 1 - spatial.distance.cosine(a, b)

### Semantic features

In [3]:
min_len_line = 5
N_SEMANTIC_FEATURES = 25
semantic_features = {}

def dump_mitchell_web_semantic_features(raw_file = os.path.join("data","mitchell_semantic_raw.txt")):
    with open(raw_file, "r") as datafile:
        lines = datafile.readlines()
        word = None

        for line in lines:

            # Skip empty
            if len(line) >= min_len_line:

                # New feature
                if "Features for" in line:

                    # Discard invalid ones (once fully parsed)
                    if word and len(semantic_features[word]['features']) < N_SEMANTIC_FEATURES: del semantic_features[word] 
                        
                    word = line.split("<a name=\"")[1].split("\"")[0]
                    semantic_features[word] = { "features": [], "values": []}

                elif word:
                    feature_name = line.split("(")[0]
                    val = float(line.split("(")[1].split(")")[0])
                    semantic_features[word]["features"].append(feature_name)
                    semantic_features[word]["values"].append(val)

    # Save to file
    #with open(os.path.join('data', 'mitchell_semantic_features.json'), 'w') as fp:
    #    json.dump(semantic_features, fp)

    return semantic_features


def load_sorted_semantic_features(file = os.path.join("data","mitchell_semantic_features.json")):
    with open(file) as f:
        semantic_features = json.load(f)
        for word in semantic_features.keys():
            # Sort all features
            sorted_features = sorted(enumerate(semantic_features[word]["features"]), key=lambda x:x[1])
            sorted_indices = [i[0] for i in sorted_features]
            sorted_values = [semantic_features[word]["values"][i] for i in sorted_indices]
            # Re-store them
            semantic_features[word]["features"] = [x[1] for x in sorted_features]
            semantic_features[word]["values"] = sorted_values

        # Sanity check: all samples have same order of feature
        feats = None
        for word in semantic_features.keys():
            if feats is None: feats = semantic_features[word]["features"]
            else: assert feats == semantic_features[word]["features"]

    return semantic_features
            

### 1. Mitchell semantic features/fMRI

In [4]:
def get_mitchell_original_data(subject = 1):
    mdata = scipy.io.loadmat(os.path.join("data", "mitchell", f"data-science-P{subject}.mat"))
    subject_data = {}

    # 6 x 60 trials
    for i in range(mdata["data"][:].shape[0]):
        cond, cond_number, word, word_number, epoch = [x[0] for x in mdata["info"][0][i]]

        # Set trial data
        if epoch[0] not in subject_data: subject_data[epoch[0]] = {}
        subject_data[epoch[0]][word] = mdata["data"][i][0][0]

    return subject_data

### Voxel selection methods

In [5]:
def r2_best_voxels(scores, K, threshold = 0.2):
    scores = np.array(scores)
    r2_selected_voxels = np.where(scores > threshold)[0]
    return np.array(
        sorted( # sort by score, pick first K indices
            list(zip(scores[r2_selected_voxels], r2_selected_voxels)), 
            key = lambda x: x[0]
        )
    )[:K, 1].astype(np.int32).tolist()

In [6]:
def mitchell_stable_voxels(voxel_matrices, train_split_indices, K = 500):

    # Get scores of the voxels
    scores = []
    for vx in voxel_matrices:
        u, s, vh = np.linalg.svd(vx[:, train_split_indices], full_matrices=True) # SVD, take first eigenvalue
        scores.append((s**2)[0])
        
    # indices of the most stable voxels
    return np.argpartition(scores, -K)[-K:]

In [7]:
def get_voxels_matrices(data):
    voxels = data[1]["bell"].shape[0]

    repetitions = []
    for vx in (range(voxels)):
        repetitions.append(np.array(
            [[data[epoch][word][vx] for word in data[epoch].keys()] for epoch in data.keys()]
        ))
    return np.array(repetitions)

In [8]:
def epoch_normalize_fmri(data):
    newData = {}
    # (word, epoch, voxels)
    for epoch in data.keys():
        for word in data[epoch].keys():
            if word not in newData: newData[word] = []
            newData[word].append(data[epoch][word])
    # Mean across epochs per word
    for word in newData.keys():
        newData[word] = np.mean(np.array(newData[word]), axis=0)
    # Normalize by global mean
    mean_vector = np.array(list(newData.values())).mean(axis=0)
    for word in newData.keys(): newData[word] -= mean_vector
    return newData

In [9]:
%%time
completeFmriData = get_mitchell_original_data(subject=1)
voxels = get_voxels_matrices(completeFmriData)
voxels = mitchell_stable_voxels(voxels, list(range(58)))

CPU times: total: 2.09 s
Wall time: 2.1 s


### Training loop

In [10]:
def leave_2_out_accuracy(y_pred, y_test):

    p1, p2 = y_pred
    i1, i2 = y_test

    pair1_score = cosim(p1, i1) + cosim(p2, i2)
    pair2_score = cosim(p1, i2) + cosim(p2, i1)

    return int(pair1_score > pair2_score)

In [11]:
subjects = 1
leave2out = LeavePOut(2)
K = 500
VOXELWISE_ACC_THRESHOLD = 0.2

semantic_features = load_sorted_semantic_features()
N_words = len(semantic_features.keys())

In [12]:
accuracies = np.zeros((subjects, 3))

for subject in range(1, subjects+1):

    print(f"**** Subject {subject} ****")

    completeFmriData = get_mitchell_original_data(subject=subject)
    voxel_matrices = get_voxels_matrices(completeFmriData) # pre-compute voxel 6x58 matrices for each voxel
    
    data = epoch_normalize_fmri(completeFmriData)
    n_voxels = data["bell"].shape[0]

    # Training
    X = []
    Y = []
    for word in semantic_features.keys():
        if word in data.keys():
            x = np.array(semantic_features[word]["values"])
            y = np.array(data[word])
            X.append(x)
            Y.append(y)

    X = np.array(X)
    Y = np.array(Y)

    # leave 2 out cross validation
    accuracies_r2 = []
    accuracies_most_stable = []
    accuracies_mitchell = []
    progress_bar = tqdm(range(leave2out.get_n_splits(X)))
    
    for i, (train_index, test_index) in enumerate(leave2out.split(X)):
        
        # Train-test split
        X_train, X_test, y_train, y_test = X[train_index], X[test_index], Y[train_index], Y[test_index]
            
        # Early voxel selection
        # this is the most expensive operation! the overhead lies in fetching
        mitchell_voxels = mitchell_stable_voxels(voxel_matrices, train_index, K=K) # extract svds, score voxels, pick best 500
        
        # Predicting & scoring
        #predictors = make_pipeline(StandardScaler(), MultiOutputRegressor(LinearRegression(), n_jobs=32))
        predictors = MultiOutputRegressor(LinearRegression(), n_jobs=32)
        predictors.fit(X_train, y_train[:, mitchell_voxels])
        
        # Mitchell stable voxels
        y_pred = predictors.predict(X_test)
        accuracies_mitchell.append(
            leave_2_out_accuracy(y_pred, y_test[:, mitchell_voxels])
        )

        if i % 10 == 0: print(np.mean(accuracies_mitchell))

        """
        # R2 best voxels
        scores = [r2_score(y_pred[:, i], y_test[:, vx]) for i,vx in enumerate(mitchell_voxels)]
        r2_best = r2_best_voxels(scores, K = K)
        print(r2_best)
        accuracies_r2.append(
            leave_2_out_accuracy(y_pred, y_test[:, r2_best])
        )
        """

        progress_bar.update(1)
            

    # Subject mean accuracies
    accuracies[subject-1] = np.array(
        [np.mean(accuracies_mitchell), 0, 0]
    )
    print(f"Accuracy: {np.mean(accuracies_mitchell):.2f}")
    
    with open('accuracies_semantic2fmri.npy', 'wb') as f:
        np.save(f, accuracies)


**** Subject 1 ****


  0%|          | 1/1770 [00:05<2:38:23,  5.37s/it]

0.0


  1%|          | 11/1770 [00:16<32:33,  1.11s/it] 

0.5454545454545454


  1%|          | 21/1770 [00:26<31:16,  1.07s/it]

0.5714285714285714


  2%|▏         | 31/1770 [00:37<31:26,  1.08s/it]

0.6451612903225806


  2%|▏         | 41/1770 [00:48<31:03,  1.08s/it]

0.5365853658536586


  3%|▎         | 51/1770 [00:59<31:31,  1.10s/it]

0.5294117647058824


  3%|▎         | 61/1770 [01:10<30:27,  1.07s/it]

0.5737704918032787


  4%|▍         | 71/1770 [01:21<30:45,  1.09s/it]

0.6056338028169014


  5%|▍         | 81/1770 [01:32<30:28,  1.08s/it]

0.654320987654321


  5%|▌         | 91/1770 [01:42<30:38,  1.09s/it]

0.6593406593406593


  6%|▌         | 101/1770 [01:53<30:05,  1.08s/it]

0.6732673267326733


  6%|▋         | 111/1770 [02:04<29:42,  1.07s/it]

0.6936936936936937


  7%|▋         | 121/1770 [02:15<29:25,  1.07s/it]

0.71900826446281


  7%|▋         | 131/1770 [02:26<29:46,  1.09s/it]

0.7251908396946565


  8%|▊         | 141/1770 [02:37<29:30,  1.09s/it]

0.7375886524822695


  9%|▊         | 151/1770 [02:48<30:01,  1.11s/it]

0.7284768211920529


  9%|▉         | 161/1770 [02:59<29:18,  1.09s/it]

0.7142857142857143


 10%|▉         | 171/1770 [03:09<28:34,  1.07s/it]

0.6900584795321637


 10%|█         | 181/1770 [03:20<28:30,  1.08s/it]

0.7016574585635359


 11%|█         | 191/1770 [03:31<29:03,  1.10s/it]

0.7120418848167539


 11%|█▏        | 201/1770 [03:42<28:32,  1.09s/it]

0.7164179104477612


 12%|█▏        | 211/1770 [03:53<27:48,  1.07s/it]

0.7109004739336493


 12%|█▏        | 221/1770 [04:04<27:50,  1.08s/it]

0.7149321266968326


 13%|█▎        | 231/1770 [04:14<27:31,  1.07s/it]

0.7229437229437229


 14%|█▎        | 241/1770 [04:25<27:06,  1.06s/it]

0.7302904564315352


 14%|█▍        | 251/1770 [04:36<26:36,  1.05s/it]

0.7370517928286853


 15%|█▍        | 261/1770 [04:46<26:35,  1.06s/it]

0.7471264367816092


 15%|█▌        | 271/1770 [04:57<26:52,  1.08s/it]

0.7564575645756457


 16%|█▌        | 281/1770 [05:08<26:55,  1.09s/it]

0.7615658362989324


 16%|█▋        | 291/1770 [05:18<26:24,  1.07s/it]

0.7594501718213058


 17%|█▋        | 301/1770 [05:29<26:02,  1.06s/it]

0.7674418604651163


 18%|█▊        | 311/1770 [05:40<25:58,  1.07s/it]

0.7717041800643086


 18%|█▊        | 321/1770 [05:51<25:46,  1.07s/it]

0.7694704049844237


 19%|█▊        | 331/1770 [06:01<25:46,  1.07s/it]

0.770392749244713


 19%|█▉        | 341/1770 [06:12<25:17,  1.06s/it]

0.7771260997067448


 20%|█▉        | 351/1770 [06:23<25:07,  1.06s/it]

0.7720797720797721


 20%|██        | 361/1770 [06:33<25:00,  1.07s/it]

0.7700831024930748


 21%|██        | 371/1770 [06:44<24:58,  1.07s/it]

0.7762803234501348


 22%|██▏       | 381/1770 [06:55<24:44,  1.07s/it]

0.7769028871391076


 22%|██▏       | 391/1770 [07:05<24:25,  1.06s/it]

0.7749360613810742


 23%|██▎       | 401/1770 [07:16<24:09,  1.06s/it]

0.7780548628428927


 23%|██▎       | 411/1770 [07:27<25:11,  1.11s/it]

0.7785888077858881


 24%|██▍       | 421/1770 [07:38<24:14,  1.08s/it]

0.7767220902612827


 24%|██▍       | 431/1770 [07:48<24:00,  1.08s/it]

0.777262180974478


 25%|██▍       | 441/1770 [07:59<23:50,  1.08s/it]

0.7755102040816326


 25%|██▌       | 451/1770 [08:10<23:26,  1.07s/it]

0.7782705099778271


 26%|██▌       | 461/1770 [08:21<23:22,  1.07s/it]

0.7809110629067245


 27%|██▋       | 471/1770 [08:31<23:08,  1.07s/it]

0.7813163481953291


 27%|██▋       | 481/1770 [08:42<22:44,  1.06s/it]

0.7837837837837838


 28%|██▊       | 491/1770 [08:53<22:46,  1.07s/it]

0.7861507128309573


 28%|██▊       | 501/1770 [09:03<22:42,  1.07s/it]

0.7884231536926147


 29%|██▉       | 511/1770 [09:14<22:29,  1.07s/it]

0.7906066536203522


 29%|██▉       | 521/1770 [09:25<22:08,  1.06s/it]

0.7850287907869482


 30%|███       | 531/1770 [09:35<21:52,  1.06s/it]

0.7796610169491526


 31%|███       | 541/1770 [09:46<21:46,  1.06s/it]

0.7781885397412199


 31%|███       | 551/1770 [09:57<21:59,  1.08s/it]

0.7749546279491834


 32%|███▏      | 561/1770 [10:07<21:40,  1.08s/it]

0.7736185383244206


 32%|███▏      | 571/1770 [10:18<21:11,  1.06s/it]

0.7705779334500875


 33%|███▎      | 581/1770 [10:29<21:02,  1.06s/it]

0.7710843373493976


 33%|███▎      | 591/1770 [10:39<21:15,  1.08s/it]

0.7715736040609137


 34%|███▍      | 601/1770 [10:50<20:33,  1.06s/it]

0.7703826955074875


 35%|███▍      | 611/1770 [11:01<20:49,  1.08s/it]

0.7741407528641571


 35%|███▌      | 621/1770 [11:11<20:25,  1.07s/it]

0.7761674718196457


 36%|███▌      | 631/1770 [11:22<20:19,  1.07s/it]

0.7749603803486529


 36%|███▌      | 641/1770 [11:33<20:01,  1.06s/it]

0.7784711388455539


 37%|███▋      | 651/1770 [11:43<20:26,  1.10s/it]

0.7788018433179723


 37%|███▋      | 661/1770 [11:55<20:09,  1.09s/it]

0.7806354009077155


 38%|███▊      | 671/1770 [12:06<20:18,  1.11s/it]

0.7824143070044709


 38%|███▊      | 681/1770 [12:16<19:26,  1.07s/it]

0.7856093979441997


 39%|███▉      | 691/1770 [12:28<19:34,  1.09s/it]

0.7858176555716353


 40%|███▉      | 701/1770 [12:38<19:33,  1.10s/it]

0.7803138373751783


 40%|████      | 711/1770 [12:49<19:12,  1.09s/it]

0.7735583684950773


 41%|████      | 721/1770 [13:00<19:00,  1.09s/it]

0.7697642163661581


 41%|████▏     | 731/1770 [13:11<18:43,  1.08s/it]

0.771545827633379


 42%|████▏     | 741/1770 [13:22<18:46,  1.09s/it]

0.7678812415654521


 42%|████▏     | 751/1770 [13:33<18:48,  1.11s/it]

0.7696404793608522


 43%|████▎     | 761/1770 [13:45<19:53,  1.18s/it]

0.7674113009198423


 44%|████▎     | 771/1770 [13:56<18:42,  1.12s/it]

0.767833981841764


 44%|████▍     | 781/1770 [14:07<17:44,  1.08s/it]

0.7708066581306018


 45%|████▍     | 791/1770 [14:18<17:25,  1.07s/it]

0.7711757269279393


 45%|████▌     | 801/1770 [14:28<17:14,  1.07s/it]

0.7727840199750312


 46%|████▌     | 811/1770 [14:39<16:56,  1.06s/it]

0.7755856966707768


 46%|████▋     | 821/1770 [14:50<16:50,  1.06s/it]

0.7771010962241169


 47%|████▋     | 831/1770 [15:01<16:40,  1.07s/it]

0.779783393501805


 48%|████▊     | 841/1770 [15:11<16:43,  1.08s/it]

0.7788347205707491


 48%|████▊     | 851/1770 [15:22<16:50,  1.10s/it]

0.7802585193889542


 49%|████▊     | 861/1770 [15:33<16:22,  1.08s/it]

0.7816492450638792


 49%|████▉     | 871/1770 [15:44<16:01,  1.07s/it]

0.7818599311136625


 50%|████▉     | 881/1770 [15:55<16:02,  1.08s/it]

0.7820658342792282


 50%|█████     | 891/1770 [16:05<15:38,  1.07s/it]

0.7845117845117845


 51%|█████     | 901/1770 [16:16<16:01,  1.11s/it]

0.7869034406215316


 51%|█████▏    | 911/1770 [16:27<15:27,  1.08s/it]

0.7892425905598244


 52%|█████▏    | 921/1770 [16:38<15:11,  1.07s/it]

0.7850162866449512


 53%|█████▎    | 931/1770 [16:49<14:58,  1.07s/it]

0.7873254564983888


 53%|█████▎    | 941/1770 [17:00<14:51,  1.08s/it]

0.7874601487778958


 54%|█████▎    | 951/1770 [17:10<14:31,  1.06s/it]

0.7896950578338591


 54%|█████▍    | 961/1770 [17:21<14:50,  1.10s/it]

0.7887617065556711


 55%|█████▍    | 971/1770 [17:32<14:22,  1.08s/it]

0.7899073120494335


 55%|█████▌    | 981/1770 [17:43<14:08,  1.08s/it]

0.7900101936799184


 56%|█████▌    | 991/1770 [17:54<15:11,  1.17s/it]

0.7911200807265388


 57%|█████▋    | 1001/1770 [18:05<13:53,  1.08s/it]

0.7902097902097902


 57%|█████▋    | 1011/1770 [18:16<13:28,  1.06s/it]

0.7922848664688428


 58%|█████▊    | 1021/1770 [18:27<13:29,  1.08s/it]

0.7933398628795298


 58%|█████▊    | 1031/1770 [18:37<13:06,  1.06s/it]

0.7934044616876819


 59%|█████▉    | 1041/1770 [18:48<13:09,  1.08s/it]

0.7915465898174832


 59%|█████▉    | 1051/1770 [18:59<13:42,  1.14s/it]

0.7935299714557564


 60%|█████▉    | 1061/1770 [19:10<13:03,  1.10s/it]

0.7926484448633365


 61%|██████    | 1071/1770 [19:21<12:30,  1.07s/it]

0.7917833800186741


 61%|██████    | 1081/1770 [19:32<12:11,  1.06s/it]

0.7918593894542091


 62%|██████▏   | 1091/1770 [19:42<12:01,  1.06s/it]

0.7928505957836847


 62%|██████▏   | 1101/1770 [19:53<11:53,  1.07s/it]

0.7920072661217076


 63%|██████▎   | 1111/1770 [20:04<11:43,  1.07s/it]

0.7902790279027903


 63%|██████▎   | 1121/1770 [20:15<11:39,  1.08s/it]

0.7912578055307761


 64%|██████▍   | 1131/1770 [20:25<11:25,  1.07s/it]

0.7895667550839964


 64%|██████▍   | 1141/1770 [20:36<11:26,  1.09s/it]

0.7905346187554777


 65%|██████▌   | 1151/1770 [20:47<11:03,  1.07s/it]

0.788010425716768


 66%|██████▌   | 1161/1770 [20:58<10:47,  1.06s/it]

0.7863910422049957


 66%|██████▌   | 1171/1770 [21:08<10:41,  1.07s/it]

0.784799316823228


 67%|██████▋   | 1181/1770 [21:19<10:21,  1.06s/it]

0.7840812870448772


 67%|██████▋   | 1191/1770 [21:30<10:14,  1.06s/it]

0.783375314861461


 68%|██████▊   | 1201/1770 [21:40<10:04,  1.06s/it]

0.7835137385512073


 68%|██████▊   | 1211/1770 [21:51<09:50,  1.06s/it]

0.7853014037985137


 69%|██████▉   | 1221/1770 [22:02<09:51,  1.08s/it]

0.7854217854217854


 70%|██████▉   | 1231/1770 [22:13<09:57,  1.11s/it]

0.784727863525589


 70%|███████   | 1241/1770 [22:24<09:44,  1.11s/it]

0.7864625302175665


 71%|███████   | 1251/1770 [22:35<09:47,  1.13s/it]

0.7881694644284573


 71%|███████   | 1261/1770 [22:46<10:18,  1.21s/it]

0.7898493259318001


 72%|███████▏  | 1271/1770 [22:58<09:11,  1.11s/it]

0.7891424075531078


 72%|███████▏  | 1281/1770 [23:09<08:41,  1.07s/it]

0.790007806401249


 73%|███████▎  | 1291/1770 [23:20<08:37,  1.08s/it]

0.7916343919442292


 74%|███████▎  | 1301/1770 [23:30<08:19,  1.06s/it]

0.7916986933128363


 74%|███████▍  | 1311/1770 [23:41<08:08,  1.06s/it]

0.7917620137299771


 75%|███████▍  | 1321/1770 [23:52<08:06,  1.08s/it]

0.7910673732021196


 75%|███████▌  | 1331/1770 [24:02<07:51,  1.07s/it]

0.791885800150263


 76%|███████▌  | 1341/1770 [24:13<07:39,  1.07s/it]

0.7897091722595079


 76%|███████▋  | 1351/1770 [24:24<07:29,  1.07s/it]

0.7868245743893413


 77%|███████▋  | 1361/1770 [24:35<07:15,  1.06s/it]

0.7847171197648788


 77%|███████▋  | 1371/1770 [24:45<07:05,  1.07s/it]

0.7826404084609774


 78%|███████▊  | 1381/1770 [24:56<07:11,  1.11s/it]

0.78059377262853


 79%|███████▊  | 1391/1770 [25:07<06:59,  1.11s/it]

0.7800143781452192


 79%|███████▉  | 1401/1770 [25:18<06:31,  1.06s/it]

0.7801570306923626


 80%|███████▉  | 1411/1770 [25:29<06:23,  1.07s/it]

0.780297661233168


 80%|████████  | 1421/1770 [25:39<06:09,  1.06s/it]

0.7804363124560169


 81%|████████  | 1431/1770 [25:50<06:02,  1.07s/it]

0.7791754018169113


 81%|████████▏ | 1441/1770 [26:01<05:51,  1.07s/it]

0.7765440666204025


 82%|████████▏ | 1451/1770 [26:12<05:41,  1.07s/it]

0.7746381805651275


 83%|████████▎ | 1461/1770 [26:22<05:29,  1.07s/it]

0.7741273100616016


 83%|████████▎ | 1471/1770 [26:33<05:16,  1.06s/it]

0.7743031951053705


 84%|████████▎ | 1481/1770 [26:43<05:08,  1.07s/it]

0.7731262660364618


 84%|████████▍ | 1491/1770 [26:54<05:01,  1.08s/it]

0.7712944332662642


 85%|████████▍ | 1501/1770 [27:05<04:49,  1.08s/it]

0.7721518987341772


 85%|████████▌ | 1511/1770 [27:16<04:36,  1.07s/it]

0.7703507610853739


 86%|████████▌ | 1521/1770 [27:27<04:24,  1.06s/it]

0.7718606180144642


 86%|████████▋ | 1531/1770 [27:38<04:17,  1.08s/it]

0.7726975832789027


 87%|████████▋ | 1541/1770 [27:49<04:07,  1.08s/it]

0.772225827384815


 88%|████████▊ | 1551/1770 [27:59<03:58,  1.09s/it]

0.7730496453900709


 88%|████████▊ | 1561/1770 [28:10<03:51,  1.11s/it]

0.7738629083920564


 89%|████████▉ | 1571/1770 [28:22<03:45,  1.13s/it]

0.7733927434754934


 89%|████████▉ | 1581/1770 [28:32<03:23,  1.08s/it]

0.7741935483870968


 90%|████████▉ | 1591/1770 [28:43<03:13,  1.08s/it]

0.7724701445631679


 90%|█████████ | 1601/1770 [28:55<03:14,  1.15s/it]

0.7732667083073079


 91%|█████████ | 1611/1770 [29:07<03:20,  1.26s/it]

0.7728119180633147


 92%|█████████▏| 1621/1770 [29:19<02:57,  1.19s/it]

0.7735965453423812


 92%|█████████▏| 1631/1770 [29:30<02:33,  1.11s/it]

0.773758430410791


 93%|█████████▎| 1641/1770 [29:42<02:28,  1.15s/it]

0.773308957952468


 93%|█████████▎| 1651/1770 [29:55<02:23,  1.21s/it]

0.7746820109024833


 94%|█████████▍| 1661/1770 [30:07<02:03,  1.14s/it]

0.7754364840457556


 94%|█████████▍| 1671/1770 [30:18<01:52,  1.13s/it]

0.7749850388988629


 95%|█████████▍| 1681/1770 [30:29<01:35,  1.07s/it]

0.7763236168947055


 96%|█████████▌| 1691/1770 [30:40<01:26,  1.09s/it]

0.7764636309875813


 96%|█████████▌| 1701/1770 [30:51<01:15,  1.09s/it]

0.7777777777777778


 97%|█████████▋| 1711/1770 [31:02<01:03,  1.08s/it]

0.7773232028053769


 97%|█████████▋| 1721/1770 [31:13<00:53,  1.08s/it]

0.7762928529924462


 98%|█████████▊| 1731/1770 [31:25<00:45,  1.18s/it]

0.7770075101097631


 98%|█████████▊| 1741/1770 [31:36<00:31,  1.10s/it]

0.7777139574956922


 99%|█████████▉| 1751/1770 [31:46<00:20,  1.09s/it]

0.7784123358081096


 99%|█████████▉| 1761/1770 [31:58<00:09,  1.09s/it]

0.7796706416808632


100%|██████████| 1770/1770 [32:07<00:00,  1.08s/it]

Accuracy: 0.78


| method | subject | voxel selection | accuracy |
|---|---|---|---|
| multiple regressors | 3 | Most stable | 0.66, 0.55, 0.60 |
| multiple regressors, no normalization | 1 | Most stable | 0.60 |
| Ridge (not voxel-wise), no normalization | 1 | Most stable | 0.45 |

Note from professor: there is no need to determine the set of best predicted voxels across multiple folds. We just compute the accuracy for each fold and then average.

**Observation**

In this case fitting is way more expensive, as 21k voxels are considered.

In [13]:
from scipy.stats import pearsonr
import matplotlib.pyplot as plt

def best_K_predict(X, indices, predictors):
    predictors = [predictors[idx] for idx in indices]
    y_hat = np.array([predictor.predict(X) for predictor in predictors]) # voxels, sample
    return y_hat.reshape(y_hat.shape[1], y_hat.shape[0]) # sample, voxels

# voxel_indices

y_hat = best_K_predict(X_train, mitchell_voxels, predictors)
y = y_train[:, mitchell_voxels]

RDM_hat = np.matmul(y_hat, np.matrix.transpose(y_hat))

RDM = np.matmul(y, np.matrix.transpose(y))

test_pearson = pearsonr(
    RDM_hat.flatten(),
    RDM.flatten()
)

print(f"Test RDMs R^2:\t{test_pearson}")

plt.subplot(121)
plt.title("Truth")
plt.imshow(RDM)
plt.colorbar()

plt.subplot(122)
plt.title("Prediction")
plt.imshow(RDM_hat)
plt.colorbar()

TypeError: 'MultiOutputRegressor' object is not subscriptable

**Observation**

Here the the voxels from the last cross_val iteration have been selected. For these voxels, the object to object distance matrices have similar patterns.

### 2. GloVe embeddings

def get_word_activations(path, skip_lines=0):
    """
        Returns dataset of fMRI word activations
        path            Path to .txt fMRI data vectors (continuous) from Cognival
        context_len     Words before the occurring one
        data            Returned dictionary with key ['word'] -> {'context', 'activations'}
    """
    data = {}
    with open(path, "r") as datafile:
        lines = datafile.readlines()[skip_lines:] # skip header
        for line in tqdm(lines):
            word = line.split(" ")[0]
            activations = np.array([float(x) for x in line.split(" ")[1:]])
            data[word] = {"activations": activations}
    return data

In [None]:
filename = os.path.join("data", "glove.6B", f"glove.6B.100d.txt")
glove_embeddings = get_word_activations(filename, skip_lines = 0)

In [None]:
semantic_features = load_sorted_semantic_features()

In [None]:
X = []
Y = []
for key in common_keys:
    X.append(glove_embeddings[key]["activations"])
    Y.append(binder_features[key])

X = np.array(X).astype(np.float32) # word embeddings
Y = np.array(Y).astype(np.float32) # binder features

X.shape, Y.shape