Skip to content

Commit

Permalink
BF - move sampling_matrix to model
Browse files Browse the repository at this point in the history
  • Loading branch information
MrBago committed Sep 19, 2014
1 parent 27a6f49 commit 4f89efd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
24 changes: 24 additions & 0 deletions dipy/reconst/odf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,30 @@ class OdfModel(ReconstModel):
def __init__(self, gtab):
ReconstModel.__init__(self, gtab)

def sampling_matrix(self, sphere):
"""The matrix needed to sample ODFs from coefficients of the model.
Parameters
----------
sphere : Sphere
Points used to sample ODF.
Returns
-------
sampling_matrix : array
The size of the matrix will be (N, M) where N is the number of
vertices on sphere and M is the number of coefficients needed by
the model.
"""
sampling_matrix = self.cache_get("sampling_matrix", sphere)
if sampling_matrix is None:
sh_order = self.sh_order
theta = sphere.theta
phi = sphere.phi
sampling_matrix, m, n = real_sym_sh_basis(sh_order, theta, phi)
self.cache_set("sampling_matrix", sphere, sampling_matrix)
return sampling_matrix

def fit(self, data):
"""To be implemented by specific odf models"""
raise NotImplementedError("To be implemented in sub classes")
Expand Down
25 changes: 1 addition & 24 deletions dipy/reconst/shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def odf(self, sphere):
The value of the odf on each point of `sphere`.
"""
B = self.sampling_matrix(sphere)
B = self.model.sampling_matrix(sphere)
return dot(self._shm_coef, B.T)

@auto_attr
Expand All @@ -571,29 +571,6 @@ def shm_coeff(self):
"""
return self._shm_coef

def sampling_matrix(self, sphere):
"""The matrix needed to sample ODFs from coefficients of the model.
Parameters
----------
sphere : Sphere
Points used to sample ODF.
Returns
-------
sampling_matrix : array
The size of the matrix will be (N, M) where N is the number of
vertices on sphere and M is the number of coefficients needed by
the model.
"""
sampling_matrix = self.model.cache_get("sampling_matrix", sphere)
if sampling_matrix is None:
sh_order = self.model.sh_order
theta = sphere.theta
phi = sphere.phi
sampling_matrix, m, n = real_sym_sh_basis(sh_order, theta, phi)
self.model.cache_set("sampling_matrix", sphere, sampling_matrix)
return sampling_matrix

def prediction_matrix(self, sphere, gtab):
"""
Expand Down
2 changes: 1 addition & 1 deletion dipy/tracking/local/direction_getter_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ShmFitPmfGen(PmfGen):
def __init__(self, shmfit, sphere):
self.fit = shmfit
self.sphere = sphere
self._B = shmfit.sampling_matrix(sphere)
self._B = shmfit.model.sampling_matrix(sphere)
self._coeff = shmfit.shm_coeff

def get_pmf(self, point):
Expand Down

0 comments on commit 4f89efd

Please sign in to comment.