In [None]:
'''
[ADD DESCRIPTION]
'''

In [None]:
import scipy.io as sio
import os
import sys
import numpy as np
import pandas as pd
import time
from joblib import Parallel, delayed
from scipy import stats
import matplotlib.pyplot as plt
import statsmodels.stats.multitest as multi
sys.path.append('/afs/dbic.dartmouth.edu/usr/wheatley/jd/')
from phaseScramble import *
from CircleShift import *
from scipy.stats import norm
from sklearn import preprocessing
%matplotlib inline

In [None]:
# mark starting time
startTime = time.time()

# participant IDs
dbicIDs = np.array(["sid000007", "sid000009", "sid000560", "sid000535", "sid000102", "sid000416", "sid000499", "sid000142"])
cbsIDs = np.array(["hid000002", "hid000003", "hid000004", "hid000005", "hid000006", "hid000007", "hid000008", "hid000009"])

# pair numbers
pairNums = np.arange(2,len(dbicIDs)+2)

# make subject list data frame
subList = pd.DataFrame(np.transpose(np.tile(pairNums, (1, 2))),columns=['pairNum'])
subList['subID'] = np.concatenate((dbicIDs, cbsIDs), axis=0)
print(subList)

# get number of participants
numSubs = len(pairNums) * 2

# set number of permutations
permutations = 100

# indicate whether or not we're debugging (if so, use a small subset of TRs and voxels to speed things up)
debug = False
debugTRs = 201
debugVox = 800

# set alpha for permutation tests
alpha = 0.05

# number of jobs for joblib
numJobs = 32

# set joblib verbosity -- don't go over 50 lest ye print like a million outputs and slow everything down
verbosity = 50

# select whether or not to save output
saveOutput = True

# set fitting distribution to normal
dist = getattr(stats, 'norm')

# define task names
taskNames = np.array(['listening','reading'])

# use joblib or not
parallel = True

# select whether or not to use the circle shifting method
circShift = True

# indicate whether or not to pop out some example plots of good vs bad null distribution fits
examplePlots = False

# select whether or not to normalize timeseries before ISC
normalize = True

# initialize the saveTag variable
saveTag = ''

In [None]:
# set data folder
folder = '/afs/dbic.dartmouth.edu/usr/wheatley/jd/control_tasks/'

# loop through participants...
boldData = [[]] * 2
for TASK in [0,1]: #for each task, listening, then reading

    # preallocate task data list
    boldData[TASK] = [[]] * numSubs

    for SUB in range(numSubs):

        # get file name
        fileName = folder + 'sub-' + subList['subID'][SUB] + '_ses-pair0' + str(subList['pairNum'][SUB]) + '_task-storytelling' + str(TASK + 3) + '_run-0' + str(TASK + 3) + '_bold_space-MNI152NLin2009cAsym_preproc_nuisRegr_2021_interp.mat'

        #load data
        tmp = sio.loadmat(fileName) #load file
        boldData[TASK][SUB] = tmp['tseries'] #get timeseries data
        print('loaded ' + str(boldData[TASK][SUB].shape[0]) + ' x ' + str(boldData[TASK][SUB].shape[1]) + ' timeseries for ' + taskNames[TASK] + ' task, sub ' + subList['subID'][SUB])

        if normalize:
            boldData[TASK][SUB] = preprocessing.normalize(boldData[TASK][SUB])
            print('normalizing timeseries')
            saveTag = 'norm_'

        if debug:

            # take a small subset of each timeseries to speed up the debugging process
            boldData[TASK][SUB] = boldData[TASK][SUB][np.ix_(np.arange(debugTRs),range(debugVox))]


In [None]:
def parallelSubWrapper(SUB,numSubs,boldData,permutations,circShift):
    """
    function to use with joblib below to run permutation tests for participants in parallel
    :param SUB: participant index
    :param numSubs: total number of participants in the analysis
    :param boldData: timeseries data [timepoints x voxels]
    :param permutations: number of permutations to use
    :param circShift: boolean indicating whether or not (True or False, respectively) to use the circle shift method to "scramble" timeseries
    :return:
        corrData: real voxelwise correlations between participant and the mean across all other participants [voxels x 1]
        nullCorr: null voxelwise correlations between participant and the mean across all other participants [permutations x voxels]
        permTest: array containing various results from permutation test and normal fits to the null distribution (see permTestMap variable above and comments below)
            structure of permTest:
            permTest[0]: permutation test p-values [voxels x 1]
            permTest[1]: FDR corrected permutation test p-values [voxels x 1]
            permTest[2]: p-value summary stats
                permTest[2][0]: number of voxels with uncorrected p-vals < alpha
                permTest[2][1]: proportion of voxels with uncorrected p-vals < alpha
                permTest[2][2]: number of voxels with FDR corrected p-vals < alpha
                permTest[2][3]: proportion of voxels with FDR corrected p-vals < alpha
            permTest[3]: null dist normal fit parameters [voxels x 1]
                permTest[3][voxel][0]: mean
                permTest[3][voxel][1]: standard deviation
            permTest[4]: Kolmogorov–Smirnov goodness of fit test results [voxels x 1]
                permTest[4][voxel][0]: KS test statistic
                permTest[4][voxel][1]: p-value (values less than alpha suggest poor fit)
            permTest[5]: logical vector for KS goodness of fit where 0 = good fit, 1 = bad fit [voxels x 1]
            permTest[6]: table indicating number and proportion of bad normal fits to the perm-based null distribution [1 x 2 table]
    """

    # get mean of data from all participants EXCEPT the current participant
    otherSubs = np.arange(0,numSubs)
    otherSubs = np.delete(otherSubs,SUB)
    groupMean = np.mean([boldData[i] for i in otherSubs], axis=0)

    # get REAL correlation between current participant and groupMean
    corrData = fastColumnCorr(boldData[SUB], groupMean)

    # generate null correlations
    nullCorr = [[]] * permutations
    feedbackMultiple = round(permutations / 10) # get number to use as a flag to generate feedback roughly 10 times while null correlations are being generated
    for PERM in range(permutations): # for each permutation...
        if circShift:
            nullCorr[PERM] = fastColumnCorr(Circle_Shift(boldData[SUB]),groupMean)
        else:
            nullCorr[PERM] = fastColumnCorr(phase_scrambling(boldData[SUB]),groupMean)
        if PERM % feedbackMultiple == 0: # if the current permutation is zero or a multiple of feedbackMultiple
            print('getting null correlation for SUB ' + str(SUB + 1) + ', permutation ' + str(PERM + 1)) #provide feedback about which subject and which permutation we're on

    # convert permutation data to a permutations x voxels array
    nullCorr = np.asarray(nullCorr)

    # preallocate sublists of permTest array
    permTest = [[]] * len(permTestMap)

    # get permutation test p-values and apply FDR correction
    permTest[0] = [[]] * nullCorr.shape[1] # permutation test p-values
    for VOX in range(nullCorr.shape[1]): # for each voxel...
        permTest[0][VOX] = len(np.where(abs(nullCorr[:,VOX]) > abs(corrData[VOX]))[0]) / float(permutations) # proportion of permutations where absolute value of null correlation is greater than absolute value of real correlation
    permTest[1] = multi.fdrcorrection(permTest[0], alpha = alpha / 2) # FDR correction - logical vector with length = # voxels
    permTest[2] = [[]] * 4 # preallocate lists for permutation test summary info
    permTest[2][0] = len(np.where(np.array(permTest[0]) < alpha / 2)[0]) # number of voxels that show significant correlations with the group
    permTest[2][1] = permTest[2][0] / nullCorr.shape[1] # proportion of voxels that show significant correlations with the group
    permTest[2][2] = np.count_nonzero(permTest[1][0]) # number of voxels that show significant FDR CORRECTED correlations with the group
    permTest[2][3] = len(np.where(permTest[1][0])[0]) / nullCorr.shape[1] # proportion of voxels that show significant FDR CORRECTED correlations with the group

    # fit normal distributions to null data
    permTest[3] = [[]] * nullCorr.shape[1] # preallocate list for voxelwise normal dist parameters
    permTest[4] = [[]] * nullCorr.shape[1] # preallocate list for voxelwise Kolmogorov–Smirnov test results
    permTest[5] = np.zeros((nullCorr.shape[1],), dtype=int) # preallocate voxelwise array to indicate bad fits
    for VOX in range(nullCorr.shape[1]): # for each voxel...
        permTest[3][VOX] = dist.fit(nullCorr[:,VOX]) # fit normal distribution
        permTest[4][VOX] = stats.kstest(nullCorr[:,VOX], "norm", permTest[3][VOX]) # measure goodness of fit
        if permTest[4][VOX][1] < alpha: # if a voxel has a bad normal fit...
            permTest[5][VOX] = 1 # flag it
    permTest[6] = pd.DataFrame(data={'numBadFits': [sum(permTest[5])], 'propBadFits': [sum(permTest[5]) / len(permTest[5])]}) # number and proportion of voxels with bad normal fits to the null distribution

    return corrData, nullCorr, permTest

In [None]:
# preallocate correlation lists
corrData = [[]] * 2
nullCorr = [[]] * 2
permTest = [[]] * 2

# map of what each participant-specific sublist in permTestMap contains
permTestMap = ['permPval','FDR_corrected_permPval','propSigFDRvoxels','normParams','KStest','badFits','badFitsSummary']

# Voxelwise correlation between participant and the rest of the group (mean)
for TASK in [0,1]: # for each task...

    corrData[TASK] = [[]] * numSubs
    nullCorr[TASK] = [[]] * numSubs
    permTest[TASK] = [[]] * numSubs

    # run ISC
    if parallel:

        # run joblib
        tmp = Parallel(n_jobs=numJobs, verbose=verbosity)(delayed(parallelSubWrapper)(SUB,numSubs,boldData[TASK],permutations,circShift) for SUB in range(numSubs))

        # assign joblib outputs
        for SUB in range(numSubs):
            corrData[TASK][SUB] = tmp[SUB][0]
            nullCorr[TASK][SUB] = tmp[SUB][1]
            permTest[TASK][SUB] = tmp[SUB][2]
    else:

        for SUB in range(numSubs):
            corrData[TASK][SUB], nullCorr[TASK][SUB], permTest[TASK][SUB] = parallelSubWrapper(SUB,numSubs,boldData[TASK],permutations,circShift)

    # print some permutation test and goodness of fit summary info
    print('\nFinished processing participant ' + str(SUB + 1) + ' of ' + str(numSubs) + ' for task ' + str(TASK + 3))
    print('% voxels with FDR corrected significant correlation with group: ' + str((permTest[TASK][SUB][2][1]) * 100) + '%')
    print('% voxels for which null dist. was normal: ' + str((1 - permTest[TASK][SUB][6].iloc[0,1]) * 100) + '%')

In [None]:
# get group level null distributions for each voxel
if numSubs > 1: # if we're running the analysis on more than one participant...

    # preallocate lists for each task
    groupNull = [[]] * 2
    groupFitData = [[]] * 2

    # for each task...
    for TASK in [0,1]:

        #feedback
        print('starting group level null distribution fits for the ' + taskNames[TASK] + ' task')

        # preallocate sublists for groupNull array
        groupNull[TASK] = [[]] * 5

        # concatenate data from the first two participants to get things started
        groupNull[TASK][0] = np.concatenate((nullCorr[TASK][0],nullCorr[TASK][1]),axis=0)

        # concatenate data from any remaining participants
        if numSubs > 2: # if we're running the analysis on more than two participants...
            for SUB in range(2,numSubs): # for each participant...
                 groupNull[TASK][0] = np.concatenate((groupNull[TASK][0],nullCorr[TASK][SUB]),axis=0) # concatenate

        # fit normal distribution to group null
        groupNull[TASK][1] = [[]] * groupNull[TASK][0].shape[1] # preallocate list for each voxel
        groupNull[TASK][2] = [[]] * groupNull[TASK][0].shape[1] # preallocate list for each voxel
        groupNull[TASK][3] = np.zeros((groupNull[TASK][0].shape[1],), dtype=int) # preallocate vector of zeros (length = voxels) for flagging voxels with bad fits
        for VOX in range(groupNull[TASK][0].shape[1]): # for each voxel... (i.e., each column in groupNull[TASK][0])
            groupNull[TASK][1][VOX] = dist.fit(groupNull[TASK][0][:,VOX]) # fit normal distribution
            groupNull[TASK][2][VOX] = stats.kstest(groupNull[TASK][0][:,VOX], "norm", groupNull[TASK][1][VOX]) # get KS goodness of fit
            if groupNull[TASK][2][VOX][1] < alpha:
                groupNull[TASK][3][VOX] = 1
        groupNull[TASK][4] = pd.DataFrame(data={'numBadFits': [sum(groupNull[TASK][3])], 'propBadFits': [sum(groupNull[TASK][3]) / len(groupNull[TASK][3])]}) # number and proportion of voxels with bad normal fits to the group null distribution

        # make another variable with just the group level fit parameters, goodness of fit results, and summary measures
        groupFitData[TASK] = [[]] * 3
        groupFitData[TASK] = [[]] * 3
        groupFitData[TASK][0] = groupNull[TASK][1] # fit parameters
        groupFitData[TASK][1] = groupNull[TASK][2][VOX] # KS test results
        groupFitData[TASK][2] = groupNull[TASK][4]

        #feedback
        print('*** group level null fits ***')
        print(groupFitData[TASK][2])

In [None]:
endTime = time.time()
duration = (endTime - startTime) / 60 #[min]
print('control tasks ISC duration: ' + str(duration))

In [None]:
if saveOutput:
    import pickle
    if circShift:
        saveFile = folder + 'controlISC_cShift_' + saveTag + str(permutations) + 'perm.pkl'
        permTest_c = permTest
        corrData_c = corrData
        groupFitData_c = groupFitData
        duration_c = duration
        with open(saveFile,'wb') as f:  # Python 3: open(..., 'wb')
            pickle.dump([permTest_c, corrData_c, groupFitData_c, duration_c], f, protocol=4)
    else:
        saveFile = folder + 'controlISC_pScram_' + saveTag + str(permutations) + 'perm.pkl'
        with open(saveFile,'wb') as f:  # Python 3: open(..., 'wb')
            pickle.dump([permTest, corrData, groupFitData, duration], f, protocol=4)

In [None]:
# what load syntax should look like:
# folder = '/afs/dbic.dartmouth.edu/usr/wheatley/jd/control_tasks/'
# loadFile = folder + 'controlISC_cShift_200perm.pkl'
# import pickle
# with open(loadFile, 'rb') as f:
#     permTest, corrData, groupFitData, duration = pickle.load(f)

In [None]:
# for TASK in [0,1]:
#     print('********** ' + taskNames[TASK] + 'task **********')
#     for SUB in range(len(permTest[TASK])):
#          print(permTest[TASK][SUB][2])

print(groupFitData[0][2])
print(groupFitData[1][2])

In [None]:
if examplePlots:

    TASK = 0 # arbitrary
    SUB = 1 # arbitrary
    numVox = len(permTest[TASK][SUB][4]) # get number of voxels

    pVals = permTest[TASK][SUB][0]
    min_value = min(pVals)
    min_index = pVals.index(min_value)
    print(min_index)

    # convert KS test statistic values to numpy array
    KSval = np.empty([numVox, 1])
    for VOX in range(len(permTest[TASK][SUB][4])):
        KSval[VOX] = permTest[TASK][SUB][4][VOX][0]

    # get min test statistic (should be good fit)
    min_i = np.unravel_index(np.argmin(KSval, axis=None), KSval.shape)[0]
    print(min_i)

    # get max test statistic (should be bad fit)
    max_i = np.unravel_index(np.argmax(KSval, axis=None), KSval.shape)[0]
    print(max_i)

    # plot good fit
    for QUAL in [min_i, max_i, 4210]:
        VOX = QUAL
        mu = permTest[TASK][SUB][3][VOX][0]
        std = permTest[TASK][SUB][3][VOX][1]
        xmin = mu - 4*std
        xmax = mu + 4*std
        x = np.linspace(xmin, xmax, 100)
        p = stats.norm.pdf(x, mu, std)

        plt.figure(facecolor='white')

        # Generate some data for this demonstration.
        data = nullCorr[TASK][SUB][:,VOX]

        # Fit a normal distribution to the data:
        mu, std = norm.fit(data)

        # Plot the histogram.
        plt.hist(data, bins=25, density=True, alpha=0.6, color='m')

        # Plot the PDF.
        xmin, xmax = plt.xlim()
        x = np.linspace(xmin, xmax, 100)
        p = norm.pdf(x, mu, std)
        plt.plot(x, p, 'k', linewidth=2)
        ksStat = round(permTest[TASK][SUB][4][VOX][0] * 100) / 100
        ksPval = round(permTest[TASK][SUB][4][VOX][1] * 100) / 100
        title = "KS test stat = " + str(ksStat) + ", KS pVal = " + str(ksPval) + ", perm pVal = " + str(pVals[VOX])
        plt.title(title)

        x2 = corrData[TASK][SUB][VOX]
        yMax = plt.gca().get_ylim()[1]
        plt.plot([x2, x2], [0, yMax], 'y', linewidth=3)

        plt.show()
