Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF: MSMT - CSD #1858

Merged
merged 33 commits into from Jun 25, 2019
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a04de3e
added msd with OSQP
ShreyasFadnavis Oct 9, 2018
e746079
added msd with matrices
ShreyasFadnavis Oct 9, 2018
9ef97dd
updates
ShreyasFadnavis Oct 11, 2018
00c2acf
MTMS- CVXPY WIP
ShreyasFadnavis Oct 11, 2018
3a020cd
ex_mrf_classifier
ShreyasFadnavis Oct 13, 2018
01a6b68
update example
ShreyasFadnavis Oct 16, 2018
c63113c
update example
ShreyasFadnavis Oct 17, 2018
2bbfa54
simulated response
ShreyasFadnavis Oct 17, 2018
a307697
working full slice
ShreyasFadnavis Oct 18, 2018
75fbe03
working one slice
ShreyasFadnavis Oct 18, 2018
562f8ab
working opt_msd 1 slice
ShreyasFadnavis Nov 3, 2018
d22f245
osqp partially fixed
ShreyasFadnavis Nov 11, 2018
be850d6
worksgit add .
ShreyasFadnavis Nov 15, 2018
80acbf9
msd changes
ShreyasFadnavis Nov 15, 2018
9a3b60f
working needs init
ShreyasFadnavis Nov 21, 2018
d9ccf78
changes
ShreyasFadnavis Nov 28, 2018
1a5a163
changes
ShreyasFadnavis Nov 30, 2018
3f8d4e8
changed cvxopt to cvxpy for pos_constrained
ShreyasFadnavis Jun 14, 2019
2398587
finally cvxpy refactored MCSD
ShreyasFadnavis Jun 14, 2019
d082020
del
ShreyasFadnavis Jun 14, 2019
794262b
clean and rename
ShreyasFadnavis Jun 14, 2019
f910732
fixes
ShreyasFadnavis Jun 17, 2019
0640aff
fixed positivity constraint
ShreyasFadnavis Jun 17, 2019
d413cc1
GLPK dependency resolved
ShreyasFadnavis Jun 17, 2019
0ee30ee
doc added
ShreyasFadnavis Jun 17, 2019
b67db6c
review fixes
ShreyasFadnavis Jun 18, 2019
e497109
added basic test
ShreyasFadnavis Jun 18, 2019
038a4ff
test fixed
ShreyasFadnavis Jun 18, 2019
7eebff4
test rem for posit constr
ShreyasFadnavis Jun 19, 2019
4c3dd12
posit constraint removed
ShreyasFadnavis Jun 19, 2019
d22e4a7
refactored
ShreyasFadnavis Jun 20, 2019
a1be125
changes suggested by Ariel + Doc
ShreyasFadnavis Jun 24, 2019
5ea2ed4
Update dipy/reconst/mcsd.py
ShreyasFadnavis Jun 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
347 changes: 347 additions & 0 deletions dipy/reconst/mcsd.py
@@ -0,0 +1,347 @@
import numpy as np
import numpy.linalg as la
from dipy.core import geometry as geo
from dipy.data import default_sphere
from dipy.reconst import shm
from dipy.reconst.multi_voxel import multi_voxel_fit

from dipy.utils.optpkg import optional_package
cvx, have_cvxpy, _ = optional_package("cvxpy")

SH_CONST = .5 / np.sqrt(np.pi)


def multi_tissue_basis(gtab, sh_order, iso_comp):
"""
Builds a basis for multi-shell CSD model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation for this function is still not complete.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there. Just need to document these.

Sorry for that.. On it!

"""
if iso_comp < 1:
msg = ("Multi-tissue CSD requires at least 2 tissue compartments")
raise ValueError(msg)
r, theta, phi = geo.cart2sphere(*gtab.gradients.T)
m, n = shm.sph_harm_ind_list(sh_order)
B = shm.real_sph_harm(m, n, theta[:, None], phi[:, None])
B[np.ix_(gtab.b0s_mask, n > 0)] = 0.

iso = np.empty([B.shape[0], iso_comp])
iso[:] = SH_CONST

B = np.concatenate([iso, B], axis=1)
return B, m, n


class MultiShellResponse(object):
arokem marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, response, sh_order, shells):
""" Estimate Multi Shell response function for multiple tissues and
multiple shells.

Parameters
----------
response : tuple or AxSymShResponse object
A tuple with two elements. The first is the eigen-values as an (3,)
ndarray and the second is the signal value for the response
function without diffusion weighting. This is to be able to
generate a single fiber synthetic signal.
sh_order : ndarray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really an array? Why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really an array? Why?

My bad!

shells : int
Number of shells in the data
"""
self.response = response
self.sh_order = sh_order
self.n = np.arange(0, sh_order + 1, 2)
self.m = np.zeros_like(self.n)
self.shells = shells
if self.iso < 1:
raise ValueError("sh_order and shape of response do not agree")

@property
def iso(self):
return self.response.shape[1] - (self.sh_order // 2) - 1


def closest(haystack, needle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is:

  1. two lines implementing a fairly straightforward use of numpy,
  2. not documented or directly tested, and
  3. used exactly once.

Could you just move this code into the calling function and eliminate this?

diff = abs(haystack[:, None] - needle)
return diff.argmin(axis=0)


def _inflate_response(response, gtab, n, delta):
if any((n % 2) != 0) or (n.max() // 2) >= response.sh_order:
raise ValueError("Response and n do not match")

iso = response.iso
n_idx = np.empty(len(n) + iso, dtype=int)
n_idx[:iso] = np.arange(0, iso)
n_idx[iso:] = n // 2 + iso

b_idx = closest(response.shells, gtab.bvals)
kernal = response.response / delta

return kernal[np.ix_(b_idx, n_idx)]


def _basic_delta(iso, m, n, theta, phi):
"""Simple delta function
Parameters
----------
iso: int (optional)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Says it's optional, but it doesn't seem to be optional in the function signature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, it would make sense to add error handling to make sure that it's at least 2.

Number of tissue compartments for running the MSMT-CSD. Minimum
number of compartments required is 2.
Default: 2
m : int ``|m| <= n``
The order of the harmonic.
n : int ``>= 0``
The degree of the harmonic.
theta : array_like
inclination or polar angle
phi : array_like
azimuth angle
"""
wm_d = shm.gen_dirac(m, n, theta, phi)
iso_d = [SH_CONST] * iso
return np.concatenate([iso_d, wm_d])


def _pos_constrained_delta(iso, m, n, theta, phi, reg_sphere=default_sphere):
"""
Delta function optimized to avoid negative lobes. Implements a Linear
Programming solver from `CVXPY` to impose this positivity constraint. The
default solver used is GLPK.

Parameters
----------
iso: int (optional)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.

Number of tissue compartments for running the MSMT-CSD. Minimum
number of compartments required is 2.
Default: 2
m : int ``|m| <= n``
The order of the harmonic.
n : int ``>= 0``
The degree of the harmonic.
theta : array_like
inclination or polar angle
phi : array_like
azimuth angle
reg_sphere : Sphere (optional)
sphere used to build the regularization B matrix.
Default: 'symmetric362'.
weight given to the constrained-positivity regularization part of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to document a parameter that no longer exists(?)

the deconvolution equation (see [1]_). Default: 1
"""

x, y, z = geo.sphere2cart(1., theta, phi)

# Realign reg_sphere so that the first vertex is aligned with delta
# orientation (theta, phi).
M = geo.vec2vec_rotmat(reg_sphere.vertices[0], [x, y, z])
new_vertices = np.dot(reg_sphere.vertices, M.T)
_, t, p = geo.cart2sphere(*new_vertices.T)

B = shm.real_sph_harm(m, n, t[:, None], p[:, None])
G_temp = np.ascontiguousarray(B[:, n != 0])
# c_ samples the delta function at the delta orientation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems to refer to a variable that no longer exists (?).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe this is c_temp?

c_temp = G_temp[0][:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we pulling out the first element of this array?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a question for @MrBago ... it works as expected, nevertheless :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c_temp should be shm.real_sph_harm(m[n != 0], n[n != 0], theta(x, y, z), phi(x, y, z). Since new_vertices are reoriented so that new_vertices[0] == [x, y, z], we can sample the matrix generated from new_vertices.

a_temp, b_temp = G_temp.shape

c_int = cvx.Parameter((c_temp.shape[0], 1))
c_int.value = -c_temp
G = cvx.Parameter((a_temp, b_temp))
G.value = -G_temp
h = cvx.Parameter((a_temp, b_temp))
h_temp = np.full((a_temp, b_temp), SH_CONST ** 2)
h.value = h_temp

# n == 0 is set to sh_const to ensure a normalized delta function.
# n > 0 values are optimized so that delta > 0 on all points of the sphere
# and delta(theta, phi) is maximized.
lp_prob = cvx.Problem(cvx.Maximize(cvx.sum(c_temp)), [G*x <= h])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right, is this x the same as the x defined in on 132? I thing x should be a variable & should be referenced in the Maximize.

I'm not sure how useful the _pos_constrained_delta is, consider removing it for this PR and adding it later with better tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.. I was about to ask the same..! Will take this function down for now :)

r = lp_prob.solve()
out = np.zeros(B.shape[1])
out[n == 0] = SH_CONST
out[n != 0] = r

iso_d = [SH_CONST] * iso
return np.concatenate([iso_d, out])


delta_functions = {"basic": _basic_delta,
"positivity_constrained": _pos_constrained_delta}


class MultiShellDeconvModel(shm.SphHarmModel):
def __init__(self, gtab, response, reg_sphere=default_sphere, iso=2,
delta_form='basic'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe instead of delta_form, we can call this positive_delta and have it be a boolean that chooses _pos_constrained_delta for the delta function, when set to True?

r"""
Multi-Shell Multi-Tissue Constrained Spherical Deconvolution
(MSMT-CSD) [1]_. This method extends the CSD model proposed in [2]_ by
the estimation of multiple response functions as a function of multiple
b-values and multiple tissue types.

Spherical deconvolution computes a fiber orientation distribution
(FOD), also called fiber ODF (fODF) [2]_. The fODF is derived from
different tissue types and thus overcomes the overestimation of WM in
GM and CSF areas.

The response function is based on the different tissue types
and is provided as input to the MultiShellDeconvModel.
It will be used as deconvolution kernel, as described in [2]_.

Parameters
----------
gtab : GradientTable
response : tuple or AxSymShResponse object
A tuple with two elements. The first is the eigen-values as an (3,)
ndarray and the second is the signal value for the response
function without diffusion weighting. This is to be able to
generate a single fiber synthetic signal. The response function
will be used as deconvolution kernel ([1]_)
reg_sphere : Sphere (optional)
sphere used to build the regularization B matrix.
Default: 'symmetric362'.
weight given to the constrained-positivity regularization part of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this text about the weight. Does it really belong together with reg_sphere?

the deconvolution equation (see [1]_). Default: 1
iso: int (optional)
arokem marked this conversation as resolved.
Show resolved Hide resolved
Number of tissue compartments for running the MSMT-CSD. Minimum
number of compartments required is 2.
Default: 2

References
----------
.. [1] Jeurissen, B., et al. NeuroImage 2014. Multi-tissue constrained
spherical deconvolution for improved analysisof multi-shell
diffusion MRI data
.. [2] Tournier, J.D., et al. NeuroImage 2007. Robust determination of
the fibre orientation distribution in diffusion MRI:
Non-negativity constrained super-resolved spherical
deconvolution
.. [3] Tournier, J.D, et al. Imaging Systems and Technology
2012. MRtrix: Diffusion Tractography in Crossing Fiber Regions
"""

sh_order = response.sh_order
super(MultiShellDeconvModel, self).__init__(gtab)
B, m, n = multi_tissue_basis(gtab, sh_order, iso)

delta_f = delta_functions[delta_form]
delta = delta_f(response.iso, response.m, response.n, 0., 0.)
self.delta = delta
multiplier_matrix = _inflate_response(response, gtab, n, delta)

r, theta, phi = geo.cart2sphere(*reg_sphere.vertices.T)
odf_reg, _, _ = shm.real_sym_sh_basis(sh_order, theta, phi)
reg = np.zeros([i + iso for i in odf_reg.shape])
reg[:iso, :iso] = np.eye(iso)
reg[iso:, iso:] = odf_reg

X = B * multiplier_matrix

self.fitter = QpFitter(X, reg)
self.sh_order = sh_order
self._X = X
self.sphere = reg_sphere
self.response = response
self.B_dwi = B
self.m = m
self.n = n

def predict(self, params, gtab=None, S0=None):
"""Compute a signal prediction given spherical harmonic coefficients
for the provided GradientTable class instance.

Parameters
----------
params : ndarray
The spherical harmonic representation of the FOD from which to make
the signal prediction.
gtab : GradientTable
The gradients for which the signal will be predicted. Use the
model's gradient table by default.
S0 : ndarray or float
The non diffusion-weighted signal value.
Default : None
"""
if gtab is None:
X = self._X
else:
iso = self.response.iso
B, m, n = multi_tissue_basis(gtab, self.sh_order, iso)
multiplier_matrix = _inflate_response(self.response, gtab, n,
self.delta)
X = B * multiplier_matrix
return np.dot(params, X.T)

@multi_voxel_fit
def fit(self, data):
coeff = self.fitter(data)
return MSDeconvFit(self, coeff, None)


class MSDeconvFit(shm.SphHarmFit):
arokem marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, model, coeff, mask):
self._shm_coef = coeff
self.mask = mask
self.model = model

@property
def shm_coeff(self):
return self._shm_coef[..., self.model.response.iso:]

@property
def volume_fractions(self):
tissue_classes = self.model.response.iso + 1
return self._shm_coef[..., :tissue_classes] / SH_CONST


def _rank(A, tol=1e-8):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required anymore!

s = la.svd(A, False, False)
threshold = (s[0] * tol)
rnk = (s > threshold).sum()
return rnk


def quadprog(P, Q, G, H):
r"""
Helper funstion to set up the Quadratic Program (QP) solver in CVXPY.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: funstion => function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation is a bit misleading. This doesn't just set up the problem. It sets it up and then solves it.

A QP problem has the following form:
minimize 1/2 x' P x + Q' x
subject to G x <= H

Here the QP solver is based on CVXPY and uses OSQP by default.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both here and above, the solver is not a default that can be changed. This should be something like "Here the QP solver is based on CVXPY and uses OSQP." without that last phrase.

"""
x = cvx.Variable(Q.shape[0])
P = cvx.Constant(P)
objective = cvx.Minimize(0.5 * cvx.quad_form(x, P) + Q * x)
constraints = [G*x <= H]

# setting up the problem
prob = cvx.Problem(objective, constraints)
prob.solve()
opt = np.array(x.value).reshape((Q.shape[0],))
return opt


class QpFitter(object):

def __init__(self, X, reg):
r"""
Makes use of the quadratic programming solver `quadprog` to fit the
model. The initialization for the model is done using the warm-start by
default in `CVXPY`.
arokem marked this conversation as resolved.
Show resolved Hide resolved
"""
self._P = P = np.dot(X.T, X)
self._X = X

# No super res for now.
assert _rank(P) == P.shape[0]

self._reg = reg
self._P_mat = np.array(P)
self._reg_mat = np.array(-reg)
self._h_mat = np.array([0])

def __call__(self, signal):
Q = np.dot(self._X.T, signal)
Q_mat = np.array(-Q)
fodf_sh = quadprog(self._P_mat, Q_mat, self._reg_mat, self._h_mat)
return fodf_sh
6 changes: 3 additions & 3 deletions dipy/reconst/shm.py
Expand Up @@ -600,11 +600,11 @@ def odf(self, sphere):

"""
B = self.model.sampling_matrix(sphere)
return dot(self._shm_coef, B.T)
return dot(self.shm_coeff, B.T)

@auto_attr
def gfa(self):
return _gfa_sh(self._shm_coef, 0)
return _gfa_sh(self.shm_coeff, 0)

@property
def shm_coeff(self):
Expand Down Expand Up @@ -632,7 +632,7 @@ def predict(self, gtab=None, S0=1.0):
if not hasattr(self.model, 'predict'):
msg = "This model does not have prediction implemented yet"
raise NotImplementedError(msg)
return self.model.predict(self.shm_coeff, gtab, S0)
return self.model.predict(self._shm_coef, gtab, S0)


class CsaOdfModel(QballBaseModel):
Expand Down