# Regress out nuissance variables

In [None]:
import os
import glob
import numpy as np
import nibabel as nb
import pandas as pd
from sklearn.decomposition import PCA
from nistats.regression import OLSModel
from joblib import Parallel, delayed



In [3]:
# setup dirs
home_dir = '/home/shared/2018/visual/cerebellum_prf/'
der_dir = os.path.join(home_dir,'derivatives')
out_dir = os.path.join(home_dir,'derivatives','pp')
in_dir = os.path.join(out_dir,'sgtf')
ng_out_dir = os.path.join(out_dir,'ng')
if not os.path.isdir(ng_out_dir): os.mkdir(ng_out_dir)


subs = ['02']#,'02','03']
sess = {
    '01':['01','03','02'],
    '02':['01','02','03','04'],
    '03':['01','02','03']
}

TR = 1.5
space = 'MNI152NLin2009cAsym' # 'T1w'

n_components =5 # number of pca components to regress out

# grab these variables from nuissance file:
varr = [
    'stdDVARS',
    'non-stdDVARS',
    'vx-wisestdDVARS',
    'FramewiseDisplacement',
    'aCompCor00',
    'aCompCor01',
    'aCompCor02',
    'aCompCor03',
    'aCompCor04',
    'aCompCor05',
    'X',
    'Y',
    'Z',
    'RotX',
    'RotY',
    'RotZ']

In [4]:
def savgol_filter(data, polyorder=3, deriv=0, window_length = 120,TR=1.5):
    """ Applies a savitsky-golay filter to a nifti-file.

    Fits a savitsky-golay filter to a 4D fMRI nifti-file and subtracts the
    fitted data from the original data to effectively remove low-frequency
    signals.

    """

    from scipy.signal import savgol_filter

    window = np.int(window_length / TR)

    # Window must be odd
    if window % 2 == 0:
        window += 1

    data_filt = savgol_filter(data, window_length=window, polyorder=polyorder,
                              deriv=deriv, mode='nearest')

    data_filtered = data - data_filt + data_filt.mean()

    return data_filtered

In [5]:
def perform_ng(in_dir,out_dir,fn,ses):
    # determine out fn
    out_fn = fn.replace(in_dir,out_dir).replace('.nii.gz','_ng.nii.gz')

    if not os.path.isfile(out_fn):

        run = fn.split('/')[-1].split('_')[3].split('-')[-1]

        # load nuissance file
        fmriprepdir = 'fmriprep_ses%s'%ses
        df_fn = os.path.join(der_dir,fmriprepdir,'fmriprep','sub-%s'%sub,'ses-%s'%ses,'func','sub-%s_ses-%s_task-prf_run-%s_bold_confounds.tsv'%(sub,ses,run))
        df = pd.DataFrame.from_csv(df_fn, sep='\t', header=0,index_col=None)

        # get the wanted variables and do stuff with it
        nuissances = []
        for var in varr:

            # get wanted nuissance variables
            data = np.array(df[var])
            # fill in nans
            data[data=='n/a'] = np.nan
            # cast to float
            data = data.astype('float32')
            # median fill nan values (i.e. first value )
            data[np.isnan(data)] = np.nanmedian(data)
            # temporally filter 
            filtered_data = savgol_filter(data)
            # z-score (so that explained variance ratios is interpretable)
            filtered_data_z = (filtered_data - np.mean(filtered_data)) / np.std(filtered_data)
            # and append
            nuissances.append(filtered_data_z)

        nuissances = np.array(nuissances)    

        # now do pca and grab first 5:
        pca = PCA(n_components=n_components)  
        pcas = pca.fit_transform(nuissances.T)

        # now load data
        img = nb.load(fn)
        data = np.nan_to_num(img.get_data())
        datashape = data.shape

        # do nuissance regression
        dm = np.hstack([np.ones((pcas.shape[0],1)),pcas]) # add intercept
        model = OLSModel(dm)
        fit = model.fit(data.reshape(-1,datashape[-1]).T)
        resid = fit.resid.T.reshape(datashape)
        resid += np.mean(data,axis=-1)[:,:,:,np.newaxis] # re-add the signal offset which was regressed out by the intercept
        
        # save
        new_img = nb.Nifti1Image(resid,affine=img.affine,header=img.header)
        nb.save(new_img,out_fn)

In [6]:
for sub in subs:
    for ses in sess[sub]:
        
        print('now removing nuisances from sub %s, ses %s'%(sub,ses))

        # setup output
        sj_ng_out_dir = os.path.join(ng_out_dir,'sub-%s'%sub)
        if not os.path.isdir(sj_ng_out_dir): os.mkdir(sj_ng_out_dir)

        # get input fns
        sj_in_dir = os.path.join(in_dir,'sub-%s'%sub)
        fns = sorted(glob.glob(os.path.join(sj_in_dir,'sub-%s_ses-%s*bold_space-%s_preproc_resampled_fnirted_smoothed_sgtf.nii.gz'%(sub,ses,space))))

        # loop over runs and perform nuissance regresison per run
        Parallel(n_jobs=6,verbose=9)(delayed(perform_ng)(sj_in_dir,sj_ng_out_dir,fn,ses) for fn in fns)

print 'done!'

now removing nuisances from sub 02, ses 01


[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=6)]: Done   4 out of  13 | elapsed:  1.5min remaining:  3.4min
[Parallel(n_jobs=6)]: Done   6 out of  13 | elapsed:  1.5min remaining:  1.8min
[Parallel(n_jobs=6)]: Done   8 out of  13 | elapsed:  2.7min remaining:  1.7min
[Parallel(n_jobs=6)]: Done  10 out of  13 | elapsed:  2.7min remaining:   49.4s
[Parallel(n_jobs=6)]: Done  13 out of  13 | elapsed:  3.8min finished
[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.


now removing nuisances from sub 02, ses 02


[Parallel(n_jobs=6)]: Done   2 out of   8 | elapsed:  1.4min remaining:  4.2min
[Parallel(n_jobs=6)]: Done   3 out of   8 | elapsed:  1.4min remaining:  2.4min
[Parallel(n_jobs=6)]: Done   4 out of   8 | elapsed:  1.4min remaining:  1.4min
[Parallel(n_jobs=6)]: Done   5 out of   8 | elapsed:  1.4min remaining:   51.9s
[Parallel(n_jobs=6)]: Done   6 out of   8 | elapsed:  1.5min remaining:   29.0s
[Parallel(n_jobs=6)]: Done   8 out of   8 | elapsed:  2.6min remaining:    0.0s
[Parallel(n_jobs=6)]: Done   8 out of   8 | elapsed:  2.6min finished
[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.


now removing nuisances from sub 02, ses 03


[Parallel(n_jobs=6)]: Done   3 out of  10 | elapsed:  1.4min remaining:  3.3min
[Parallel(n_jobs=6)]: Done   5 out of  10 | elapsed:  1.4min remaining:  1.4min
[Parallel(n_jobs=6)]: Done   7 out of  10 | elapsed:  2.6min remaining:  1.1min
[Parallel(n_jobs=6)]: Done  10 out of  10 | elapsed:  2.7min finished
[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.


now removing nuisances from sub 02, ses 04


[Parallel(n_jobs=6)]: Done   6 out of  15 | elapsed:  1.4min remaining:  2.0min
[Parallel(n_jobs=6)]: Done   8 out of  15 | elapsed:  2.7min remaining:  2.4min
[Parallel(n_jobs=6)]: Done  10 out of  15 | elapsed:  2.7min remaining:  1.4min
[Parallel(n_jobs=6)]: Done  12 out of  15 | elapsed:  2.8min remaining:   41.5s


done!


[Parallel(n_jobs=6)]: Done  15 out of  15 | elapsed:  3.9min finished
