Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PeaksAndMetrics pickle-able #1195

Merged
merged 14 commits into from
Mar 22, 2017
Merged
93 changes: 71 additions & 22 deletions dipy/direction/peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,68 @@ def peak_directions(odf, sphere, relative_peak_threshold=.5,
return directions, values, indices


def _pam_from_attrs(klass, sphere, peak_indices, peak_values, peak_dirs,
gfa, qa, shm_coeff, B, odf):
"""
Construct a PeaksAndMetrics object (or object of a subclass) from its
attributes.

This is also useful for pickling/unpickling of these objects (see also
:func:`__reduce__` below).

Parameters
----------
klass : class
The class of object to be created.
sphere : `Sphere` class instance.
Sphere for discretization.
peak_indices : ndarray
Indices (in sphere vertices) of the peaks in each voxel.
peak_values : ndarray
The value of the peaks.
peak_dirs : ndarray
The direction of each peak.
gfa : ndarray
The Generalized Fractional Anisotropy in each voxel.
qa : ndarray
Quantitative Anisotropy in each voxel.
shm_coeff : ndarray
The coefficients of the spherical harmonic basis for the ODF in
each voxel.
B : ndarray
The spherical harmonic matrix, for multiplication with the
coefficients.
odf : ndarray
The orientation distribution function on the sphere in each voxel.

Returns
-------
pam : Instance of the class `klass`.
"""
this_pam = klass()
this_pam.sphere = sphere
this_pam.peak_dirs = peak_dirs
this_pam.peak_values = peak_values
this_pam.peak_indices = peak_indices
this_pam.gfa = gfa
this_pam.qa = qa
this_pam.shm_coeff = shm_coeff
this_pam.B = B
this_pam.odf = odf
return this_pam


class PeaksAndMetrics(PeaksAndMetricsDirectionGetter):
pass
def __reduce__(self): return _pam_from_attrs, (self.__class__,
self.sphere,
self.peak_indices,
self.peak_values,
self.peak_dirs,
self.gfa,
self.qa,
self.shm_coeff,
self.B,
self.odf)


def _peaks_from_model_parallel(model, data, sphere, relative_peak_threshold,
Expand Down Expand Up @@ -480,27 +540,16 @@ def peaks_from_model(model, data, sphere, relative_peak_threshold,

qa_array /= global_max

pam = PeaksAndMetrics()
pam.sphere = sphere
pam.peak_dirs = peak_dirs
pam.peak_values = peak_values
pam.peak_indices = peak_indices
pam.gfa = gfa_array
pam.qa = qa_array

if return_sh:
pam.shm_coeff = shm_coeff
pam.B = B
else:
pam.shm_coeff = None
pam.B = None

if return_odf:
pam.odf = odf_array
else:
pam.odf = None

return pam
return _pam_from_attrs(PeaksAndMetrics,
sphere,
peak_indices,
peak_values,
peak_dirs,
gfa_array,
qa_array,
shm_coeff if return_sh else None,
B if return_sh else None,
odf_array if return_odf else None)


def gfa(samples):
Expand Down
26 changes: 26 additions & 0 deletions dipy/direction/tests/test_peaks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import numpy as np

import pickle
from io import BytesIO

from numpy.testing import (assert_array_equal, assert_array_almost_equal,
assert_almost_equal, run_module_suite,
assert_equal, assert_)
Expand Down Expand Up @@ -485,6 +489,28 @@ def test_peaksFromModel():
assert_array_equal(pam.peak_indices[mask, 0], odf_argmax)
assert_array_equal(pam.peak_indices[mask, 1:], -1)

# Test serialization and deserialization:
for normalize_peaks in [True, False]:
for return_odf in [True, False]:
for return_sh in [True, False]:
pam = peaks_from_model(model, data, _sphere, .5, 45,
normalize_peaks=normalize_peaks,
return_odf=return_odf,
return_sh=return_sh)

b = BytesIO()
pickle.dump(pam, b)
b.seek(0)
new_pam = pickle.load(b)
b.close()

for attr in ['peak_dirs', 'peak_values', 'peak_indices',
'gfa', 'qa', 'shm_coeff', 'B', 'odf']:
assert_array_equal(getattr(pam, attr),
getattr(new_pam, attr))
assert_array_equal(pam.sphere.vertices,
new_pam.sphere.vertices)


def test_peaksFromModelParallel():
SNR = 100
Expand Down
1 change: 0 additions & 1 deletion dipy/reconst/peak_direction_getter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,3 @@ cdef class PeaksAndMetricsDirectionGetter(DirectionGetter):
return 0
else:
return 1