# improve across-session alignment

Problem: after fMRIprep, the sessions are aligned to MNI. However, they do not nicely align across sessions. 

Solution: pick a target session, and compute warpfield from average over runs and time from other session(s) to target session.

In [1]:
import os
import glob
import numpy as np
import nibabel as nb
from nipype.interfaces import fsl
from joblib import Parallel, delayed


In [2]:
# setup dirs
home_dir = '/home/shared/2018/visual/cerebellum_prf/'
der_dir = os.path.join(home_dir,'derivatives')
out_dir = os.path.join(home_dir,'derivatives','pp')
res_out_dir = os.path.join(out_dir,'res')

# setup subs
subs = ['02']
target_session = '03'
source_sessions = {
    '01':['01','03','02'],
    '02':['01','02','03','04'],
    '03':['01','02','03']
}
space = 'MNI152NLin2009cAsym' # 'T1w' 


In [3]:
def load_data(fn):
    img = nb.load(fn)
    data = img.get_data()

    return data
    

In [4]:
def compute_warp(fn,target_fn,postFix='fnirted'):
    
    flt = fsl.FNIRT()
    flt.inputs.in_file = fn
    flt.inputs.ref_file = target_fn
    flt.inputs.warped_file = fn.replace('.nii.gz','_%s.nii.gz'%postFix)
    flt.inputs.field_file = fn.replace('.nii.gz','_%s_field.nii.gz'%postFix)


    flt.run()

In [5]:
def apply_warp(fn,target_fn,field_file):
    
    flt = fsl.ApplyWarp()
    flt.inputs.ref_file = target_fn
    flt.inputs.in_file = fn
    flt.inputs.out_file = fn.replace('.nii.gz','_fnirted.nii.gz')
    flt.inputs.field_file = field_file
    flt.run()

### step 1: create average over runs and time for each session

In [6]:
# # create avg files
for sub in subs:
    for ses in source_sessions[sub]:
        print 'now computing run and time avg for sub %s, ses %s'%(sub,ses)

        sj_res_out_dir = os.path.join(res_out_dir,'sub-%s'%sub)

        fns = sorted(glob.glob(os.path.join(sj_res_out_dir,'*sub-%s_ses-%s*bold_space-%s_preproc_resampled.nii.gz'%(sub,ses,space))))

        all_data = Parallel(n_jobs=len(fns),verbose=9)(delayed(load_data)(fn)  for fn in fns)
        
        sample_img = nb.load(fns[0])
        
        print 'now computing mean over runs'
        avg_over_runs = np.nanmean(all_data,axis=0)
        new_img = nb.Nifti1Image(avg_over_runs,affine=sample_img.affine,header=sample_img.header)
        out_fn = os.path.join(sj_res_out_dir,'mean_over_runs_ses_%s.nii.gz'%ses)
        nb.save(new_img,out_fn)

        print 'now computing mean over time'
        avg_over_time = np.nanmean(avg_over_runs,axis=-1)
        new_img = nb.Nifti1Image(avg_over_time,affine=sample_img.affine,header=sample_img.header)
        out_fn = os.path.join(sj_res_out_dir,'mean_over_runs_timemean_ses_%s.nii.gz'%ses)
        nb.save(new_img,out_fn)


now computing run and time avg for sub 02, ses 01


[Parallel(n_jobs=13)]: Using backend LokyBackend with 13 concurrent workers.
[Parallel(n_jobs=13)]: Done   2 out of  13 | elapsed:   22.1s remaining:  2.0min
[Parallel(n_jobs=13)]: Done   4 out of  13 | elapsed:   25.6s remaining:   57.6s
[Parallel(n_jobs=13)]: Done   6 out of  13 | elapsed:   29.0s remaining:   33.9s
[Parallel(n_jobs=13)]: Done   8 out of  13 | elapsed:   31.9s remaining:   19.9s
[Parallel(n_jobs=13)]: Done  10 out of  13 | elapsed:   34.8s remaining:   10.4s
[Parallel(n_jobs=13)]: Done  13 out of  13 | elapsed:   39.3s finished


now computing mean over runs
now computing mean over time
now computing run and time avg for sub 02, ses 02


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   2 out of   8 | elapsed:   33.0s remaining:  1.6min
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:   34.5s remaining:   57.4s
[Parallel(n_jobs=8)]: Done   4 out of   8 | elapsed:   36.0s remaining:   36.0s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:   37.6s remaining:   22.6s
[Parallel(n_jobs=8)]: Done   6 out of   8 | elapsed:   39.4s remaining:   13.1s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:   43.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:   43.0s finished


now computing mean over runs
now computing mean over time
now computing run and time avg for sub 02, ses 03


[Parallel(n_jobs=10)]: Using backend LokyBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done   3 out of  10 | elapsed:   23.3s remaining:   54.4s
[Parallel(n_jobs=10)]: Done   5 out of  10 | elapsed:   26.4s remaining:   26.4s
[Parallel(n_jobs=10)]: Done   7 out of  10 | elapsed:   29.6s remaining:   12.7s
[Parallel(n_jobs=10)]: Done  10 out of  10 | elapsed:   33.9s finished


now computing mean over runs
now computing mean over time
now computing run and time avg for sub 02, ses 04


[Parallel(n_jobs=15)]: Using backend LokyBackend with 15 concurrent workers.
[Parallel(n_jobs=15)]: Done   2 out of  15 | elapsed:   32.5s remaining:  3.5min
[Parallel(n_jobs=15)]: Done   4 out of  15 | elapsed:   35.9s remaining:  1.6min
[Parallel(n_jobs=15)]: Done   6 out of  15 | elapsed:   39.2s remaining:   58.8s
[Parallel(n_jobs=15)]: Done   8 out of  15 | elapsed:   42.0s remaining:   36.7s
[Parallel(n_jobs=15)]: Done  10 out of  15 | elapsed:   45.0s remaining:   22.5s
[Parallel(n_jobs=15)]: Done  12 out of  15 | elapsed:   48.3s remaining:   12.1s
[Parallel(n_jobs=15)]: Done  15 out of  15 | elapsed:   52.7s finished


now computing mean over runs
now computing mean over time


### step 2: compute warpfield from source to target session

Let's do this for all source and target sessions simultaneously

In [7]:
all_source_fns = []
all_target_fns = []
for sub in subs:
    sj_res_out_dir = os.path.join(res_out_dir,'sub-%s'%sub)
    target_fn = os.path.join(sj_res_out_dir,'mean_over_runs_timemean_ses_%s.nii.gz'%target_session)
    
    for ses in source_sessions[sub]:
        
        source_fn = os.path.join(sj_res_out_dir,'mean_over_runs_timemean_ses_%s.nii.gz'%ses)
        all_source_fns.append(source_fn)
        all_target_fns.append(target_fn)

        
Parallel(n_jobs=len(all_source_fns),verbose=9)(delayed(compute_warp)(all_source_fns[i],all_target_fns[i])  for i in range(len(all_source_fns)))


[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:  4.3min remaining:  4.3min
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:  6.4min remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:  6.4min finished


[None, None, None, None]

### step 3: apply the appropriate warpfield to full timecourse imgs

In [8]:
for sub in subs:
    
    sj_res_out_dir = os.path.join(res_out_dir,'sub-%s'%sub)

    target_fn = os.path.join(sj_res_out_dir,'mean_over_runs_timemean_ses_%s.nii.gz'%target_session)
    
    for ses in source_sessions[sub]:

        field_file = os.path.join(sj_res_out_dir,'mean_over_runs_timemean_ses_%s_fnirted_field.nii.gz'%ses)
        
        fns = sorted(glob.glob(os.path.join(sj_res_out_dir,'sub-%s_ses-%s*bold_space-%s_preproc_resampled.nii.gz'%(sub,ses,space))))

        Parallel(n_jobs=len(fns),verbose=9)(delayed(apply_warp)(fn,target_fn,field_file) for fn in fns)


[Parallel(n_jobs=13)]: Using backend LokyBackend with 13 concurrent workers.
[Parallel(n_jobs=13)]: Done   2 out of  13 | elapsed:   52.7s remaining:  4.8min
[Parallel(n_jobs=13)]: Done   4 out of  13 | elapsed:  1.1min remaining:  2.5min
[Parallel(n_jobs=13)]: Done   6 out of  13 | elapsed:  1.1min remaining:  1.3min
[Parallel(n_jobs=13)]: Done   8 out of  13 | elapsed:  1.2min remaining:   43.8s
[Parallel(n_jobs=13)]: Done  10 out of  13 | elapsed:  1.2min remaining:   21.4s
[Parallel(n_jobs=13)]: Done  13 out of  13 | elapsed:  1.2min finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   2 out of   8 | elapsed:   46.0s remaining:  2.3min
[Parallel(n_jobs=8)]: Done   3 out of   8 | elapsed:   47.1s remaining:  1.3min
[Parallel(n_jobs=8)]: Done   4 out of   8 | elapsed:   47.1s remaining:   47.1s
[Parallel(n_jobs=8)]: Done   5 out of   8 | elapsed:   47.6s remaining:   28.5s
[Parallel(n_jobs=8)]: Done   6 out of   8 | elapsed: