# Two-step SRM
 @Ruiqing Zhang

+ ### Generate GM mask using all sub's highres.nii.gz 

+ ### Concatnate all runs, generate train dataset  

+ ### Time segment matching

+ ### First step SRM for all subjects

+ ### Second step SRM within group

+ ### ISC and permutation test 

In [3]:
from os.path import abspath, dirname, join
from brainiak import image, io
from brainiak.isc import isc, isfc, compute_summary_statistic,permutation_isc
from brainiak.fcma.util import compute_correlation
import brainiak.funcalign.srm
import sys
import os
import numpy as np
import nibabel as nib
import h5py
from itertools import combinations
from scipy.cluster.hierarchy import fcluster, linkage, dendrogram
from sklearn.cluster import KMeans
import scipy.spatial.distance as sp_distance
from sklearn.manifold import TSNE
from scipy import stats
from scipy.stats import norm, pearsonr, zscore
from scipy.spatial.distance import squareform
from statsmodels.sandbox.stats.multicomp import multipletests
from nilearn import datasets,image,masking
from nilearn.input_data import NiftiMasker, NiftiLabelsMasker


In [None]:
from sklearn.metrics import confusion_matrix

 Subjects info table
 
 |Proficiency|  N   |ACC_Comp | subID |
 |:---:|: --- |: ---| :--- |
 |HPL|10|0.861538462| 1,2,4,8,11,16,19,20,23,28|
 |MPL|9|0.671328671|3,6,7,14,21,24,32,35,37|  
 |LPL|16|0.415865385|5,9,10,12,13,15,17,18,22,25,26,27,30,31,33,34,36|

In [3]:
root_dir=dirname('/data/neuro/LLS_audio/')
fdir = dirname(join(root_dir,'derivatives/analysis01/FUNC_reorg/'))
PL=[[1,2,4,8,11,16,19,20,23,28],[3,6,7,14,21,24,32,34,35,37],[5,9,10,12,13,15,17,18,22,25,26,27,30,31,33,36]]

- #### organize train data for SRM

   subejcts= (range(1,29),range(30,38)) 
   
   sub029 is removed due to the earphone problem

In [5]:
def generate_train_dset(mask,subjects):
        f = h5py.File('train_dset.h5','w')
        mask=nib.load(join(fdir,mask))
        train_dset = []
        test_dset= []
        masked_img = []
        images_concatenated=[]
        
        for subject in subjects:
                for sub in subject:
                        dset=[]
                        dset += [join(fdir, 'sub{:03d}.run{:02d}.func.resampl.nii.gz').format(sub,run)  for run in range(1,5)]
                        masked_img = masking.apply_mask(dset,mask)
                        data = np.array(masked_img)
                        images_concatenated.append(masked_img)
        
        data=np.array(images_concatenated)    
        train_dset=np.swapaxes(data,1,2)
        num_subs, vox_num, nTR = train_dset.shape
        print('train dataset: ',
                'Participants ', num_subs,
                'Voxels per participant ', vox_num,
                'TRs per participant ', nTR)
        for sub in range(len(train_dset)):
                train_dset[sub] = stats.zscore(train_dset[sub],axis=1,ddof=1)
                train_dset[sub] = np.nan_to_num(train_dset[sub]) 
        
        f.create_dataset(('train_dset_{}').format(num_subs),data=train_dset)
        f.close()

+ #### validation performance of the feature in second step SRM

In [6]:
def time_segment_matching(data, win_size=10): 
    nsubjs = len(data)
    (ndim, nsample) = data[0].shape
    accu = np.zeros(shape=nsubjs)
    nseg = nsample - win_size 
    
    # mysseg prediction
    trn_data = np.zeros((ndim*win_size, nseg),order='f')
    
    # the training data also include the test data, but will be subtracted when calculating A
    for m in range(nsubjs):
        for w in range(win_size):
            trn_data[w*ndim:(w+1)*ndim,:] += data[m][:,w:(w+nseg)]
    for tst_subj in range(nsubjs):
        tst_data = np.zeros((ndim*win_size, nseg),order='f')
        for w in range(win_size):
            tst_data[w*ndim:(w+1)*ndim,:] = data[tst_subj][:,w:(w+nseg)]

        A =  np.nan_to_num(stats.zscore((trn_data - tst_data),axis=0, ddof=1))
        B =  np.nan_to_num(stats.zscore(tst_data,axis=0, ddof=1))

        # compute correlation matrix
        corr_mtx = compute_correlation(B.T,A.T)

        # The correlation classifier.
        for i in range(nseg):
            for j in range(nseg):
                # exclude segments overlapping with the testing segment
                if abs(i-j)<win_size and i != j :
                    corr_mtx[i,j] = -np.inf
        max_idx =  np.argmax(corr_mtx, axis=1)
        accu[tst_subj] = sum(max_idx == range(nseg)) / nseg

        # Print accuracy
        print("Accuracy for subj %d is: %0.4f" % (tst_subj, accu[tst_subj] ))
        
    print("The average accuracy among all subjects is {0:f} +/- {1:f}".format(np.mean(accu), np.std(accu)))
    return accu

+ #### First step SRM to get residuals by substracting recons_dset from train_dset
    
    dset='train_dset_{}'.format(num_subs)
    
    params = [10,20]
    n_iter = 100

In [7]:
def first_step_srm(params,n_iter,dset):
    f = h5py.File('train_dset.h5','r')
    for k1 in params:
        train_dset = f[dset][:]
        print('feature numbers :', k1, 'srm iterations : ', n_iter)
        
        srm = brainiak.funcalign.srm.SRM(n_iter=n_iter, features=k1,)
        srm.fit(train_dset)
        print('SRM has been fit')
        
        shared_train = srm.transform(train_dset)
        for sub in range(len(shared_train)):
                shared_train[sub] = stats.zscore(shared_train[sub],axis=1,ddof=1)
                shared_train[sub] = np.nan_to_num(shared_train[sub])
        
        f1 = h5py.File('shared_signal_{}_{}.h5'.format(k1,n_iter),'w')
        f1.create_dataset(('shared_signal_s1_{}_{}').format(k1,n_iter),data=shared_train)
        print('shared train data of srm1 is saved!')
        
        recons_dset = np.zeros((train_dset.shape))
        for p in range(len(recons_dset)):
                w=srm.w_[p]
                recons_dset[p,:,:]=w.dot(shared_train[p])
                recons_dset[p] = stats.zscore(recons_dset[p])
                recons_dset[p] = np.nan_to_num(recons_dset[p])
        
        f2 = h5py.File('recons_signal_{}_{}.h5'.format(k1,n_iter),'w')
        f2.create_dataset(('recons_signal_s1_{}_{}').format(k1,n_iter),data=recons_dset)
        print('reconstruction data of srm1 is saved!')
       
        residuals = np.zeros((train_dset.shape))
        for s in range(len(residuals)):
            residuals[s]=np.subtract(train_dset[s],recons_dset[s])
        
        f3 = h5py.File('residuals_{}_{}.h5'.format(k1,n_iter),'w')
        f3.create_dataset(('residuals_{}_{}').format(k1,n_iter),data=residuals)
        
        f1.close();f2.close();f3.close()
        del train_dset
        del shared_train

    f.close()

+  #### Voxel Selection

In [None]:
def voxel_selection(k1,n_iter):
    residuals_name = ('residuals_{}_{}').format(k1,n_iter)
    f = h5py.File('residuals_{}_{}.h5'.format(k1,n_iter),'r')
    dset = f[residuals_name][:] 
    data  = np.zeros((dset[0].shape[1],dset[0].shape[0],len(dset))) 
    for i in range(len(dset)):
        data[:,:,i] = dset[i].T
    print('ISC computation begins...')
    iscs = isc(data,summary_statistic='mean',tolerate_nans=0.85)
    iscs = np.nan_to_num(iscs)
    print('permutation test begins...')
    observed, p, distribution = bootstrap_isc(iscs)
    if  np.where(p<0.001) != np.empty:
        sig = 0.001
        ind = np.where(p<0.001)
        print('permutation test in residuals: found significant voxels in p<0.001') 
    elif  np.where(p<0.005) != np.empty:
        sig = 0.005
        ind = np.where(p<0.005)
        print('permutation test in residuals: found significant voxels in p<0.005')
    elif np.where(p<0.05) == np.empty:
        sig = 0.05
        ind = np.where(p<0.05)
        print('permutation test in residuals: found no statistical significant voxels in p < 0.05')
    observed[ind]=1
    observed[~ind]=0  
    coords = np.where(observed)
    masked_residual = np.zeros(len(dset),len(ind),dset[0].shape[1])
    for sub in range(len(dset)):
        masked_residual[sub] = dset[sub][coords,:]
    print('masked_residuals shape is :',masked_residual.shape)
    ff = h5py.File('masked_residual_k1{}_n_iter{}_sig{}.h5'.format(k1,n_iter,sig),'w')
    ff.create_dataset('masked_residual_k1{}_n_iter{}_sig{}'.format(k1,n_iter,sig),data=masked_residual)
    ff.close()
    f.close()

+ #### Second step SRM to fit residuals within-group 
 
    K1 = [20,50] params = [50,100] n_iter =100

In [8]:
def second_step_srm(k1,params,n_iter):
        '''
        group srm and validation 
        '''
        print('second srm fit within-group data')
        residuals_name = ('residuals_{}_{}').format(k1,n_iter)
        f = h5py.File('residuals_{}_{}.h5'.format(k1,n_iter),'r')
        residuals = f[residuals_name][:]
        
        for i in PL:
                train_dset2 = np.zeros((len(i),residuals[0].shape[0],residuals[0].shape[1]))
                for j in range(len(i)):
                        id = i[j]-1
                        if i[j] < 30:
                                print('extracting subject',id+1)
                                train_dset2[j]=residuals[id]
                        else:
                                print('extracting subject',id+1) 
                                train_dset2[j]=residuals[id-1]
                num_subs,n_voxels,nTR = train_dset2.shape
                print('subjects number in 2ed srm:',num_subs,
                        '\n number of voxels:',n_voxels ,
                        '\n number of TRs:',nTR)
                for k2 in params:
                        srm2 =brainiak.funcalign.srm.SRM(n_iter=n_iter, features=k2,) 
                
                print('second step srm fitting within-group data...', 
                        '\n features number ', k2 , '\n iteration', n_iter)
                
                srm2.fit(train_dset2)
                shared_train2 = srm2.transform(train_dset2)
              
                for sub in range(len(shared_train2)):
                        shared_train2[sub] = stats.zscore(shared_train2[sub],axis=1,ddof=1)
                        shared_train2[sub] = np.nan_to_num(shared_train2[sub])
               
                recons_dset2=np.zeros((train_dset2.shape))
                for p in range(len(train_dset2)):
                        w=srm2.w_[p]
                        recons_dset2[p]=w.dot(shared_train2[p])
                        recons_dset2[p] = stats.zscore(recons_dset2[p])
                        recons_dset2[p] = np.nan_to_num(recons_dset2[p])
                
           
                groups=dict([('H',PL[0]),('M',PL[1]),('L',PL[2])])
                for key, value in groups.items():
                        if value == i:
                                f1 = h5py.File(('shared_'+ key +'G_k1_{}_k2_{}.h5').format(k1,k2),'w')
                                f2 = h5py.File(('recons_'+ key +'G_k1_{}_k2_{:}.h5').format(k1,k2),'w')
                                f1.create_dataset(('shared_'+ key +'G_k1_{}_k2_{}').format(k1,k2),data=shared_train2)
                               
                                f2.create_dataset(('recons_'+ key +'G_k1_{}_k2_{:}').format(k1,k2),data=recons_dset2)
                                            
                print('validation parameters :',
                        '\n subjects ID:', i,
                        '\n k:', k2)
                #accuracy = time_segment_matching(recons_dset2,win_size=10)
                #np.savetxt(ot_name,accuracy)
                f1.close();f2.close()
        
        f.close()

+ #### Validation using TSM

In [None]:
def validation(group,k1,k2):
    f=h5py.File('recons_{}G_k1_{}_k2_{}.h5'.format(group,k1,k2),'r')
    dset=f['recons_{}G_k1_{}_k2_{}'.format(group,k1,k2)][:]
    accuracy = time_segment_matching(dset,win_size=10)
    ot_name=join(fdir,'ts_acc_{}G_k1_{}_k2_{}.txt').format(group,k1,k2) 
    np.savetxt(ot_name,accuracy)
    f.close()

+ #### permutation test for group ISC 

   group is two samples : ['H','L'] or just one sample : ['H']
   
   k1=[20,50] k2=[50,100]
   
   mask = 'GM_{}_mask.nii.gz'.format(num_sub)

In [1]:
def permut_isc(group,k1,k2,mask):
    GM_epi_mask = join(fdir,'results',mask)
    bi_mask=io.load_boolean_mask(GM_epi_mask, lambda x: x>0.5)
    coords =np.where(bi_mask)
    dset = []
    groups=dict([('H',len(PL[0])),('M',len(PL[1])),('L',len(PL[2]))])
    
    if len(group)==1:
        f1 = h5py.File('permutation_isc_p_{}_{}.h5'.format(k1,k2),'w')
        f2 = h5py.File('permutation_isc_value_{}_{}.h5'.format(k1,k2),'w')
        f = h5py.File('recons_{}G_k1_{}_k2_{}'.format(group[0],k1,k2),'r')
        dset = f[('recons_{}G_k1_{}_k2_{}'.format(group[0],k1,k2))][:]
        label = None
        p_name = ('{}_p_k1_{}_k2_{}'.format(group[0],k1,k2))
        isc_name = ('{}_observed_{}_k2_{}.nii.gz'.format(group[0],k1,k2)) 
    
    elif len(group)==2:
        f1 = h5py.File('permutation_isc_p_{}vs{}_{}_{}.h5'.format(group[0],group[1],k1,k2),'w')
        f2 = h5py.File('permutation_isc_value_{}vs{}_{}_{}.h5'.format(group[0],group[1],k1,k2),'w')
        f = h5py.File('recons_{}G_k1_{}_k2_{}'.format(group[0],k1,k2),'r')
        ff = h5py.File('recons_{}G_k1_{}_k2_{}'.format(group[1],k1,k2),'r')
        dset = np.concatenate((f['recons_{}G_k1_{}_k2_{}'.format(group[0],k1,k2)][:],ff['recons_{}G_k1_{}_k2_{}'.format(group[1],k1,k2)][:]))
        label = np.concatenate((np.zeros(groups.get('{}'.format(group[0]))),np.ones(groups.get('{}'.format(group[1]))))) 
        p_name = ('diff_{}vs{}_p_k1_{}_k2_{}'.format(group[0],group[1],k1,k2))
        isc_name = ('diff_{}vs{}_observed_k1_{}_k2_{}.nii.gz'.format(group[0],group[1],k1,k2)) 

    f.close()
    data  = np.zeros((dset[0].shape[1],dset[0].shape[0],len(dset))) 
    for i in range(len(dset)):
        data[:,:,i] = dset[i].T
    
    iscs = isc(data,summary_statistic='mean',tolerate_nans=0.85)
    iscs = np.nan_to_num(iscs)
    observed, p, distribution = permutation_isc(iscs,group_assignment=label) 
    
    if  np.where(p<0.001) != np.empty:
        sig = 0.001
        print('permutation test in group {}: found significant voxels in p<0.001'.format(group)) 
    elif  np.where(p<0.005) != np.empty:
        sig = 0.005
        print('permutation test in group {}: found significant voxels in p<0.005'.format(group))
    elif np.where(p<0.05) == np.empty:
        sig = 0.5
        print('permutation test in group {}: found no statistical significant voxels in p < {}'.format(group,sig))
    
    ind = np.where(p<sig)
    f1.create_dataset(p_name,data=p)
    f2.create_dataset(isc_name[:-7],data=observed)
    f.close();f1.close();f2.close()
    nii_template = nib.load(GM_epi_mask)
    out_isc = np.zeros(nii_template.shape)
    out_isc[coords] = observed
    isc_obj = nib.Nifti1Image(out_isc,nii_template.affine,nii_template.header)
    nib.save(isc_obj,isc_name)
    threshod = min(np.abs(observed[ind]))
    print('one-tail threshold of the difference isc values is ',threshod, 'p < ',sig)