In [8]:
import warnings
import sys 
import os 

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt

import nibabel as nib

from brainiak.isc import isc
from brainiak.fcma.util import compute_correlation
import brainiak.funcalign.srm
from brainiak import image, io

import scipy.spatial.distance as sp_distance

from tqdm.notebook import tqdm

from sklearn.model_selection import LeaveOneOut

from shared_gpfa import SharedGpfa

%autosave 5

%load_ext autoreload
%autoreload 2

%matplotlib widget

Autosaving every 5 seconds
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
movie_data = np.load(os.path.join('raider', 'movie.npy'))
vox_num, nTR, num_subs = movie_data.shape
print(vox_num, nTR, num_subs)

1000 2203 10


In [10]:
train_data = []
test_data = []
for sub in range(num_subs):
    train_data.append(movie_data[:, :nTR//2, sub])    
    test_data.append(movie_data[:, -(nTR//2):, sub])  

In [11]:
for sub in range(num_subs):    
    train_data[sub] = stats.zscore(train_data[sub], axis=1, ddof=1)
    train_data[sub] = np.nan_to_num(train_data[sub])
    
    test_data[sub] = stats.zscore(test_data[sub], axis=1, ddof=1)
    test_data[sub] = np.nan_to_num(test_data[sub])

In [None]:
features = 50  
n_iter = 20 
srm = brainiak.funcalign.srm.SRM(n_iter=n_iter, features=features)
srm.fit(train_data)

In [None]:
print('SRM: Features X Time-points ', srm.s_.shape)

plt.figure(figsize=(15, 4))
plt.tight_layout()
plt.imshow(srm.s_, cmap='viridis')
plt.title('SRM: Features X Time-points')
plt.xlabel('TR')
plt.ylabel('feature')
plt.colorbar()

plt.figure(figsize=(20, 4))
plt.plot(srm.s_[0:3,:].T);
plt.title('SRM: top feature')
plt.xlabel('TR')
plt.grid()

In [None]:
dist_mat = sp_distance.squareform(sp_distance.pdist(srm.s_.T))
plt.figure(figsize=(10,10))
plt.title('Distance between pairs of time points in shared space')
plt.xlabel('TR')
plt.ylabel('TR')
plt.imshow(dist_mat, cmap='viridis')
plt.colorbar()

In [None]:
# plt.subplot(211)
plt.figure()
plt.plot(srm.w_[0][0,:])
plt.plot(srm.w_[1][0,:])
feature_corr = np.corrcoef(srm.w_[0][0,:], srm.w_[1][0,:].T)[0,1]
plt.title('SRM: Weights x Features for one voxel (correlation of loading, r: %0.3f)' % feature_corr) 
plt.xlabel('feature')
plt.ylabel('weight for one voxel')
plt.tight_layout()

In [None]:
# Transform the SRM data into shared space
shared_train = srm.transform(train_data)

# Zscore the transformed training data
for subject in range(num_subs):
    shared_train[subject] = stats.zscore(shared_train[subject], axis=1, ddof=1)

# Reorganize the data back into an appropriate space for ISC
raw_obj = np.zeros((train_data[0].shape[0], train_data[0].shape[1], len(train_data)))
for ppt in range(len(train_data)):
    raw_obj[:, :, ppt] = train_data[ppt]
    
# Perform ISC on all participants, collapsing across participants    
corr_raw = isc(raw_obj, summary_statistic='mean')
corr_raw = np.nan_to_num(corr_raw)  

# Reorganize the SRM transformed data back into an appropriate space for ISC
shared_obj = np.zeros((shared_train[0].shape[0], shared_train[0].shape[1], len(train_data)))
for ppt in range(len(train_data)):
    shared_obj[:, :, ppt] = shared_train[ppt]
    
# Perform ISC on all participants, collapsing across participants        
corr_shared = isc(shared_obj, summary_statistic='mean')
corr_shared = np.nan_to_num(corr_shared)


In [None]:
plt.figure(figsize=(14,5))
plt.subplot(1, 2, 1)
plt.title('ISC for all voxels')
plt.hist(corr_raw);
plt.xlabel('correlation')
plt.ylabel('number of voxels')
plt.xlim([-1, 1]);

plt.subplot(1, 2, 2)
plt.title('ISC for shared features')
plt.hist(corr_shared);
plt.xlabel('correlation')
plt.ylabel('number of features')
plt.xlim([-1, 1]);

plt.tight_layout()

tstat = stats.ttest_ind(np.arctanh(corr_shared), np.arctanh(corr_raw))
print('Independent samples t test between raw and SRM transformed data:', tstat.statistic, 'p:', tstat.pvalue)


In [None]:
# Transform the test data into the shared space using the individual weight matrices
shared_test = srm.transform(test_data)

# Zscore the transformed test data
for subject in range(num_subs):
    shared_test[subject] = stats.zscore(shared_test[subject], axis=1, ddof=1)

# Reorganize the data back into an appropriate space for ISC
raw_obj = np.zeros((test_data[0].shape[0], test_data[0].shape[1], len(test_data)))
for ppt in range(len(test_data)):
    raw_obj[:, :, ppt] = test_data[ppt]
    
# Perform ISC on all participants, collapsing across participants    
corr_raw = isc(raw_obj, summary_statistic='mean')
corr_raw = np.nan_to_num(corr_raw)  

# Reorganize the SRM transformed data back into an appropriate space for ISC
shared_obj = np.zeros((shared_test[0].shape[0], shared_test[0].shape[1], len(test_data)))
for ppt in range(len(test_data)):
    shared_obj[:, :, ppt] = shared_test[ppt]
    
# Perform ISC on all participants, collapsing across participants        
corr_shared = isc(shared_obj, summary_statistic='mean')
corr_shared = np.nan_to_num(corr_shared)


In [None]:
plt.figure(figsize=(14,5))
plt.subplot(1, 2, 1)
plt.title('ISC for all voxels')
plt.hist(corr_raw);
plt.xlabel('correlation')
plt.ylabel('number of voxels')
plt.xlim([-1, 1]);

plt.subplot(1, 2, 2)
plt.title('ISC for shared features')
plt.hist(corr_shared);
plt.xlabel('correlation')
plt.ylabel('number of features')
plt.xlim([-1, 1]);

plt.tight_layout()

tstat = stats.ttest_ind(np.arctanh(corr_shared), np.arctanh(corr_raw))
print('Independent samples t test between raw and SRM transformed data:', tstat.statistic, 'p:', tstat.pvalue)


In [None]:
w0 = srm.w_[0]  # Weights for subject 1
signal_srm0 = w0.dot(shared_test[0])  # Reconstructed signal for subject 1

plt.figure(figsize=(10,5))
plt.title('SRM reconstructed vs. original signal for one voxel', fontsize=14)
plt.plot(signal_srm0[100,:100])
plt.plot(test_data[0][100,:100])
plt.xlabel('TR')
plt.ylabel('signal of one voxel')
plt.legend(('Reconstructed data', 'Original data'), loc=(1.04,0.5))
plt.tight_layout()

In [None]:
# Do the reconstruction on all individual participants and organize it for ISC
signal_srm = np.zeros((test_data[0].shape[0], test_data[0].shape[1], len(test_data)))
for ppt in range(len(test_data)):
    signal_srm[:, :, ppt] = w0.dot(shared_test[ppt])

corr_reconstructed = isc(signal_srm, summary_statistic='mean')
corr_reconstructed = np.nan_to_num(corr_reconstructed)

In [None]:
# Plot the figure
plt.figure(figsize=(14,5))
plt.subplot(1, 2, 1)
plt.title('ISC for all voxels')
plt.hist(corr_raw);
plt.xlabel('correlation')
plt.ylabel('number of voxels')
plt.xlim([-1, 1]);

plt.subplot(1, 2, 2)
plt.title('ISC for shared features')
plt.hist(corr_reconstructed[0]);
plt.xlabel('correlation')
plt.ylabel('number of features')
plt.xlim([-1, 1]);

plt.tight_layout()

tstat = stats.ttest_1samp(np.arctanh(corr_reconstructed) - np.arctanh(corr_raw), 0)
print('Dependent samples t test between raw and SRM transformed data:', tstat.statistic, 'p:', tstat.pvalue)


# 5. Time-segment matching

In [None]:
# Take in a list of participants of voxel by TR data. Also specify how big the time segment is to be matched
def time_segment_matching(data, win_size=10): 
    nsubjs = len(data)
    (ndim, nsample) = data[0].shape
    accu = np.zeros(shape=nsubjs)
    nseg = nsample - win_size 
    
    # mysseg prediction
    trn_data = np.zeros((ndim*win_size, nseg),order='f')
    
    # the training data also include the test data, but will be subtracted when calculating A
    for m in range(nsubjs):
        for w in range(win_size):
            trn_data[w*ndim:(w+1)*ndim,:] += data[m][:,w:(w+nseg)]
    for tst_subj in range(nsubjs):
        tst_data = np.zeros((ndim*win_size, nseg),order='f')
        for w in range(win_size):
            tst_data[w*ndim:(w+1)*ndim,:] = data[tst_subj][:,w:(w+nseg)]

        A =  np.nan_to_num(stats.zscore((trn_data - tst_data),axis=0, ddof=1))
        B =  np.nan_to_num(stats.zscore(tst_data,axis=0, ddof=1))

        # compute correlation matrix
        corr_mtx = compute_correlation(B.T,A.T)

        # The correlation classifier.
        for i in range(nseg):
            for j in range(nseg):
                # exclude segments overlapping with the testing segment
                if abs(i-j)<win_size and i != j :
                    corr_mtx[i,j] = -np.inf
        max_idx =  np.argmax(corr_mtx, axis=1)
        accu[tst_subj] = sum(max_idx == range(nseg)) / nseg

        # Print accuracy
#         print("Accuracy for subj %d is: %0.4f" % (tst_subj, accu[tst_subj] ))
        
    print("The average accuracy among all subjects is {0:f} +/- {1:f}".format(np.mean(accu), np.std(accu)))
    return accu

In [None]:
# definitly double dipping!

accu_train_r = time_segment_matching(train_data, win_size=10)
accu_train_s = time_segment_matching(shared_train, win_size=10)
accu_test_r = time_segment_matching(test_data, win_size=10)
accu_test_s = time_segment_matching(shared_test, win_size=10)

# Corrected Time-segment Matching

In [17]:
# Take in a list of participants of voxel by TR data. Also specify how big the time segment is to be matched
def correlation_classifier(train, test, win_size): 
    nsubjs = len(train)
    (ndim, nsample) = train[0].shape
    nseg = nsample - win_size 
    
    # mysseg prediction
    trn_data = np.zeros((ndim*win_size, nseg),order='f')
    
    for m in range(nsubjs):
        for w in range(win_size):
            trn_data[w*ndim:(w+1)*ndim,:] += train[m][:,w:(w+nseg)]        
        
    tst_data = np.zeros((ndim*win_size, nseg),order='f')
    for w in range(win_size):
        tst_data[w*ndim:(w+1)*ndim,:] = test[:,w:(w+nseg)]

    # compute correlation matrix
    corr_mtx = compute_correlation(trn_data.T,tst_data.T)

    # The correlation classifier.
    for i in range(nseg):
        for j in range(nseg):
            # exclude segments overlapping with the testing segment
            if abs(i-j)<win_size and i != j :
                corr_mtx[i,j] = -np.inf
    max_idx =  np.argmax(corr_mtx, axis=1)
    accu = sum(max_idx == range(nseg)) / nseg
    return accu

In [18]:
def time_segment_matching(half1, half2, n_features=50, win_size=10, desc=''):
    if type(win_size) is int: win_size = [win_size]
    loo = LeaveOneOut()
    acc_r = []
    acc_s = []
    acc_g = []

    pbar = tqdm(loo.split(half1), total=loo.get_n_splits(half1), desc=desc)
    for train_subs, test_sub in pbar:
        test_sub = test_sub[0]

        # no functional registration
        pbar.set_description(f'{desc} Raw')
        pbar.refresh()
        # time-segment matching performance
        sub_acc = [correlation_classifier(half2[train_subs], half2[test_sub], w) for w in win_size]
        acc_r.append(sub_acc)

        
        # Shared Response Model
        pbar.set_description(f'{desc} SRM')
        pbar.refresh()
        # fit srm using half1[train_subs]
        srm = brainiak.funcalign.srm.SRM(n_iter=20, features=n_features, rand_seed=0)
        srm.fit(half1[train_subs])
        # held-out subject base using half1[test_sub]
        w_test_sub = srm.transform_subject(half1[test_sub])
        # map half2[train_subs] to shared space
        half2_shared_train = srm.transform(half2[train_subs])
        # map held-out subject half2 to shared space
        half2_shared_test = w_test_sub.T.dot(half2[test_sub])
        # time-segment matching performance
        sub_acc = [correlation_classifier(half2_shared_train, half2_shared_test, w) for w in win_size]
        acc_s.append(sub_acc)
        
        # Shared GPFA
        pbar.set_description(f'{desc} SGPFA')
        pbar.refresh()
        # fit sgpfa using half1[train_subs]
        sgpfa = SharedGpfa(len(train_subs), half1.shape[1], n_features)
        sgpfa.fit(train_data=half1[train_subs], n_iters=1100, learning_rate=0.075, fa_init=True, reg=0.)
        # map half2[train_subs] to shared space
        half2_shared_train = sgpfa.add_video(half2[train_subs], n_iters=400, learning_rate=0.06, ls_init=False, desc='mapping 2nd half for train')
        # held-out subject base using half1[test_sub]
        w_test_sub = sgpfa.add_subject(half1[test_sub], add_to_model=True, solve_ls=True)
        # map held-out subject half2 to shared space
        ## TODO: what about the noise variance for the treaining subject?? 
        half2_shared_test = sgpfa.add_video(half2[test_sub], n_iters=300, learning_rate=0.06, ls_init=False, subs=sgpfa.m, desc='mapping 2nd half for test')
        half2_shared_test = half2_shared_test[0]

        # time-segment matching performance
        sub_acc = [correlation_classifier(half2_shared_train, half2_shared_test, w) for w in win_size]
        acc_g.append(sub_acc)
        
        for i in range(len(acc_g[-1])):
            print(f'w = {win_size[i]} \t-- SGPFA: {acc_g[-1][i]:.2f} \t SRM: {acc_s[-1][i]:.2f} \t Raw: {acc_r[-1][i]:.2f}')
        
    acc_s = np.array(acc_s)
    acc_r = np.array(acc_r)
    acc_g = np.array(acc_g)

    return acc_g, acc_s, acc_r

In [26]:
half1 = np.stack(train_data, 0)
half2 = np.stack(test_data, 0)

n = 3
acc_g, acc_s, acc_r = time_segment_matching(half1[:n, :, :75], half2[:n, :, :75], n_features=40, win_size=[20])

print(f'The average accuracy among all subjects using SGPFA is:  {acc_g.mean(0)} +/- {acc_g.std(0)}')
print(f'The average accuracy among all subjects using SRM is:    {acc_s.mean(0)} +/- {acc_s.std(0)}')
print(f'The average accuracy among all subjects using RAW is:    {acc_r.mean(0)} +/- {acc_r.std(0)}')

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='fitting SGPFA', max=1100.0, style=ProgressStyle(descripti…





KeyboardInterrupt: 

In [None]:
nsubs = np.arange(3, len(half1) + 1)
acc_r = []
acc_s = []
acc_g = []
win_size=[3, 5, 10, 20, 30]

for n in nsubs:
    _acc_g, _acc_s, _acc_r = time_segment_matching(half1[:n, :, :75], half2[:n, :, :75], n_features=40, desc=f'{n} subjects', win_size=win_size)
    print(f'The average accuracy among all subjects using SGPFA is:  {_acc_g.mean(0)} +/- {_acc_g.std(0)}')
    print(f'The average accuracy among all subjects using SRM is:    {_acc_s.mean(0)} +/- {_acc_s.std(0)}')
    print(f'The average accuracy among all subjects using RAW is:    {_acc_r.mean(0)} +/- {_acc_r.std(0)}')
    acc_r.append(_acc_r)
    acc_s.append(_acc_s)
    acc_g.append(_acc_g)

In [None]:
plt.style.use('ggplot')
ncols = 2
nrows = len(win_size) // ncols
fig, axs = plt.subplots(nrows, ncols, figsize=(10*ncols,5*nrows))

for w, ax in zip(range(len(win_size)), axs):
    for acc, label in zip((acc_g, acc_s, acc_r), ('SGPFA', 'SRM', 'Raw')):
        y = list(map(lambda x: x.mean(0), acc))[w]
        yerr = list(map(lambda x: x.std(0), acc))[w]
        ax.errorbar(nsubs, y=y , yerr=yerr, label=label, marker='o', capsize=3, elinewidth=1.5)
    ax.grid(color='w')
    ax.legend()
    ax.set_ylabel('Accuracy')
    ax.set_xlabel('Number of subjects')
    ax.set_ylim(0, 1)
    ax.set_title(f'window size: {win_size[w]}')


In [None]:
z = [
    [[0.31481481, 0.5047619 , 0.82051282, 0.93333333, 1.        ], [0.07550697, 0.10583862, 0.01918799, 0.0942809 , 0.        ]],
    [[0.31018519, 0.47142857, 0.76410256, 0.96969697, 1.        ], [0.06835566, 0.09965928, 0.05664288, 0.04285496, 0.        ]],
    [[0.11574074, 0.21428571, 0.43076923, 0.68484848, 0.98518519], [0.06245712, 0.07648752, 0.16426274, 0.03090315, 0.02095131]],
    [[0.33680556, 0.56071429, 0.82307692, 0.95454545, 1.        ], [0.07039283, 0.08885611, 0.07956985, 0.0522233 , 0.        ]],
    [[0.36111111, 0.53928571, 0.84230769, 0.96363636, 1.        ], [0.09771699, 0.09112393, 0.10568178, 0.06298367, 0.        ]],
    [[0.17361111, 0.30357143, 0.53461538, 0.77272727, 1.        ], [0.0625    , 0.06952829, 0.15890714, 0.04165978, 0.        ]],
    [[0.35833333, 0.58      , 0.88615385, 0.99272727, 1.        ], [0.08117577, 0.06916411, 0.07511727, 0.01454545, 0.        ]],
    [[0.37777778, 0.57714286, 0.86153846, 0.99272727, 1.        ], [0.09018157, 0.1329723 , 0.11051252, 0.01454545, 0.        ]],
    [[0.19444444, 0.30285714, 0.53846154, 0.85818182, 1.        ], [0.062113  , 0.07845446, 0.1513646 , 0.08635885, 0.        ]],
    [[0.36574074, 0.60714286, 0.87179487, 0.97272727, 1.        ], [0.04721314, 0.02945075, 0.09216513, 0.06098367, 0.        ]],
    [[0.40277778, 0.61666667, 0.88461538, 0.98787879, 1.        ], [0.10111264, 0.1152981 , 0.10723345, 0.02710385, 0.        ]],
    [[0.17361111, 0.27619048, 0.47435897, 0.82727273, 1.        ], [0.06551376, 0.12027934, 0.19488866, 0.10271789, 0.        ]],
    [[0.39484127, 0.6244898 , 0.9010989 , 1.        , 1.        ], [0.05027325, 0.09122244, 0.08540378, 0.        , 0.        ]],
    [[0.42460317, 0.62244898, 0.8967033 , 1.        , 1.        ], [0.0938641 , 0.11875377, 0.11266818, 0.        , 0.        ]],
    [[0.18452381, 0.27755102, 0.50549451, 0.85454545, 1.        ], [0.06459849, 0.12070182, 0.16841533, 0.12559187, 0.        ]],
    [[0.40277778, 0.60357143, 0.88461538, 1.        , 1.        ], [0.05379144, 0.04559695, 0.09036415, 0.        , 0.        ]],
    [[0.41666667, 0.63035714, 0.875     , 0.99090909, 1.        ], [0.09547033, 0.10137763, 0.13618589, 0.02405228, 0.        ]],
    [[0.20486111, 0.29821429, 0.55192308, 0.79545455, 1.        ], [0.07818285, 0.11052692, 0.12904692, 0.15844023, 0.        ]],
    [[0.41358025, 0.6015873 , 0.89059829, 0.9979798 , 1.        ], [0.0514609 , 0.06749299, 0.09660767, 0.00571399, 0.        ]],
    [[0.41975309, 0.61587302, 0.88205128, 0.99393939, 1.        ], [0.0817173 , 0.11979659, 0.13548559, 0.01714198, 0.        ]],
    [[0.19444444, 0.3       , 0.56239316, 0.77979798, 1.        ], [0.05892557, 0.09965928, 0.16169892, 0.17279519, 0.        ]],
    [[0.43194444, 0.63714286, 0.89846154, 0.99272727, 1.        ], [0.0649816 , 0.07261866, 0.08233285, 0.02181818, 0.        ]],
    [[0.42361111, 0.62714286, 0.88923077, 0.99454545, 1.        ], [0.09027778, 0.11472238, 0.13461319, 0.01636364, 0.        ]],
    [[0.19861111, 0.31571429, 0.56769231, 0.77818182, 1.        ], [0.05009636, 0.10110613, 0.14340739, 0.15487665, 0.        ]]
]

z = np.array(z)

win_size = [3, 5, 10, 20, 30]
nsubs = np.arange(3, 11)
fig, axs = plt.subplots(3, 2, figsize=(30, 20))
jitter = 0.015

for i, w, ax in zip(range(len(win_size)), win_size, axs.flat):
    for j, label in zip(range(3), ('SGPFA', 'SRM', 'Raw')):
        ax.errorbar(nsubs+j*jitter, y=z[j::3, 0, i], yerr=z[j::3, 1, i], label=label, marker='o', capsize=3, elinewidth=1.5, alpha=1)
    ax.grid(color='w')
    ax.legend()
    ax.set_ylabel('Accuracy')
    ax.set_xlabel('Number of subjects')
    ax.set_ylim(0, 1.05)
    ax.set_title(f'window size: {w}')
axs.ravel()[-1].set_visible(False)