Skip to content

Commit

Permalink
RF - clean up MSD and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MrBago committed Dec 25, 2016
1 parent 6dce91a commit 653504c
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 80 deletions.
130 changes: 73 additions & 57 deletions dipy/reconst/msd.py
@@ -1,67 +1,30 @@
import numpy as np
import numpy.linalg as la

from dipy.core.gradients import GradientTable
from dipy.sims.voxel import single_tensor

from dipy.core.geometry import cart2sphere
from dipy.core import geometry as geo
from dipy.data import default_sphere
from dipy.reconst import shm
from dipy.reconst import csdeconv as csd
from dipy.reconst.multi_voxel import multi_voxel_fit

from cvxopt.solvers import lp, qp
from cvxopt import matrix, solvers

from cvxopt import matrix
from cvxopt import solvers
from cvxopt.solvers import qp
solvers.options['show_progress'] = False

csf_md=3e-3
gm_md=.76e-3
evals_d = np.array([.992, .254, .254]) * 1e-3

def sim_response(sh_order, bvals, evals=evals_d, csf_md=3e-3, gm_md=.76e-3):
bvals = np.array(bvals, copy=True)
evecs = np.zeros((3, 3))
z = np.array([0, 0, 1.])
evecs[:, 0] = z
evecs[:2, 1:] = np.eye(2)

n = np.arange(0, sh_order + 1, 2)
m = np.zeros_like(n)

big_sphere = default_sphere.subdivide()
theta, phi = big_sphere.theta, big_sphere.phi

B = shm.real_sph_harm(m, n, theta[:, None], phi[:, None])
A = shm.real_sph_harm(0, 0, 0, 0)

response = np.empty([len(bvals), len(n) + 2])
for i, bvalue in enumerate(bvals):
gtab = GradientTable(big_sphere.vertices * bvalue)
wm_response = single_tensor(gtab, 1., evals, evecs, snr=None)
response[i, 2:] = np.linalg.lstsq(B, wm_response)[0]

response[i, 0] = np.exp(-bvalue * csf_md) / A
response[i, 1] = np.exp(-bvalue * gm_md) / A

return MultiShellResponse(response, sh_order, bvals)

sh_const = .5 / np.sqrt(np.pi)

def multi_tissue_basis(gtab, sh_order, iso_comp):
"""Builds a basis for multi-shell CSD model"""
r, theta, phi = cart2sphere(*gtab.gradients.T)
if iso_comp < 1:
msg = ("Multi-tissue CSD requires at least 2 tissue compartments")
raise ValueError(msg)
r, theta, phi = geo.cart2sphere(*gtab.gradients.T)
m, n = shm.sph_harm_ind_list(sh_order)
B = shm.real_sph_harm(m, n, theta[:, None], phi[:, None])

if iso_comp == 0:
B[gtab.b0s_mask, :] = 0.
return B
else:
B[np.ix_(gtab.b0s_mask, n > 0)] = 0.
B[np.ix_(gtab.b0s_mask, n > 0)] = 0.

iso = np.empty([B.shape[0], iso_comp])
iso[:] = shm.real_sph_harm(0, 0, 0, 0)
iso[:] = sh_const

B = np.concatenate([iso, B], axis=1)
return B, m, n
Expand All @@ -75,8 +38,7 @@ def __init__(self, response, sh_order, shells):
self.n = np.arange(0, sh_order + 1, 2)
self.m = np.zeros_like(self.n)
self.shells = shells
self.n_isotripic = response.shape[1] - len(self.n)
if self.n_isotripic < 1:
if self.iso < 1:
raise ValueError("sh_order and shape of response do not agree")

@property
Expand All @@ -89,7 +51,7 @@ def closest(haystack, needle):
return diff.argmin(axis=0)


def _inflate_response(response, gtab, n):
def _inflate_response(response, gtab, n, delta):
if any((n % 2) != 0) or (n.max() // 2) >= response.sh_order:
raise ValueError("Response and n do not match")

Expand All @@ -99,23 +61,74 @@ def _inflate_response(response, gtab, n):
n_idx[iso:] = n // 2 + iso

b_idx = closest(response.shells, gtab.bvals)
kernal = response.response / delta

return kernal[np.ix_(b_idx, n_idx)]


def _basic_delta(iso, m, n, theta, phi):
"""Simple delta function"""
wm_d = shm.gen_dirac(m, n, theta, phi)
iso_d = [sh_const] * iso
return np.concatenate([iso_d, wm_d])


def _pos_constrained_delta(iso, m, n, theta, phi, reg_sphere=default_sphere):
"""Delta function optimized to avoid negative lobes."""

x, y, z = geo.sphere2cart(1., theta, phi)

# Realign reg_sphere so that the first vertex is aligned with delta
# orientation (theta, phi).
M = geo.vec2vec_rotmat(reg_sphere.vertices[0], [x, y, z])
new_vertices = np.dot(reg_sphere.vertices, M.T)
_, t, p = geo.cart2sphere(*new_vertices.T)

B = shm.real_sph_harm(m, n, t[:, None], p[:, None])
G = B[:, n != 0]
# c samples the delta function at the delta orientation.
c = G[0]
a, b = G.shape

c = matrix(-c)
G = matrix(-G)
h = matrix(sh_const**2, (a, 1))

# n == 0 is set to sh_const to ensure a normalized delta function.
# n > 0 values are optimized so that delta > 0 on all points of the sphere
# and delta(theta, phi) is maximized.
r = lp(c, G, h)
x = np.asarray(r['x'])[:, 0]
out = np.zeros(B.shape[1])
out[n == 0] = sh_const
out[n != 0] = x

iso_d = [sh_const] * iso
return np.concatenate([iso_d, out])

delta_functions = {"basic":_basic_delta,
"positivity_constrained":_pos_constrained_delta}

return response.response[np.ix_(b_idx, n_idx)]

class MultiShellDeconvModel(shm.SphHarmModel):

def __init__(self, gtab, response, reg_sphere=default_sphere, iso=2):
def __init__(self, gtab, response, reg_sphere=default_sphere, iso=2,
delta_form='basic'):
"""
"""
sh_order = response.sh_order
super(MultiShellDeconvModel, self).__init__(gtab)
B, m, n = multi_tissue_basis(gtab, sh_order, iso)
multiplier_matrix = _inflate_response(response, gtab, n)

r, theta, phi = cart2sphere(reg_sphere.x, reg_sphere.y, reg_sphere.z)
delta_f = delta_functions[delta_form]
delta = delta_f(response.iso, response.m, response.n, 0., 0.)
self.delta = delta
multiplier_matrix = _inflate_response(response, gtab, n, delta)

r, theta, phi = geo.cart2sphere(*reg_sphere.vertices.T)
odf_reg, _, _ = shm.real_sym_sh_basis(sh_order, theta, phi)
reg = np.zeros([i + iso for i in odf_reg.shape])
reg[:iso, :iso] = np.eye(iso) * 1000.
reg[:iso, :iso] = np.eye(iso)
reg[iso:, iso:] = odf_reg

X = B * multiplier_matrix
Expand All @@ -126,14 +139,17 @@ def __init__(self, gtab, response, reg_sphere=default_sphere, iso=2):
self.sphere = reg_sphere
self.response = response
self.B_dwi = B
self.m = m
self.n = n

def predict(self, params, gtab=None, S0=None):
if gtab is None:
X = self._X
else:
iso = self.response.iso
B, m, n = multi_tissue_basis(gtab, self.sh_order, iso)
multiplier_matrix = _inflate_response(self.response, gtab, n)
multiplier_matrix = _inflate_response(self.response, gtab, n,
self.delta)
X = B * multiplier_matrix
return np.dot(params, X.T)

Expand All @@ -157,7 +173,7 @@ def shm_coeff(self):
@property
def volume_fractions(self):
tissue_classes = self.model.response.iso + 1
return self._shm_coef[..., :tissue_classes]
return self._shm_coef[..., :tissue_classes] / sh_const


def _rank(A, tol=1e-8):
Expand Down
71 changes: 48 additions & 23 deletions dipy/reconst/tests/test_msd.py
Expand Up @@ -4,29 +4,15 @@

from dipy.sims.voxel import (multi_tensor, single_tensor)
from dipy.reconst import shm
from dipy.sims import voxel as voxSim
from dipy.data import default_sphere, get_3shell_gtab
from dipy.core.gradients import GradientTable

import dipy.viz.fvtk as fvtk

def show_response(r, m, n):
sphere = default_sphere.mirror().subdivide()
theta, phi = sphere.theta, sphere.phi
B = shm.real_sph_harm(m, n, theta[:, None], phi[:, None])
sp_func = np.dot(r, B.T)
print(sp_func[:, :5])

act = fvtk.sphere_funcs(sp_func, sphere, norm=False)
ren = fvtk.ren()
fvtk.add(ren, act)
fvtk.show(ren)

csf_md=3e-3
gm_md=.76e-3
evals_d = np.array([.992, .254, .254]) * 1e-3

def sim_response(sh_order, bvals, evals=evals_d, csf_md=3e-3, gm_md=.76e-3):
def sim_response(sh_order, bvals, evals=evals_d, csf_md=csf_md, gm_md=gm_md):
bvals = np.array(bvals, copy=True)
evecs = np.zeros((3, 3))
z = np.array([0, 0, 1.])
Expand Down Expand Up @@ -54,6 +40,45 @@ def sim_response(sh_order, bvals, evals=evals_d, csf_md=3e-3, gm_md=.76e-3):
return MultiShellResponse(response, sh_order, bvals)


def _expand(m, iso, coeff):
params = np.zeros(len(m))
params[m == 0] = coeff[iso:]
params = np.concatenate([coeff[:iso], params])
return params


def test_msd_model_delta():
sh_order = 8
gtab = get_3shell_gtab()
shells = np.unique(gtab.bvals // 100.) * 100.
response = sim_response(sh_order, shells, evals_d)
model = MultiShellDeconvModel(gtab, response, delta_form='positivity_constrained')
iso = response.iso

theta, phi = default_sphere.theta, default_sphere.phi
B = shm.real_sph_harm(response.m, response.n, theta[:, None], phi[:, None])

wm_delta = model.delta.copy()
# set isotropic components to zero
wm_delta[:iso] = 0.
wm_delta = _expand(model.m, iso, wm_delta)

for i, s in enumerate(shells):
g = GradientTable(default_sphere.vertices * s)
signal = model.predict(wm_delta, g)
expected = np.dot(response.response[i, iso:], B.T)
npt.assert_array_almost_equal(signal, expected)

signal = model.predict(wm_delta, gtab)
fit = model.fit(signal)
m = model.m
npt.assert_array_almost_equal(fit.shm_coeff[m != 0], 0., 2)

expected = model.delta[2:]
npt.assert_array_almost_equal(fit.shm_coeff[m == 0], expected, 2)



def test_MultiShellDeconvModel():

gtab = get_3shell_gtab()
Expand All @@ -64,20 +89,20 @@ def test_MultiShellDeconvModel():
angles = [(0, 0), (60, 0)]

S_wm, sticks = multi_tensor(gtab, mevals, S0, angles=angles,
fractions=[50., 50.], snr=None)
fractions=[30., 70.], snr=None)
S_gm = np.exp(-gtab.bvals * gm_md)
S_csf = np.exp(-gtab.bvals * csf_md)

sh_order = 8
response = sim_response(sh_order, [0, 1000, 2000, 3500])
model = MultiShellDeconvModel(gtab, response)
signal = S_wm + 2 * S_gm + .5 * S_csf
model = MultiShellDeconvModel(gtab, response,
delta_form='positivity_constrained')
vf = [1.3, .8, 1.9]
signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm]))
fit = model.fit(signal)
q = model.predict(fit._shm_coef)

S = fit.predict()
refit = model.fit(S)
npt.assert_array_almost_equal(refit._shm_coef, fit._shm_coef)
npt.assert_array_almost_equal(fit.volume_fractions, vf, 4)

S_pred = fit.predict()
npt.assert_array_almost_equal(S_pred, signal, 3)

test_MultiShellDeconvModel()

0 comments on commit 653504c

Please sign in to comment.