In [1]:
#start ipcluster controller
from ipyparallel import Client
rc = Client()

In [2]:
TR = 2

In [3]:
def mask(d, raw_d=None, nskip=3, mask_bad_end_vols=True):
    mn = d[:,:,:,nskip:].mean(3)
    masked_data, mask = median_otsu(mn, 3, 2)
    mask = np.concatenate((np.tile(True, (d.shape[0], d.shape[1], d.shape[2], nskip)),
                           np.tile(np.expand_dims(mask==False, 3), (1,1,1,d.shape[3]-nskip))),
                           axis=3)
    if mask_bad_end_vols:
        # Some runs have corrupt volumes at the end (e.g., mux scans that are stopped prematurely). Mask those too.
        # But... motion correction might have interpolated the empty slices such that they aren't exactly zero.
        # So use the raw data to find these bad volumes.
        # 2015.10.29 RFD: this caused problems with some non-mux EPI scans that (inexplicably)
        # have empty slices at the top of the brain. So we'll disable it for now.
        if raw_d!=None:
            slice_max = raw_d.max(0).max(0)
        else:
            slice_max = d.max(0).max(0)
        bad = np.any(slice_max==0, axis=0)
        # We don't want to miss a bad volume somewhere in the middle, as that could be a valid artifact.
        # So, only mask bad vols that are contiguous to the end.
        mask_vols = np.array([np.all(bad[i:]) for i in range(bad.shape[0])])
    # Mask out the skip volumes at the beginning
    mask_vols[0:nskip] = True
    mask[:,:,:,mask_vols] = True
    brain = np.ma.masked_array(d, mask=mask)
    good_vols = np.logical_not(mask_vols)
    return brain,good_vols

In [4]:
def estimate_motion(nifti_image):
    # BEGIN STDOUT SUPRESSION
    actualstdout = sys.stdout
    sys.stdout = open(os.devnull,'w')
    # We want to use the middle time point as the reference. But the algorithm does't allow that, so fake it.
    ref_vol = nifti_image.shape[3]/2 + 1
    ims = nb.four_to_three(nifti_image)
    reg = Realign4d(nb.concat_images([ims[ref_vol]] + ims),tr=TR) # in the next release, we'll need to add tr=1.

    reg.estimate(loops=3) # default: loops=5
    aligned = reg.resample(0)[:,:,:,1:]
    sys.stdout = actualstdout
    # END STDOUT SUPRESSION
    abs_disp = []
    rel_disp = []
    transrot = []
    prev_T = None
    # skip the first one, since it's the reference volume
    for T in reg._transforms[0][1:]:
        # get the full affine for this volume by pre-multiplying by the reference affine
        #mc_affine = np.dot(ni.get_affine(), T.as_affine())
        transrot.append(T.translation.tolist()+T.rotation.tolist())
        # Compute the mean displacement
        # See http://www.fmrib.ox.ac.uk/analysis/techrep/tr99mj1/tr99mj1/node5.html
        # radius of the spherical head assumption (in mm):
        R = 80.
        # The center of the volume. Assume 0,0,0 in world coordinates.
        # Note: it might be better to use the center of mass of the brain mask.
        xc = np.matrix((0,0,0)).T
        T_error = T.as_affine() - np.eye(4)
        A = np.matrix(T_error[0:3,0:3])
        t = np.matrix(T_error[0:3,3]).T
        abs_disp.append(np.sqrt( R**2. / 5 * np.trace(A.T * A) + (t + A*xc).T * (t + A*xc) ).item())
        if prev_T!=None:
            T_error = T.as_affine() - prev_T.as_affine() # - np.eye(4)
            A = np.matrix(T_error[0:3,0:3])
            t = np.matrix(T_error[0:3,3]).T
            rel_disp.append(np.sqrt( R**2. / 5 * np.trace(A.T * A) + (t + A*xc).T * (t + A*xc) ).item())
        else:
            rel_disp.append(0.0)
        prev_T = T
    return aligned,np.array(abs_disp),np.array(rel_disp),np.array(transrot)

In [5]:
def find_spikes(d, spike_thresh):
    slice_mean = d.mean(axis=0).mean(axis=0)
    t_z = (slice_mean - np.atleast_2d(slice_mean.mean(axis=1)).T) / np.atleast_2d(slice_mean.std(axis=1)).T
    spikes = np.abs(t_z)>spike_thresh
    spike_inds = np.transpose(spikes.nonzero())
    # mask out the spikes and recompute z-scores using variance uncontaminated with spikes.
    # This will catch smaller spikes that may have been swamped by big ones.
    d.mask[:,:,spike_inds[:,0],spike_inds[:,1]] = True
    slice_mean2 = d.mean(axis=0).mean(axis=0)
    t_z = (slice_mean - np.atleast_2d(slice_mean.mean(axis=1)).T) / np.atleast_2d(slice_mean2.std(axis=1)).T
    spikes = np.logical_or(spikes, np.abs(t_z)>spike_thresh)
    spike_inds = np.transpose(spikes.nonzero())
    return((spike_inds, t_z))

In [6]:
def compute_qa(ni, tr, spike_thresh=6., nskip=4):
    brain,good_vols = mask(ni.get_data(), nskip=nskip)
    t = np.arange(0.,brain.shape[3]) * tr
    # Get the global mean signal and subtract it out for spike detection
    global_ts = brain.mean(0).mean(0).mean(0)
    # Simple z-score-based spike detection
    spike_inds,t_z = find_spikes(brain - global_ts, spike_thresh)
    # Compute temporal snr on motion-corrected data,
    aligned,abs_disp,rel_disp,transrot = estimate_motion(ni)
    brain_aligned = np.ma.masked_array(aligned.get_data(), brain.mask)
    # Remove slow-drift (3rd-order polynomial) from the variance
    global_ts_aligned = brain_aligned.mean(0).mean(0).mean(0)
    global_trend = np.poly1d(np.polyfit(t[good_vols], global_ts_aligned[good_vols], 3))(t)
    tsnr = brain_aligned.mean(axis=3) / (brain_aligned - global_trend).std(axis=3)
    # convert rotations to degrees
    transrot[:,3:] *= 180./np.pi
    return aligned,abs_disp,rel_disp,transrot, tsnr,global_ts,t_z,spike_inds

In [12]:
def plot_figs(im_id, median_tsnr,aligned,abs_disp,rel_disp,transrot, tsnr,global_ts,t_z,spike_inds):
    out_dir = im_id[:-7]
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)


    #plot displacement
    disp = pd.DataFrame({'Displacement (mm)':list(abs_disp) + list(rel_disp),
                        'kind':['abs']*len(abs_disp)+['rel']*len(rel_disp),
                       'time':range(len(abs_disp))*2,
                       'subject':[0]*2*len(abs_disp)})
    sns.tsplot(value = 'Displacement (mm)',time = 'time',unit = 'subject',condition = 'kind',data = disp)
    out_f = out_dir + '/displacement.png'
    plt.savefig(out_f)
    plt.close()

    #plot translations
    trans = pd.DataFrame({'Translations (mm)':list(transrot[:,0]) + list(transrot[:,1]) + list(transrot[:,2]),
                        'kind':['x']*len(abs_disp)+['y']*len(rel_disp)+['z']*len(rel_disp),
                       'time':range(len(abs_disp))*3,
                       'subject':[0]*3*len(abs_disp)})
    sns.tsplot(value = 'Translations (mm)',time = 'time',unit = 'subject',condition = 'kind',data = trans)
    out_f = out_dir + '/translations.png'
    plt.savefig(out_f)
    plt.close()

    #plot rotations
    transrot[:,3:] *= 180./np.pi
    rot = pd.DataFrame({'Rotations (deg)':list(transrot[:,3]) + list(transrot[:,4]) + list(transrot[:,5]),
                        'kind':['roll']*len(abs_disp)+['pitch']*len(rel_disp)+['yaw']*len(rel_disp),
                       'time':range(len(abs_disp))*3,
                       'subject':[0]*3*len(abs_disp)})
    sns.tsplot(value = 'Rotations (deg)',time = 'time',unit = 'subject',condition = 'kind',data = rot)
    out_f = out_dir + '/rotations.png'
    plt.savefig(out_f)
    plt.close()

    #Plot signal intensity
    ts = global_ts.data[np.logical_not(global_ts.mask)]
    ts = scipy.stats.zscore(ts)
    signal = pd.DataFrame({'Intensity (z)':ts, 
                           'unit': ['TR']*len(ts),
                           'TR':range(len(ts))})
    sns.tsplot(data = signal, value = 'Intensity (z)',time = 'TR',unit = 'unit')
    out_f = out_dir + '/signal.png'
    plt.title('TSNR = ' + str(median_tsnr))
    plt.savefig(out_f)
    plt.close()


In [13]:
def run_QA(im_id):
    im = nb.load(im_id)
    aligned,abs_disp,rel_disp,transrot, tsnr,global_ts,t_z,spike_inds = compute_qa(im, TR)
    median_tsnr = np.ma.median(tsnr)
    plot_figs(im_id,median_tsnr,aligned,abs_disp,rel_disp,transrot, tsnr,global_ts,t_z,spike_inds)

In [14]:
#####################################################################
###Here you set the files you want to run and it will loop through###
###You may need to adjust this a bit if the filenames have changed###
#####################################################################
home_dir = '/Users/ianballard/Dropbox/Decision Neuroscience Lab/fMRI_Data/Habitization/'
subs= ['HAB02 Session 1','HAB03 Session 1']
scans = map(str,range(1,7))
files = [home_dir + sub + '/EPI' + scan + '.nii.gz' for scan in scans for sub in subs]

In [None]:
dview = rc[0:3]
dview.block = True

dview.push(dict(home_dir = home_dir,
                plot_figs = plot_figs,
               compute_qa = compute_qa,
               find_spikes = find_spikes,
               mask = mask,
               TR = TR,
               estimate_motion = estimate_motion))
dview.execute("import numpy as np")
dview.execute("import nibabel as nb")
dview.execute("import seaborn as sns")
dview.execute("import pandas as pd")
with dview.sync_imports():
    import os
    import numpy
    import matplotlib
    import os
    from nipy.algorithms.registration import affine,Realign4d
    from dipy.segment.mask import median_otsu
    import sys
    import scipy.stats
dview.execute("%matplotlib inline")
dview.execute("import matplotlib.pyplot as plt")
dview.map_sync(run_QA,files)

importing os on engine(s)
importing numpy on engine(s)
importing matplotlib on engine(s)
importing glob from glob on engine(s)
importing affine,Realign4d from nipy.algorithms.registration on engine(s)

//anaconda/lib/python2.7/site-packages/skimage/filter/__init__.py:6: skimage_deprecation: The `skimage.filter` module has been renamed to `skimage.filters`.  This placeholder module will be removed in v0.13.
  warn(skimage_deprecation('The `skimage.filter` module has been renamed '



importing median_otsu from dipy.segment.mask on engine(s)
importing sys on engine(s)
importing json on engine(s)
importing argparse on engine(s)
importing time on engine(s)
importing shutil on engine(s)
importing nipype on engine(s)