In [3]:
import os.path as op
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import dipy.data as dpd
from dipy.data import fetcher
import dipy.tracking.utils as dtu
import dipy.tracking.streamline as dts
from dipy.io.streamline import save_tractogram, load_tractogram
from dipy.stats.analysis import afq_profile, gaussian_weights
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.io.stateful_tractogram import Space
from os.path import abspath, expanduser, join
import AFQ.utils.streamlines as aus
import AFQ.data as afd
import AFQ.tractography as aft
import AFQ.registration as reg
import AFQ.dti as dti
import AFQ.segmentation as seg
from AFQ.utils.volume import patch_up_roi
from pandas import Series, read_csv, to_numeric

ModuleNotFoundError: No module named 'dipy.stats'

In [None]:
#Set user and path variables
local='False'
user = expanduser('~')
if user == '/Users/lucindasisk':
    if local == 'True':
        laptop = '/Users/lucindasisk/Desktop/DATA'
        home = join(user, 'Desktop/Milgram/candlab')
        raw_dir = join(home, 'data/mri/bids_recon/shapes')
        proc_dir = join(home, 'analyses/shapes/dwi')
        workflow_dir = join(laptop, 'workflows_ls')
        data_dir = join(laptop, 'data_ls')
    else:
        home = join(user, 'Desktop/Milgram/candlab')
        raw_dir = join(home, 'data/mri/bids_recon/shapes')
        proc_dir = join(home, 'analyses/shapes/dwi/data')
        workflow_dir = join(home, 'analyses/shapes/dwi/workflows')
        data_dir = join(home, 'analyses/shapes/dwi/data')
else:
    home = '/gpfs/milgram/project/gee_dylan/candlab'
    raw_dir = join(home, 'data/mri/bids_recon/shapes')
    proc_dir = join(home, 'analyses/shapes/dwi/data')
    workflow_dir = join(home, 'analyses/shapes/dwi/workflows')
    data_dir = join(home, 'analyses/shapes/dwi/data')
    
# Read in subject subject_list
subject_info = read_csv(
    home + '/scripts/shapes/mri/dwi/shapes_dwi_subjList_08.07.2019.txt', sep=' ', header=None)
subject_list = subject_info[0].tolist()


In [2]:
origin = '/Users/lucindasisk/Desktop/Milgram/candlab/analyses/shapes/dwi/data/5_tract_Reconstruction/sub-A698'
bval = join(raw_dir, 'sub-A698/ses-shapesV1/dwi/sub-A698_ses-shapesV1_dwi.bval')
bvec = join(proc_dir,'3_Eddy_Corrected/sub-A698/eddy_corrected.eddy_rotated_bvecs')
dtifile = join(proc_dir,'3_Eddy_Corrected/sub-A698/eddy_corrected.nii.gz')
img = nib.load(proc_dir +'/3_Eddy_Corrected/sub-A698/eddy_corrected.nii.gz')

print("Calculating DTI...")
if not op.exists(join(origin,'dti_FA.nii.gz')):
    dti_params = dti.fit_dti(dtifile, bval, bvec,
                             out_dir=origin)
else:
    dti_params = {'FA': join(origin, 'dti_FA.nii.gz'),
                  'params': join(origin, 'dti_params.nii.gz')}

FA_img = nib.load(dti_params['FA'])
FA_data = FA_img.get_fdata()

templates = afd.read_templates()
bundle_names = ["UNC", "CGC"]

bundles = {}
for name in bundle_names:
    for hemi in ['_R', '_L']:
        bundles[name + hemi] = {
            'ROIs': [templates[name + '_roi1' + hemi],
                     templates[name + '_roi2' + hemi]],
            'rules': [True, True],
            'prob_map': templates[name + hemi + '_prob_map'],
            'cross_midline': False}

print("Registering to template...")
MNI_T2_img = dpd.read_mni_template()
if not op.exists('mapping.nii.gz'):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
    warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata, gtab)
    reg.write_mapping(mapping, './mapping.nii.gz')
else:
    mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)



tg = load_tractogram(join(origin, 'SIFT_msCSD_brain_tracktography.tck'), img)
streamlines = tg.streamlines

streamlines = dts.Streamlines(
    dtu.transform_tracking_output(streamlines,
                                  np.linalg.inv(img.affine)))

#############
print("Segmenting fiber groups...")
segmentation = seg.Segmentation()
segmentation.segment(bundles,
                     streamlines,
                     fdata=dtifile,
                     fbval=bval,
                     fbvec=bvec,
                     mapping=mapping,
                     reg_template=MNI_T2_img)


fiber_groups = segmentation.fiber_groups

print("Cleaning fiber groups...")

##############

# for bundle in bundles:
#     new_fibers, idx_in_bundle = seg.clean_fiber_group(
#         fiber_groups[bundle]['sl'],
#         return_idx=True)

#     idx_in_global = fiber_groups[bundle]['idx'][idx_in_bundle]

# for kk in fiber_groups:
#     print(kk, len(fiber_groups[kk]['sl']))

#     sft = StatefulTractogram(
#         dtu.transform_tracking_output(fiber_groups[kk]['sl'], img.affine),
#         img, Space.RASMM)

#     save_tractogram(sft, './%s_afq.trk'%kk,
#                     bbox_valid_check=False)

#######
for bundle in bundles:
    fiber_groups[bundle] = seg.clean_fiber_group(fiber_groups[bundle])

for kk in fiber_groups:
    print(kk, len(fiber_groups[kk]))

    sft = StatefulTractogram(
        dtu.transform_tracking_output(fiber_groups[kk], img.affine),
        img, Space.RASMM)

    save_tractogram(sft, origin + '/msCSD_%s_afq.trk'%kk,
                    bbox_valid_check=False)
 #############   
# print("Extracting tract profiles...")
# for bundle in bundles:
#     fig, ax = plt.subplots(1)
#     weights = gaussian_weights(fiber_groups[bundle]['sl'])
#     profile = afq_profile(FA_data, fiber_groups[bundle]['sl'],
#                           np.eye(4), weights=weights)
#     ax.plot(profile)
#     ax.set_title(bundle)

# plt.show()

print("Extracting tract profiles...")
for bundle in bundles:
    fig, ax = plt.subplots(1)
    weights = gaussian_weights(fiber_groups[bundle])
    profile = afq_profile(FA_data, fiber_groups[bundle],
                          np.eye(4), weights=weights)
    ax.plot(profile)
    ax.set_title(bundle)

plt.show()

NameError: name 'join' is not defined