Skip to content

Commit

Permalink
Merge pull request #1384 from gabknight/NF_cmc_pft_merged
Browse files Browse the repository at this point in the history
NF - Particle Filtering Tractography (merge)
  • Loading branch information
Garyfallidis committed Feb 6, 2018
2 parents 7ae848e + 14f4c76 commit e4c8b9b
Show file tree
Hide file tree
Showing 13 changed files with 1,127 additions and 130 deletions.
23 changes: 15 additions & 8 deletions dipy/direction/pmf.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ cdef class SimplePmfGen(PmfGen):
self.out = np.empty(pmf_array.shape[3])

cdef double[:] get_pmf_c(self, double* point) nogil:
trilinear_interpolate4d_c(self.pmf_array, point, self.out)
cdef:
size_t len_pmf = self.out.shape[0]
if trilinear_interpolate4d_c(self.pmf_array, point, self.out):
for i in range(len_pmf):
self.out[i] = 0.0
return self.out


Expand Down Expand Up @@ -62,13 +66,16 @@ cdef class SHCoeffPmfGen(PmfGen):
size_t len_B = self.B.shape[1]
double _sum

trilinear_interpolate4d_c(self.shcoeff, point, self.coeff)
for i in range(len_pmf):
_sum = 0
for j in range(len_B):
_sum += self.B[i, j] * self.coeff[j]
self.pmf[i] = _sum
if self.pmf[i] < 0.0:
if trilinear_interpolate4d_c(self.shcoeff, point, self.coeff):
for i in range(len_pmf):
self.pmf[i] = 0.0
else:
for i in range(len_pmf):
_sum = 0
for j in range(len_B):
_sum += self.B[i, j] * self.coeff[j]
self.pmf[i] = _sum
if self.pmf[i] < 0.0:
self.pmf[i] = 0.0

return self.pmf
43 changes: 43 additions & 0 deletions dipy/direction/tests/test_pmf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
import numpy.testing as npt

from dipy.core.sphere import HemiSphere, unit_octahedron
from dipy.direction.pmf import SimplePmfGen, SHCoeffPmfGen


def test_pmf_from_sh():
sphere = HemiSphere.from_sphere(unit_octahedron)
pmfgen = SHCoeffPmfGen(np.ones([2, 2, 2, 28]), sphere, None)

# Test that the pmf is greater than 0 for a valid point
pmf = pmfgen.get_pmf(np.array([0, 0, 0], dtype='float'))
npt.assert_equal(np.sum(pmf) > 0, True)

# Test that the pmf is 0 for invalid Points
npt.assert_array_equal(pmfgen.get_pmf(np.array([-1, 0, 0], dtype='float')),
np.zeros(len(sphere.vertices)))
npt.assert_array_equal(pmfgen.get_pmf(np.array([0, 0, 10], dtype='float')),
np.zeros(len(sphere.vertices)))


def test_pmf_from_array():
sphere = HemiSphere.from_sphere(unit_octahedron)
pmfgen = SimplePmfGen(np.ones([2, 2, 2, len(sphere.vertices)]))

# Test that the pmf is greater than 0 for a valid point
pmf = pmfgen.get_pmf(np.array([0, 0, 0], dtype='float'))
npt.assert_equal(np.sum(pmf) > 0, True)

# Test that the pmf is 0 for invalid Points
npt.assert_array_equal(pmfgen.get_pmf(np.array([-1, 0, 0], dtype='float')),
np.zeros(len(sphere.vertices)))
npt.assert_array_equal(pmfgen.get_pmf(np.array([0, 0, 10], dtype='float')),
np.zeros(len(sphere.vertices)))

npt.assert_raises(
ValueError,
lambda: SimplePmfGen(np.ones([2, 2, 2, len(sphere.vertices)])*-1))


if __name__ == '__main__':
npt.run_module_suite()
10 changes: 5 additions & 5 deletions dipy/segment/tests/test_mrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
square_1[99:157, 99:157, :] = temp_3


def test_greyscale_image():
def test_grayscale_image():

com = ConstantObservationModel()
icm = IteratedConditionalModes()
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_greyscale_image():
npt.assert_(icm_segmentation.min() == 0)


def test_greyscale_iter():
def test_grayscale_iter():

max_iter = 15
beta = np.float64(0.1)
Expand Down Expand Up @@ -195,11 +195,11 @@ def test_greyscale_iter():
npt.assert_(PLY[100, 100, 1, 3] > PLY[100, 100, 1, 1])
npt.assert_(PLY[100, 100, 1, 3] > PLY[100, 100, 1, 2])

mu_upd, sigmasq_upd = com.update_param(image_gauss, PLY, mu, nclasses)
mu_upd, sigmasq_upd = com.update_param(image_gauss, PLY, mu, nclasses)
npt.assert_(mu_upd[0] >= 0.0)
npt.assert_(mu_upd[1] >= 0.0)
npt.assert_(mu_upd[2] >= 0.0)
npt.assert_(mu_upd[3] >= 0.0)
npt.assert_(mu_upd[3] >= 0.0)
npt.assert_(sigmasq_upd[0] >= 0.0)
npt.assert_(sigmasq_upd[1] >= 0.0)
npt.assert_(sigmasq_upd[2] >= 0.0)
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_square_iter():
npt.assert_(mu_upd[0] >= 0.0)
npt.assert_(mu_upd[1] >= 0.0)
npt.assert_(mu_upd[2] >= 0.0)
npt.assert_(mu_upd[3] >= 0.0)
npt.assert_(mu_upd[3] >= 0.0)
npt.assert_(sigmasq_upd[0] >= 0.0)
npt.assert_(sigmasq_upd[1] >= 0.0)
npt.assert_(sigmasq_upd[2] >= 0.0)
Expand Down
18 changes: 12 additions & 6 deletions dipy/tracking/local/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from .localtracking import LocalTracking
from .tissue_classifier import (
ActTissueClassifier, BinaryTissueClassifier, ThresholdTissueClassifier,
TissueClassifier)
from .direction_getter import DirectionGetter
from .localtracking import LocalTracking, ParticleFilteringTracking
from .tissue_classifier import (ActTissueClassifier,
BinaryTissueClassifier,
CmcTissueClassifier,
ConstrainedTissueClassifier,
ThresholdTissueClassifier,
TissueClassifier)

from dipy.tracking import utils

__all__ = ["ActTissueClassifier", "BinaryTissueClassifier", "LocalTracking",
"ThresholdTissueClassifier"]
__all__ = ["ActTissueClassifier", "BinaryTissueClassifier",
"CmcTissueClassifier", "ConstrainedTissueClassifier",
"DirectionGetter", "LocalTracking", "ParticleFilteringTracking",
"ThresholdTissueClassifier", "TissueClassifier"]

0 comments on commit e4c8b9b

Please sign in to comment.