Skip to content

Commit

Permalink
RF: Use the same step in prediction as you use in fitting.
Browse files Browse the repository at this point in the history
  • Loading branch information
arokem committed Feb 2, 2016
1 parent 2355db4 commit fc87b90
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
24 changes: 22 additions & 2 deletions dipy/reconst/dti.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ def adc(self, sphere):
"""
return apparent_diffusion_coef(self.quadratic_form, sphere)

def predict(self, gtab, S0=1):
def predict(self, gtab, S0=1, step=1e4):
r"""
Given a model fit, predict the signal on the vertices of a sphere
Expand Down Expand Up @@ -1162,7 +1162,27 @@ def predict(self, gtab, S0=1):
which a signal is to be predicted and $b$ is the b value provided in
the GradientTable input for that direction
"""
return tensor_prediction(self.model_params[..., 0:12], gtab, S0=S0)
shape = self.model_params.shape[:-1]
size = np.prod(shape)
step = self.model.kwargs.get('step', size)
#if step == 2:
# 1/0.
if step >= size:
return tensor_prediction(self.model_params[..., 0:12], gtab, S0=S0)
params = np.reshape(self.model_params,
(-1, self.model_params.shape[-1]))
predict = np.empty((size, gtab.bvals.shape[0]))
if isinstance(S0, np.ndarray):
S0 = S0.ravel()
for i in range(0, size, step):
if isinstance(S0, np.ndarray):
this_S0 = S0[i:i+step]
else:
this_S0 = S0
predict[i:i+step] = tensor_prediction(params[i:i+step], gtab,
S0=this_S0)
return predict.reshape(shape + (gtab.bvals.shape[0], ))



def iter_fit_tensor(step=1e4):
Expand Down
12 changes: 12 additions & 0 deletions dipy/reconst/tests/test_dti.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,18 @@ def test_predict():
p = dtif.predict(gtab, S0)
assert_equal(p.shape, data.shape)

# Use a smaller step in predicting:

dtim = dti.TensorModel(gtab, step=2)
dtif = dtim.fit(data)
S0 = np.mean(data[..., gtab.b0s_mask], -1)
p = dtif.predict(gtab, S0)
assert_equal(p.shape, data.shape)
# And with a scalar S0:
S0 = 1
p = dtif.predict(gtab, S0)
assert_equal(p.shape, data.shape)


def test_eig_from_lo_tri():
psphere = get_sphere('symmetric362')
Expand Down

0 comments on commit fc87b90

Please sign in to comment.