Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Stateful tractogram examples #1925

Merged
merged 12 commits into from
Jul 29, 2019
22 changes: 11 additions & 11 deletions dipy/tracking/tests/test_life.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os.path as op

import numpy as np
import numpy.testing as npt
import scipy.linalg as la
import nibabel as nib
import dipy.core.gradients as grad
import dipy.core.optimize as opt
import dipy.data as dpd
import dipy.reconst.dti as dti
import dipy.tracking.life as life

from dipy.io.gradients import read_bvals_bvecs
from dipy.tracking.utils import move_streamlines
from dipy.io.stateful_tractogram import Space, StatefulTractogram
import dipy.tracking.life as life
import nibabel as nib
import numpy as np
import numpy.testing as npt
import scipy.linalg as la

THIS_DIR = op.dirname(__file__)

Expand Down Expand Up @@ -162,10 +160,12 @@ def test_fit_data():
ni_data = nib.load(fdata)
data = ni_data.get_data()
tensor_streamlines = nib.streamlines.load(fstreamlines).streamlines
tensor_streamlines = move_streamlines(tensor_streamlines, np.eye(4),
ni_data.affine)
sft = StatefulTractogram(tensor_streamlines, ni_data, Space.RASMM)
sft.to_vox()
tensor_streamlines_vox = sft.streamlines

life_model = life.FiberModel(gtab)
life_fit = life_model.fit(data, tensor_streamlines)
life_fit = life_model.fit(data, tensor_streamlines_vox)
model_error = life_fit.predict() - life_fit.data
model_rmse = np.sqrt(np.mean(model_error ** 2, -1))
matlab_rmse, matlab_weights = dpd.matlab_life_results()
Expand Down
12 changes: 7 additions & 5 deletions dipy/workflows/tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from dipy.data import get_fnames
from dipy.io.image import save_nifti
from dipy.io.streamline import load_tractogram
from dipy.workflows.mask import MaskFlow
from dipy.workflows.reconst import ReconstCSDFlow
from dipy.workflows.tracking import (LocalFiberTrackingPAMFlow,
Expand Down Expand Up @@ -217,13 +218,14 @@ def tractogram_has_seeds(tractogram_path):


def seeds_are_same_space_as_streamlines(tractogram_path):
tractogram = \
nib.streamlines.load(tractogram_path).tractogram
seeds = tractogram.data_per_streamline['seeds']
streamlines = tractogram.streamlines
sft = load_tractogram(tractogram_path, 'same', bbox_valid_check=False)
seeds = sft.data_per_streamline['seeds']
streamlines = sft.streamlines

for seed, streamline in zip(seeds, streamlines):
map_res = list(map(lambda x: np.allclose(seed, x), streamline))
map_res = list(map(lambda x: np.allclose(seed, x,
atol=1e-2,
rtol=1e-4), streamline))
# If no point is close to the seed, it likely means that the seed is
# not in the same space as the streamline
if not np.any(map_res):
Expand Down
26 changes: 14 additions & 12 deletions dipy/workflows/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import logging
import numpy as np

from nibabel.streamlines import save, Tractogram

from dipy.direction import (DeterministicMaximumDirectionGetter,
ProbabilisticDirectionGetter,
ClosestPeakDirectionGetter)
from dipy.io.image import load_nifti
from dipy.io.peaks import load_peaks
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_tractogram
from dipy.tracking import utils
from dipy.tracking.local import (BinaryTissueClassifier,
ThresholdTissueClassifier, LocalTracking,
Expand Down Expand Up @@ -107,12 +107,14 @@ def _core_run(self, stopping_path, use_binary_mask, stopping_thr,

if save_seeds:
streamlines, seeds = zip(*tracking_result)
tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4))
tractogram.data_per_streamline['seeds'] = seeds
seeds = {'seeds': seeds}
else:
tractogram = Tractogram(tracking_result, affine_to_rasmm=np.eye(4))
streamlines = list(tracking_result)
seeds = {}

save(tractogram, out_tract)
sft = StatefulTractogram(streamlines, seeding_path, Space.RASMM,
data_per_streamline=seeds)
save_tractogram(sft, out_tract, bbox_valid_check=False)
logging.info('Saved {0}'.format(out_tract))

def run(self, pam_files, stopping_files, seeding_files,
Expand Down Expand Up @@ -319,12 +321,12 @@ def run(self, pam_files, wm_files, gm_files, csf_files, seeding_files,

if save_seeds:
streamlines, seeds = zip(*tracking_result)
tractogram = Tractogram(streamlines, affine_to_rasmm=np.eye(4))
tractogram.data_per_streamline['seeds'] = seeds
seeds = {'seeds': seeds}
else:
tractogram = Tractogram(tracking_result,
affine_to_rasmm=np.eye(4))

save(tractogram, out_tract)
streamlines = list(tracking_result)
seeds = {}

sft = StatefulTractogram(streamlines, seeding_path, Space.RASMM,
data_per_streamline=seeds)
save_tractogram(sft, out_tract, bbox_valid_check=False)
logging.info('Saved {0}'.format(out_tract))
36 changes: 19 additions & 17 deletions doc/examples/afq_tract_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@

"""

import dipy.stats.analysis as dsa
import nibabel as nib
import dipy.tracking.streamline as dts
from dipy.segment.clustering import QuickBundles
from dipy.segment.metric import (AveragePointwiseEuclideanMetric,
ResampleFeature)
from dipy.data.fetcher import get_two_hcp842_bundles
import dipy.data as dpd
from dipy.io.streamline import load_trk
import matplotlib.pyplot as plt
import numpy as np
import os.path as op
Expand All @@ -40,9 +49,8 @@

"""

from dipy.io.streamline import load_trk
cst_l, hdr = load_trk("CST_L.trk")
af_l, hdr = load_trk("AF_L.trk")
cst_l = load_trk("CST_L.trk", "same", bbox_valid_check=False).streamlines
af_l = load_trk("AF_L.trk", "same", bbox_valid_check=False).streamlines

transform = np.load("slr_transform.npy")

Expand All @@ -61,17 +69,13 @@
for different subjects, which means that we'll get roughly the same orientation
"""

import dipy.data as dpd
from dipy.data.fetcher import get_two_hcp842_bundles
model_af_l_file, model_cst_l_file = get_two_hcp842_bundles()

model_af_l, hdr = load_trk(model_af_l_file)
model_cst_l, hdr = load_trk(model_cst_l_file)

model_af_l = load_trk(model_af_l_file, "same",
bbox_valid_check=False).streamlines
model_cst_l = load_trk(model_cst_l_file, "same",
bbox_valid_check=False).streamlines

from dipy.segment.metric import (AveragePointwiseEuclideanMetric,
ResampleFeature)
from dipy.segment.clustering import QuickBundles

feature = ResampleFeature(nb_points=100)
metric = AveragePointwiseEuclideanMetric(feature)
Expand All @@ -98,7 +102,6 @@
individual, and not relative to the atlas space.
"""

import dipy.tracking.streamline as dts

oriented_cst_l = dts.orient_by_streamline(cst_l, standard_cst_l,
affine=transform)
Expand All @@ -109,13 +112,13 @@
"""
Read volumetric data from an image corresponding to this subject.

For the purpose of this, we've extracted only the FA within the bundles in question,
but in real use, this is where you would add the FA map of your subject.
For the purpose of this, we've extracted only the FA within the bundles in
question, but in real use, this is where you would add the FA map of your
subject.
"""

files, folder = dpd.fetch_bundle_fa_hcp()

import nibabel as nib
img = nib.load(op.join(folder, "hcp_bundle_fa.nii.gz"))
fa = img.get_fdata()

Expand All @@ -124,7 +127,6 @@
Calculate weights for each bundle:
"""

import dipy.stats.analysis as dsa

w_cst_l = dsa.gaussian_weights(oriented_cst_l)
w_af_l = dsa.gaussian_weights(oriented_af_l)
Expand Down Expand Up @@ -171,4 +173,4 @@

.. [Garyfallidis12] Garyfallidis E. et al., QuickBundles a method for
tractography simplification, Frontiers in Neuroscience, vol 6, no 175, 2012.
"""
"""
64 changes: 38 additions & 26 deletions doc/examples/bundle_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,40 @@
First import the necessary modules.
"""

from dipy.data.fetcher import get_two_hcp842_bundles
from dipy.data.fetcher import (fetch_target_tractogram_hcp,
fetch_bundle_atlas_hcp842,
get_bundle_atlas_hcp842,
get_target_tractogram_hcp)
import numpy as np
from dipy.segment.bundles import RecoBundles
from dipy.align.streamlinear import whole_brain_slr
from dipy.viz import window, actor
from fury import actor, window
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import load_trk, save_trk
from dipy.io.utils import create_tractogram_header


"""
Download and read data for this tutorial
"""

from dipy.data.fetcher import (fetch_target_tractogram_hcp,
fetch_bundle_atlas_hcp842,
get_bundle_atlas_hcp842,
get_target_tractogram_hcp)

target_file, target_folder = fetch_target_tractogram_hcp()
atlas_file, atlas_folder = fetch_bundle_atlas_hcp842()

atlas_file, all_bundles_files = get_bundle_atlas_hcp842()
target_file = get_target_tractogram_hcp()

atlas, atlas_header = load_trk(atlas_file)
target, target_header = load_trk(target_file)
sft_atlas = load_trk(atlas_file, "same", bbox_valid_check=False)
atlas = sft_atlas.streamlines
atlas_header = create_tractogram_header(atlas_file,
*sft_atlas.space_attribute)

sft_target = load_trk(target_file, "same", bbox_valid_check=False)
target = sft_target.streamlines
target_header = create_tractogram_header(atlas_file,
*sft_atlas.space_attribute)

"""
let's visualize atlas tractogram and target tractogram before registration
Expand All @@ -42,8 +52,8 @@

ren = window.Renderer()
ren.SetBackground(1, 1, 1)
ren.add(actor.line(atlas, colors=(1,0,1)))
ren.add(actor.line(target, colors=(1,1,0)))
ren.add(actor.line(atlas, colors=(1, 0, 1)))
ren.add(actor.line(target, colors=(1, 1, 0)))
window.record(ren, out_path='tractograms_initial.png', size=(600, 600))
if interactive:
window.show(ren)
Expand All @@ -62,7 +72,7 @@
"""

moved, transform, qb_centroids1, qb_centroids2 = whole_brain_slr(
atlas, target, x0='affine', verbose=True, progressive=True)
atlas, target, x0='affine', verbose=True, progressive=True)


"""
Expand All @@ -81,8 +91,8 @@

ren = window.Renderer()
ren.SetBackground(1, 1, 1)
ren.add(actor.line(atlas, colors=(1,0,1)))
ren.add(actor.line(moved, colors=(1,1,0)))
ren.add(actor.line(atlas, colors=(1, 0, 1)))
ren.add(actor.line(moved, colors=(1, 1, 0)))
window.record(ren, out_path='tractograms_after_registration.png',
size=(600, 600))
if interactive:
Expand All @@ -101,14 +111,13 @@
as model bundles
"""

from dipy.data.fetcher import get_two_hcp842_bundles
model_af_l_file, model_cst_l_file = get_two_hcp842_bundles()

"""
Extracting bundles using recobundles [Garyfallidis17]_
"""

model_af_l, hdr = load_trk(model_af_l_file)
sft_af_l = load_trk(model_af_l_file, "same", bbox_valid_check=False)
model_af_l = sft_af_l.streamlines

rb = RecoBundles(moved, verbose=True)

Expand All @@ -129,10 +138,10 @@

ren = window.Renderer()
ren.SetBackground(1, 1, 1)
ren.add(actor.line(model_af_l, colors=(.1,.7,.26)))
ren.add(actor.line(recognized_af_l, colors=(.1,.1,6)))
ren.add(actor.line(model_af_l, colors=(.1, .7, .26)))
ren.add(actor.line(recognized_af_l, colors=(.1, .1, 6)))
ren.set_camera(focal_point=(320.21296692, 21.28884506, 17.2174015),
position=(2.11, 200.46, 250.44) , view_up=(0.1, -1.028, 0.18))
position=(2.11, 200.46, 250.44), view_up=(0.1, -1.028, 0.18))
window.record(ren, out_path='AF_L_recognized_bundle.png',
size=(600, 600))
if interactive:
Expand All @@ -153,10 +162,12 @@
space of the subject anatomy.

"""
reco_af_l = StatefulTractogram(target[af_l_labels], target_header,
Space.RASMM)
save_trk(reco_af_l, "AF_L.trk", bbox_valid_check=False)

save_trk( "AF_L.trk", target[af_l_labels], hdr['voxel_to_rasmm'])

model_cst_l, hdr = load_trk(model_cst_l_file)
sft_cst_l = load_trk(model_cst_l_file, "same", bbox_valid_check=False)
model_cst_l = sft_cst_l.streamlines

recognized_cst_l, cst_l_labels = rb.recognize(model_bundle=model_cst_l,
model_clust_thr=5.,
Expand All @@ -175,8 +186,8 @@

ren = window.Renderer()
ren.SetBackground(1, 1, 1)
ren.add(actor.line(model_cst_l, colors=(.1,.7,.26)))
ren.add(actor.line(recognized_cst_l, colors=(.1,.1,6)))
ren.add(actor.line(model_cst_l, colors=(.1, .7, .26)))
ren.add(actor.line(recognized_cst_l, colors=(.1, .1, 6)))
ren.set_camera(focal_point=(-18.17281532, -19.55606842, 6.92485857),
position=(-360.11, -340.46, -40.44),
view_up=(-0.03, 0.028, 0.89))
Expand All @@ -199,8 +210,9 @@
Save the bundle as a trk file:

"""

save_trk("CST_L.trk", target[cst_l_labels], hdr['voxel_to_rasmm'])
reco_cst_l = StatefulTractogram(target[cst_l_labels], target_header,
Space.RASMM)
save_trk(reco_af_l, "CST_L.trk", bbox_valid_check=False)


"""
Expand All @@ -212,4 +224,4 @@
bundles using local and global streamline-based registration
and clustering, Neuroimage, 2017.

"""
"""