# Initialization

In [64]:
%matplotlib inline
import nibabel as nb
import nipy
import numpy as np
import matplotlib.pyplot as plt
import nipy.algorithms.registration
from ip_utils2 import *
from scipy import ndimage
import os
from glob import glob
from subprocess import call

def warp_mni_to_img(target, warp, phy_coords):
    ''' Warp (physical-space) MNI-space coords to a target image. '''
    phy_coords = np.array([[-24.,-1.,-21.]]).T
    warp = nb.load(os.path.join('/hcp/115320/MNINonLinear/xfms/standard2acpc_dc.nii.gz'))
    target = nb.load(os.path.join('/data/hcp/data/115320/dwi_MD.nii.gz'))
    warp_img_coords = xform_coords(np.linalg.inv(warp.get_affine()), phy_coords)
    warp_data = warp.get_data()
    xo = ndimage.map_coordinates(warp_data[...,0], warp_img_coords, order=1)
    yo = ndimage.map_coordinates(warp_data[...,1], warp_img_coords, order=1)
    zo = ndimage.map_coordinates(warp_data[...,2], warp_img_coords, order=1)
    phy_coords[0,:] -= xo
    phy_coords[1,:] += yo
    phy_coords[2,:] += zo
    warped_img_coords = xform_coords(np.linalg.inv(target.get_affine()), phy_coords)
    return warped_img_coords

def warp_img_to_mni(source, warp, img_coords, mni_ref=None):
    ''' Warp source image coords to mni space. If an mni reference image is passed, 
        then the coordinates returned are in that image space. If mni_ref is none,
        the returned coords are in MNI physical space.
    '''
    phy_coords = xform_coords(source.get_affine(), img_coords)
    warp_img_coords = xform_coords(np.linalg.inv(warp.get_affine()), phy_coords)
    warp_data = warp.get_data()
    xo = ndimage.map_coordinates(warp_data[...,0], warp_img_coords, order=1)
    yo = ndimage.map_coordinates(warp_data[...,1], warp_img_coords, order=1)
    zo = ndimage.map_coordinates(warp_data[...,2], warp_img_coords, order=1)
    phy_coords[0,:] -= xo
    phy_coords[1,:] += yo
    phy_coords[2,:] += zo
    if mni_ref==None:
        return phy_coords
    else:
        return xform_coords(np.linalg.inv(mni_ref.get_affine()), phy_coords)

def xform_sl(subcode, sl_file, img_coords=True):
    out_dir = os.path.join('/data/hcp/data',subcode)
    warp = nb.load(os.path.join('/hcp',subcode,'MNINonLinear/xfms/acpc_dc2standard.nii.gz'))
    mni = nb.load('/usr/share/fsl/data/standard/MNI152_T1_1mm.nii.gz')
    #fa = nb.load(os.path.join(out_dir,'dwi_FA.nii.gz'))
    ref = nb.load(os.path.join('/hcp', subcode, 'T1w/T1w_acpc_dc_restore_1.25.nii.gz'))
    sl = load_streamline_file(os.path.join(out_dir, sl_file))
    if img_coords:
        slx = [warp_img_to_mni(ref, warp, np.array(s).T, mni).T.tolist() for s in sl]
    else:
        slx = [warp_img_to_mni(ref, warp, np.array(s).T).T.tolist() for s in sl]
    return(slx)
    
def make_roi(target, warp, mni_roi, dilation=0):
    sz = target.shape
    # Image space coords of target:
    #x,y,z = np.meshgrid(range(sz[0]),range(sz[1]),range(sz[2]), indexing='ij')
    img_coords = np.array(np.meshgrid(range(sz[0]),range(sz[1]),range(sz[2]), indexing='ij')).reshape((3,-1))
    # Physical space coords of target:
    phy_coords = xform_coords(target.get_affine(), img_coords)
    # MNI image space coords for target 
    mni_img_coords = xform_coords(np.linalg.inv(warp.get_affine()), phy_coords)
    # pull the offsets from the MNI-space LUT
    warp_data = warp.get_data()
    xo = ndimage.map_coordinates(warp_data[...,0], mni_img_coords, order=1)
    yo = ndimage.map_coordinates(warp_data[...,1], mni_img_coords, order=1)
    zo = ndimage.map_coordinates(warp_data[...,2], mni_img_coords, order=1)
    # Apply the offsets to physical space coords of the target
    # FIXME: I think the +/- is related to the affine. These values work for the images 
    # that we are processing (i.e., our results match fsl's applywarp). But I suspect
    # this is not a general solution...
    phy_coords[0,:] -= xo
    phy_coords[1,:] += yo
    phy_coords[2,:] += zo
    # convert the target physical space coords to roi image space
    roi_coords = xform_coords(np.linalg.inv(mni_roi.get_affine()), phy_coords)
    # Pull the values from the ROI map
    roi_vals = ndimage.map_coordinates(mni_roi.get_data(), roi_coords, order=1)
    roi_vals = roi_vals.reshape(sz)
    if dilation>0:
        roi_vals = ndimage.binary_dilation(roi_vals, iterations=dilation)
    roi_vals = ndimage.binary_fill_holes(roi_vals).astype(np.int8)
    roi = nb.Nifti1Image(roi_vals, target.get_affine())
    return roi

def make_rois(subcode, roi_names, roi_dilation=None, qa=False):
    out_dir = os.path.join('/data/hcp/data',subcode)
    sub_dir = os.path.join('/hcp',subcode)
    mni_lut = nb.load(os.path.join(sub_dir,'MNINonLinear/xfms/standard2acpc_dc.nii.gz'))
    #t1 = nb.load(os.path.join(sub_dir,'T1w/T1w_acpc_dc.nii.gz')
    #ref = nb.load(os.path.join(out_dir,'dwi_FA.nii.gz'))
    ref = nb.load(os.path.join(sub_dir,'T1w', 'Diffusion', 'nodif_brain_mask.nii.gz'))
    for i,roi_name in enumerate(roi_names):
        mni_roi = nb.load('/data/ROIs/' + roi_name + '.nii.gz')
        if roi_dilation!=None and len(roi_dilation)>i:
            dilation = roi_dilation[i]
            print('Dilating %s by %d...' % (roi_name, dilation))
        roi = make_roi(ref, mni_lut, mni_roi, roi_dilation)
        nb.save(roi, os.path.join(out_dir, 'ROI_'+roi_name+'.nii.gz'))
        if qa:
            sl = xform_coord(roi.get_affine(), np.array(np.where(roi.get_data())).mean(axis=1)).round()
            outfile = os.path.join(out_dir, subcode+'_ROI_'+roi_name+'.png')
            show_brain(fa, sl=sl, overlay_file=roi, overlay_clip=[1,2], outfile=outfile)
        
def make_rois_fsl(subcode, roi_names, roi_dir='/data/ROIs/'):
    out_dir = os.path.join('/data/hcp/data',subcode)
    warp = os.path.join('/hcp',subcode,'MNINonLinear/xfms/standard2acpc_dc.nii.gz')
    ref = os.path.join('/hcp',subcode,'T1w', 'Diffusion', 'nodif_brain_mask.nii.gz')
    for i,roi_name in enumerate(roi_names):
        infile = os.path.join(roi_dir,roi_name)
        outfile = os.path.join(out_dir, 'ROI_'+roi_name+'_fsl.nii.gz')
        call(['applywarp', '-i', infile, '-r', ref, '-w', warp, '-o', outfile])

#Test warping functions

In [127]:
img_coords_orig = np.array([[28.,122.,69.]],dtype=np.float32)
#img_coords_orig = sl[0][-1:,:]

subcode = '100408'
warp_to_img = nb.load(os.path.join('/hcp',subcode,'MNINonLinear/xfms/acpc_dc2standard.nii.gz'))
warp_to_mni = nb.load(os.path.join('/hcp',subcode,'MNINonLinear/xfms/standard2acpc_dc.nii.gz'))
mni = nb.load('/usr/share/fsl/data/standard/MNI152_T1_1mm.nii.gz')
ref = nb.load(os.path.join('/hcp', subcode, 'T1w/T1w_acpc_dc_restore_1.25.nii.gz'))

mni_coords = warp_img_to_mni(ref, warp_to_mni, img_coords_orig.T).T
img_coords = warp_mni_to_img(ref, warp_to_img, mni_coords.copy().T).T
print('Img_in: '+str(img_coords_orig.round())+', MNI: '+str(mni_coords.round())+', Img_out: '+str(img_coords.round()))

sl_file = os.path.join('/data/hcp/data/',subcode,'RVLPFC2FIRSTamyg_bigRight_optimized.trk')
sl_trk,hdr = nb.trackvis.read(sl_file, points_space='voxel')
sl = [s[0] for s in sl_trk]

sl_mni_coords = warp_img_to_mni(ref, warp_to_img, sl[0].T).T
print('Img_in: '+str(sl[0][-1:,:].round())+', MNI: '+str(sl_mni_coords[-1:,:].round()))

Img_in: [[  28.  122.   69.]], MNI: [[ 60.  28.  14.]], Img_out: [[  28.  122.   69.]]
Img_in: [[  32.  132.   62.]], MNI: [[ 48.  39.   7.]]


In [110]:
(sl[0][-1:,:].round(),sl_mni_coords[-1:,:].round())

(array([[  30.,  114.,   65.]], dtype=float32), array([[ 49.,  20.,   4.]]))

In [116]:
xform_coords(np.linalg.inv(ref.get_affine()), mni_coords.T)

array([[  29.66108475],
       [ 114.17694321],
       [  64.82564278]])

In [117]:
warp_mni_to_img(target, warp, phy_coords)

#Merge MNI-space track files to summarize group data

In [17]:
from dipy.segment import select
from dipy_run import *
from dipy.segment.clustering import QuickBundles
from dipy.segment.metric import ResampleFeature
from dipy.segment.metric import AveragePointwiseEuclideanMetric
feature = ResampleFeature(nb_points=50)
metric = AveragePointwiseEuclideanMetric(feature)
qb = QuickBundles(threshold=15., metric=metric)

def save_to_trackvis(streamlines, outname, dims, pixdim):
    hdr = nb.trackvis.empty_header()
    hdr['voxel_size'] = pixdim
    hdr['voxel_order'] = 'LAS'
    hdr['dim'] = dims
    trk = ((sl, None, None) for sl in streamlines)
    nb.trackvis.write(outname, trk, hdr, points_space='voxel')

In [47]:
trk_file = 'amyg_rifg12.trk'
#trk_file = 'amyg_rm1.trk'
#trk_file = 'amyg_rsfg.trk'
subcodes = sorted([os.path.basename(d) for d in glob('/data/hcp/data/*') 
                   if os.path.exists(os.path.join(d,trk_file))])
len(subcodes)

300

##Get the centroid of the largest cluster for each subject

In [48]:
# Load the fibers for each subject, transforming to MNI space 
slx = []
for sc in subcodes:
    slx.append(xform_sl(sc, trk_file))

In [None]:
# cluster each subject's fibers
clusters = []
for sl in slx:
    clusters.append(qb.cluster([np.array(s) for s in sl]))

In [50]:
# Find the largest cluster for each subject
slx_all = []
for c in clusters:
    if len(c)>0:
        clust_num = np.array(map(len,c)).argmax()
        slx_all.append(c.centroids[clust_num])

In [51]:
# Save the centroid of the largest cluster for each subject
ni = nb.load('/usr/share/fsl/data/standard/MNI152_T1_1mm.nii.gz')
dims = ni.shape
pixdim = ni.header.get_zooms()
save_to_trackvis(slx_all, '/data/hcp/mni_all_centroids_'+trk_file, dims, pixdim)

# STOP HERE

## Generate the ROIs

In [71]:
trk_file = '2M_SIFT.trk'
roi_names = ['RVLPFC_12mm_54_27_12','FIRSTamyg_smRight']
#roi_names = ['RVLPFC_15mm_54_27_12','FIRSTamyg_bigRight']
#roi_names = ['RSFG_10mm_gm','RM1_gm']
subcodes = sorted([os.path.basename(d) for d in glob('/data/hcp/data/*') 
                   if os.path.exists(os.path.join(d,trk_file)) 
                   #and os.path.exists(os.path.join(d,'dwi_MD.nii.gz'))
                   and not os.path.exists(os.path.join(d,'ROI_'+roi_names[0]+'.nii.gz'))])
len(subcodes)

146

In [None]:
for subcode in subcodes:
    #if not os.path.exists(os.path.join('/data/hcp/data',subcode,'ROI_FIRSTamyg_bigRight_fsl.nii.gz')):
    #make_rois_fsl(subcode, roi_names)
    if not os.path.exists(os.path.join('/data/hcp/data',subcode,'ROI_'+roi_names[0]+'.nii.gz')):
        print('making ROIs for ' + subcode + '...')
        make_rois(subcode, roi_names)
    else:
        print('skipping ' + subcode)

## Test mni warping functions

In [None]:
# Warp_mni_to_img
phy_coords = np.array([[-24.,-1.,-21.],[54.,33.,10.]]).T
warp = nb.load(os.path.join('/hcp/115320/MNINonLinear/xfms/standard2acpc_dc.nii.gz'))
target = nb.load(os.path.join('/data/hcp/data/115320/dwi_MD.nii.gz'))

warped_img_coords = warp_mni_to_img(target, warp, phy_coords)
warped_img_coords

In [None]:
source =  nb.load(os.path.join('/data/hcp/data/115320/dwi_MD.nii.gz'))
warp = nb.load(os.path.join('/hcp/115320/MNINonLinear/xfms/acpc_dc2standard.nii.gz'))

warp_img_to_mni(source, warp, warped_img_coords, mni_ref=None)

#Data Quality Check
Display the normalized ROIs from the TMS stimulation study on the HCP data for visual inspection of ROI alignment.

In [None]:
roi = nb.load('/data/hcp/data/100307/ROI_FIRSTamyg_smRight.nii.gz')
t1 = nb.load('/hcp/100307/T1w/T1w_acpc_dc_restore_1.25.nii.gz')
sl = xform_coord(roi.get_affine(), np.array(np.where(roi.get_data())).mean(axis=1)).round()
show_brain(t1, sl=sl, overlay_file=roi, clip=99, overlay_clip=[1,2])

In [None]:
trk_file = '2M_SIFT.trk'
subcodes = sorted([os.path.basename(d) for d in glob('/data/hcp/data/*') 
                   if os.path.exists(os.path.join(d,trk_file))
                   and os.path.exists(os.path.join(d,'ROI_FIRSTamyg_bigRight.nii.gz'))])
len(subcodes)

In [None]:
roi_names = ['RVLPFC_15mm_54_27_12','FIRSTamyg_bigRight','RVLPFC_15mm_54_27_12_fsl','FIRSTamyg_bigRight_fsl']
outdir = '/data/hcp/data/'
for sc in subcodes:
    #md = nb.load(os.path.join(outdir, sc, 'dwi_MD.nii.gz'))
    md = nb.load(os.path.join('/hcp', sc, 'T1w/T1w_acpc_dc_restore_1.25.nii.gz'))
    for roi_name in roi_names:
        roi = nb.load(os.path.join(outdir, sc, 'ROI_' + roi_name + '.nii.gz'))
        sl = xform_coord(roi.get_affine(), np.array(np.where(roi.get_data())).mean(axis=1)).round()
        outfile = os.path.join(out_dir, sc+'_ROI_'+roi_name+'.png')
        show_brain(md, sl=sl, overlay_file=roi, clip=99, overlay_clip=[1,2], outfile=outfile)