In [None]:
from brainiak.isc import isc,bootstrap_isc,compute_summary_statistic
import pandas as pd
import numpy as np
import nibabel as nib
from nilearn.image import index_img,concat_imgs,resample_to_img,smooth_img,math_img
from nilearn.masking import apply_mask,unmask
from nilearn import plotting as niplt
import os
import time
from scipy.stats import mode as statmode
import matplotlib.pyplot as plt
import pickle
from scipy.ndimage import gaussian_filter1d, gaussian_filter
from skimage.measure import label as dolabel
from skimage.measure import regionprops
from scipy.stats import ttest_1samp,ttest_ind
from statsmodels.stats.multitest import fdrcorrection
from scipy.stats import f_oneway
from statsmodels.stats.anova import AnovaRM

def find_fmri_duration(dataf):
    
    l = []
    for i,row in dataf.iterrows():
        this_l = nib.load(row.fmri).shape[-1]
        if this_l < 50:
            print(row.fmri)
        l.append(this_l)
        
    l = np.array(l)
    if len(np.unique(l)) > 1:
        print('WTF ' + 50*'-')
    
    return l[0]

def export_results(iscstat,c,met,t,fld_out,logsumm):

    if met == 'boot':
        STAT = 'm'
    elif met == 't-test':
        STAT = 't'

    unthres = unmask(iscstat[c][met][STAT],av_mask_img)

    if t == 'uncorrected':
        thres_m = unmask(iscstat[c][met][STAT] * (iscstat[c][met]['p']< p_thres[t]),av_mask_img)
    elif t == 'FDR':
        isc_pmask,isc_pcorr = fdrcorrection(iscstat[c][met]['p'], p_thres[t])
        thres_m = unmask(iscstat[c][met][STAT] * isc_pmask, av_mask_img)

    thres_mc, cl = filter_clusters(thres_m,c_thres[t])
    thres_mc_bin = math_img('im > 0',im=thres_mc)


    ## Save Maps
    # Untresholded
#     fout = fld_out + '{}_isc_{}.nii.gz'.format(isc_type,c)
#     nib.save(unthres, fout)
    # Thresholded
    fout = fld_out + '{}ISC_{}_{}_{}_p{}_k{:02d}.nii.gz'.format(isc_type,c,met,t,str(p_thres[t])[2:],c_thres[t])
    nib.save(thres_mc, fout)
    # Thresholded binary
    fout = fout.replace('.nii.gz','_bin.nii.gz')
    nib.save(thres_mc_bin, fout)

    # Figure
    fig,ax = plt.subplots(figsize = (9,3))
    niplt.plot_glass_brain(thres_mc, figure=fig,axes=ax,colorbar = True)
    ax.set_title(r'{} ISC - {} correction ($\alpha = {}$, $\kappa = {}$)'.format(isc_type,t,p_thres[t],c_thres[t]))

    fout = fld_out + '{}ISC_{}_{}_{}_p{}_k{:02d}.pdf'.format(isc_type,c,met,t,str(p_thres[t])[2:],c_thres[t])

    plt.savefig(fout)
    plt.close('all')


    # Some logs
    included_subjects = logsumm.sort_values('subject')['subject'].unique()
    np.savetxt(fld_out + 'included_subjects.txt', included_subjects)

    included_movies = logsumm.sort_values('Title')['Title'].unique()
    np.savetxt(fld_out + 'included_movies.txt', included_movies, fmt = '%s')

    return thres_mc


def filter_clusters(thresholded_image,cluster_threshold):

    imgdata = np.nan_to_num(thresholded_image.get_fdata(),0)
    labeled = dolabel(imgdata != 0, background=0, connectivity = 2)
    labprops = regionprops(labeled)

    clusters = []
    for c in labprops:
        if c.area < cluster_threshold:
            labeled[labeled == c.label] = 0
        else:
            clusters.append([c.label,c.area])
        
    clusters = np.array(clusters)
    
    filtered_data = imgdata*(labeled > 0)
    filtered_image = nib.Nifti2Image(filtered_data,header=thresholded_image.header,affine=thresholded_image.affine)
    
    return filtered_image,clusters

def do_RM_Ftest(iscdata):
    
    n_sub,n_vox = np.shape(iscdata['M'])

    F = np.zeros(n_vox)
    pF = np.zeros_like(F)

    for nv in range(n_vox):
        df = pd.DataFrame()
        for c in iscdata:
            vox_data = iscdata[c][:,nv]

            this_df = (pd
               .DataFrame(np.array([np.arange(n_sub),vox_data]).T,columns=['sub_id','z'])
               .assign(cond = c)
               .assign(sub_id = lambda d : d['sub_id'].astype(int))
              )

            df  = df.append(this_df)

        aovrm = AnovaRM(df,'z', 'sub_id', within=['cond'])
        res = aovrm.fit()

        F[nv] = res.anova_table['F Value']
        pF[nv] = res.anova_table['Pr > F']

        print('{:07d} out of {:07d} done...'.format(nv,n_vox), end='\r', flush=True)

            
    return F,pF

In [None]:
froot = '/data00/layerfMRI/'
isc_fld = froot + 'analyses/ISC/'
av_mask_img = isc_fld + 'masks/average_mask_mni.nii.gz'
Nvox = int(np.sum(apply_mask(av_mask_img,av_mask_img)))

In [None]:
logsumm = (pd
           .read_csv(froot+ 'logs/log_summary.csv')
           .assign(fmri = lambda d:
                   froot + 
                   'regdata/sub_' + d['subject'].apply(lambda n: '{:02d}'.format(n)) +
                   '/ses_' + d['session'].apply(lambda n: '{:02d}'.format(n)) +
                   '/func/task_' + d['task'].apply(lambda n: '{:01d}'.format(n)) +
                   '_run_' + d['run'].apply(lambda n: '{:01d}'.format(n)) +
                   '_4D_MNI.nii.gz'
                  )
           .assign(fmri_missing = lambda d : d['fmri'].apply(lambda s: os.path.isfile(s) != True))
           
          )

In [None]:
froot+ 'logs/log_summary.csv'

In [None]:
#### IGNORE SUBJECT 8
logsumm = logsumm.loc[lambda d : d['subject'] != 8]

movies = {}
for m,df in logsumm.groupby('Type'):
    movies[m] = list(df['Title'].unique())
    
Nsub = logsumm['subject'].nunique()

In [None]:
smoothing = 6 # mm
if not os.path.isdir(isc_fld + f'isc_preloaded_input_{smoothing}mm/'):
    os.makedirs(isc_fld + f'isc_preloaded_input_{smoothing}mm/')

In [None]:
data = {}
data_mov = {}
data_mov_run = {}

for c in ['M','S']:
    data_mov[c] = {}
    data_mov_run[c] = {}


    for im,movfile in enumerate(movies[c]):

        mov = movfile.split('.')[0]
        data_mov_run[c][mov] = {}

        for run,rdf in logsumm.loc[lambda d : d['Title'] == movfile].groupby('run'):

            fpickle = isc_fld + f'isc_preloaded_input_{smoothing}mm/{c}_{mov}_run{run}.pickle'
            if os.path.isfile(fpickle):
                with open(fpickle, 'rb') as fid:
                    data_mov_run[c][mov][run] = pickle.load(fid)  
                print('{} loaded successfully'.format(fpickle))

            else:
        
                if rdf['Duration'].nunique() == 1:
                    lmovie = rdf['Duration'].iloc[0]
                else:
                    print('Multiple values found for the movie duration')
                    if rdf['ExpectedDuration'].nunique() == 1:
                        print('Picked the expected duration of the movie')
                        lmovie = rdf['ExpectedDuration'].iloc[0]
                    else:                        
                        print('Also the exnpected duration contains more than one value. Picking the minimum')
                        print(rdf['ExpectedDuration'].values)
                        lmovie = np.min(rdf['ExpectedDuration'].values)
                        print(lmovie)
                        
                print('Reading movie {} ({} frames)'.format(mov,lmovie))
                
                ## Check that start_TR + lmovie is not bigger than the niftis
                l_fmri = find_fmri_duration(rdf)
                start_TR_max = rdf.start_TR.max()
                if l_fmri < start_TR_max + lmovie:
                    print('Warning! Some movie size exceeds the number of recorded volumes')
                    lmovie = l_fmri - start_TR_max
                    print('length of the movie has been decreased to {}'.format(lmovie))                
                

                try:
                    data_mov_run[c][mov][run] = np.zeros((lmovie,Nvox,Nsub))

                    for i,row in rdf.sort_values('subject').reset_index().iterrows():

                        # Read data
                        st = time.time()
                        subjdata = apply_mask(
#                             row.fmri,
                                smooth_img(row.fmri,6),
                            av_mask_img
                        )[row.start_TR:row.start_TR+lmovie,:]        

                        # Standardize data
                        avdata = np.average(subjdata, axis = 0)
                        stdata = np.std(subjdata, axis = 0)
                        subjdata = (subjdata - avdata[None,:])/stdata[None,:]
                        subjdata[np.isnan(subjdata)] = 0
                        en = time.time()
                        print('Subject {} processed in {:.01f} s'.format(row.subject,en-st))

                        # Place it in the right spot in of the array
                        data_mov_run[c][mov][run][:,:,i] = subjdata

                    # Save it in a dedicated folder
                    with open(fpickle, 'wb') as fid:
                        pickle.dump(data_mov_run[c][mov][run], fid, pickle.HIGHEST_PROTOCOL)

                except:
                    print('Some error occurred at condition {} movie {} run {} for subject {}...'.format(c,mov,run,row.subject))
                    print('Skipping this whole run')


        # After looping through the runs I can concatenate them into an array
        data_mov[c][mov] = np.concatenate([data_mov_run[c][mov][run] for run in data_mov_run[c][mov]], axis = 0)
        print('Movie {} completed [{:02d}/{:02d}]'.format(mov,im+1,len(movies[c])))

    data[c] = np.concatenate([data_mov[c][mov] for mov in data_mov[c]], axis = 0)

In [None]:
for mov in data_mov[c]:
    print(mov,data_mov[c][mov].shape)

In [None]:
isc_type = 'loo' ## 'loo' or 'pair'
dopair = (isc_type == 'pair')

# fld_out = isc_fld + '{}_isc_{}mm/'.format(isc_type,smoothing)
fld_out = '/data00/layerfMRI/Github_repo/layerfMRI/analyses/dual_ISC/4figures/fig_2_results_ISC1/'
if not os.path.isdir(fld_out):
    os.makedirs(fld_out)

In [None]:
fld_out

In [None]:
iscdata = {}

for c in ['M','S']:

    iscdata[c] = isc(data[c], pairwise=dopair)
    iscdata[c][np.isnan(iscdata[c])] = 0

### One sample t-test

In [None]:
for c in ['M','S']:
    iscstat[c] = {'boot':{},'t-test':{}}
    
    iscstat[c]['boot']['m'],iscstat[c]['boot']['ci'],iscstat[c]['boot']['p'],iscstat[c]['boot']['d'] = bootstrap_isc(iscdata[c], n_bootstraps=5000, pairwise=dopair,summary_statistic='median')
    iscstat[c]['t-test']['t'],iscstat[c]['t-test']['p'] = ttest_1samp(iscdata[c],0)

### Repeated Measure F test

In [None]:
# c = 'M-S'
# iscstat[c] = {'rep-F'}
# n_sub,n_vox = np.shape(iscdata['M'])

# iscstat[c]['repF']['m'], iscstat[c]['repF']['p'] = do_RM_Ftest(iscdata)

In [None]:
p_thres = {'uncorrected':0.001}#,'FDR':0.05}
c_thres = {'uncorrected':50}#,'FDR':20}

In [None]:
thres_map = {}
for c in iscstat:
    thres_map[c] = {}
    for met in iscstat[c]:
        thres_map[c][met] = {}
        for t in p_thres:
            thres_map[c][met][t] = export_results(iscstat,c,met,t,fld_out,logsumm)

In [None]:
# Save OR map
t = 'uncorrected'

for met in ['boot','t-test']:

    M_OR_S = math_img('(im1>0) + (im2>0)',im1 = thres_map['M'][met][t], im2 = thres_map['S'][met][t])

    # Thresholded binary
    fout = fld_out + '{}ISC_{}_{}_{}_p{}_k{:02d}_bin.nii.gz'.format(isc_type,'M_OR_S',met,t,str(p_thres[t])[2:],c_thres[t])
    nib.save(M_OR_S, fout)

    # Figure
    fig,ax = plt.subplots(figsize = (9,3))
    niplt.plot_glass_brain(M_OR_S, figure=fig,axes=ax,colorbar = True)
    ax.set_title(r'{} ISC - {} correction ($\alpha = {}$, $\kappa = {}$)'.format(isc_type,t,p_thres[t],c_thres[t]))

    fout = fld_out + '{}ISC_{}_{}_{}_p{}_k{:02d}_bin.pdf'.format(isc_type,'M_OR_S',met,t,str(p_thres[t])[2:],c_thres[t])

    plt.savefig(fout)
    plt.close('all')