Skip to content

Commit

Permalink
Merge branch 'nipy-dipy-master' into textactor2d
Browse files Browse the repository at this point in the history
  • Loading branch information
dmreagan committed Aug 4, 2017
2 parents b33d5b9 + a0df597 commit f92f06b
Show file tree
Hide file tree
Showing 13 changed files with 354 additions and 298 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Expand Up @@ -17,7 +17,7 @@ addons:

env:
global:
- DEPENDS="cython numpy scipy matplotlib h5py nibabel cvxopt"
- DEPENDS="cython numpy scipy matplotlib h5py nibabel cvxpy"
- VENV_ARGS="--python=python"
- INSTALL_TYPE="setup"
- EXTRA_WHEELS="https://5cf40426d9f06eb7461d-6fe47d9331aba7cd62fc36c7196769e4.ssl.cf2.rackcdn.com"
Expand Down Expand Up @@ -88,6 +88,7 @@ before_install:
# Needed for Python 3.5 wheel fetching
- $PIPI --upgrade pip setuptools
- $PIPI nose;
- $PIPI numpy;
- if [ -n "$DEPENDS" ]; then $PIPI $DEPENDS; fi
- if [ "${COVERAGE}" == "1" ]; then pip install coverage coveralls codecov; fi
- if [ "${VTK}" == "1" ]; then
Expand Down
4 changes: 2 additions & 2 deletions cythexts.py
Expand Up @@ -38,7 +38,7 @@ def stamped_pyx_ok(exts, hash_stamp_fname):
for mod in exts:
for source in mod.sources:
base, ext = splitext(source)
if not ext in ('.pyx', '.py'):
if ext not in ('.pyx', '.py'):
continue
source_hash = sha1(open(source, 'rb').read()).hexdigest()
c_fname = base + '.c'
Expand All @@ -58,7 +58,7 @@ def stamped_pyx_ok(exts, hash_stamp_fname):
if line.startswith('#'):
continue
fname, hash = [e.strip() for e in line.split(',')]
if not hash in stamps:
if hash not in stamps:
return False
# Compare path made canonical for \/
fname = fname.replace(filesep, '/')
Expand Down
33 changes: 16 additions & 17 deletions dipy/io/dpy.py
@@ -1,12 +1,12 @@
''' A class for handling large tractography datasets.
""" A class for handling large tractography datasets.
It is built using the pytables tools which in turn implement
key features of the HDF5 (hierachical data format) API [1]_.
References
----------
.. [1] http://www.hdfgroup.org/HDF5/doc/H5.intro.html
'''
"""

import numpy as np

Expand All @@ -22,8 +22,7 @@
tables, have_tables, _ = optional_package('tables')

# Useful variable for backward compatibility.
if have_tables:
TABLES_LESS_3_0 = LooseVersion(tables.__version__) < "3.0"
TABLES_LESS_3_0 = LooseVersion(tables.__version__) < "3.0" if have_tables else False

# Make sure not to carry across setup module from * import
__all__ = ['Dpy']
Expand All @@ -32,7 +31,7 @@
class Dpy(object):
@doctest_skip_parser
def __init__(self, fname, mode='r', compression=0):
''' Advanced storage system for tractography based on HDF5
""" Advanced storage system for tractography based on HDF5
Parameters
------------
Expand All @@ -50,7 +49,7 @@ def __init__(self, fname, mode='r', compression=0):
>>> from dipy.io.dpy import Dpy
>>> def dpy_example():
... fd,fname = mkstemp()
... fname = fname + '.dpy' #add correct extension
... fname += '.dpy'#add correct extension
... dpw = Dpy(fname,'w')
... A=np.ones((5,3))
... B=2*A.copy()
Expand All @@ -67,7 +66,7 @@ def __init__(self, fname, mode='r', compression=0):
... os.remove(fname) #delete file from disk
>>> dpy_example() # skip if not have_tables
'''
"""

self.mode = mode
self.f = tables.openFile(fname, mode=self.mode) if TABLES_LESS_3_0 else tables.open_file(fname, mode=self.mode)
Expand Down Expand Up @@ -116,30 +115,30 @@ def version(self):
return ver[0].decode()

def write_track(self, track):
''' write on track each time
'''
""" write on track each time
"""
self.tracks.append(track.astype(np.float32))
self.curr_pos += track.shape[0]
self.offsets.append(np.array([self.curr_pos]).astype(np.int64))

def write_tracks(self, T):
''' write many tracks together
'''
""" write many tracks together
"""
for track in T:
self.tracks.append(track.astype(np.float32))
self.curr_pos += track.shape[0]
self.offsets.append(np.array([self.curr_pos]).astype(np.int64))

def read_track(self):
''' read one track each time
'''
""" read one track each time
"""
off0, off1 = self.offsets[self.offs_pos:self.offs_pos + 2]
self.offs_pos += 1
return self.tracks[off0:off1]

def read_tracksi(self, indices):
''' read tracks with specific indices
'''
""" read tracks with specific indices
"""
T = []
for i in indices:
# print(self.offsets[i:i+2])
Expand All @@ -148,8 +147,8 @@ def read_tracksi(self, indices):
return T

def read_tracks(self):
''' read the entire tractography
'''
""" read the entire tractography
"""
I = self.offsets[:]
TR = self.tracks[:]
T = []
Expand Down
40 changes: 13 additions & 27 deletions dipy/io/peaks.py
Expand Up @@ -3,24 +3,21 @@
import os
import numpy as np

from dipy.core.sphere import Sphere
from dipy.direction.peaks import (PeaksAndMetrics,
reshape_peaks_for_visualization)
from dipy.core.sphere import Sphere
from dipy.io.image import save_nifti

from distutils.version import LooseVersion

# Conditional import machinery for pytables
from dipy.utils.optpkg import optional_package

# Allow import, but disable doctests, if we don't have pytables
tables, have_tables, _ = optional_package('tables')
tables, have_tables, _ = optional_package('tables', 'PyTables is not installed')

# Useful variable for backward compatibility.
if have_tables:
TABLES_LESS_3_0 = LooseVersion(tables.__version__) < "3.0"

from dipy.data import get_sphere
from dipy.core.sphere import Sphere
from dipy.io.image import save_nifti
TABLES_LESS_3_0 = LooseVersion(tables.__version__) < "3.0" if have_tables else False


def _safe_save(f, group, array, name):
Expand All @@ -34,10 +31,11 @@ def _safe_save(f, group, array, name):
name : string
"""

if TABLES_LESS_3_0:
func_create_carray = f.createCArray
else:
func_create_carray = f.create_carray
if not have_tables:
# We generate a TripWireError via this call
_ = tables.any_attributes

func_create_carray = f.createCArray if TABLES_LESS_3_0 else f.create_carray

if array is not None:
atom = tables.Atom.from_dtype(array.dtype)
Expand Down Expand Up @@ -189,20 +187,9 @@ def save_peaks(fname, pam, affine=None, verbose=False):
[b"0.0.1"], 'PAM5 version number')
version_string = f.root.version[0].decode()

try:
affine = pam.affine
except AttributeError:
pass

try:
shm_coeff = pam.shm_coeff
except AttributeError:
shm_coeff = None

try:
odf = pam.odf
except AttributeError:
odf = None
affine = pam.affine if hasattr(pam, 'affine') else affine
shm_coeff = pam.shm_coeff if hasattr(pam, 'shm_coeff') else None
odf = pam.odf if hasattr(pam, 'odf') else None

_safe_save(f, group, affine, 'affine')
_safe_save(f, group, pam.peak_dirs, 'peak_dirs')
Expand Down Expand Up @@ -271,4 +258,3 @@ def peaks_to_niftis(pam,
save_nifti(fname_indices, pam.peak_indices, pam.affine)

save_nifti(fname_gfa, pam.gfa, pam.affine)

65 changes: 51 additions & 14 deletions dipy/io/tests/test_io_peaks.py
Expand Up @@ -4,31 +4,26 @@
import numpy as np
import numpy.testing as npt

from distutils.version import LooseVersion
from dipy.reconst.peaks import PeaksAndMetrics
from nibabel.tmpdirs import InTemporaryDirectory

from dipy.reconst.peaks import PeaksAndMetrics
from dipy.data import get_sphere
from dipy.io.peaks import load_peaks, save_peaks, peaks_to_niftis, _safe_save

# Conditional import machinery for pytables
from dipy.utils.optpkg import optional_package
from dipy.utils.tripwire import TripWireError

# Allow import, but disable doctests, if we don't have pytables
tables, have_tables, _ = optional_package('tables')

from dipy.data import get_sphere
from dipy.core.sphere import Sphere

from dipy.io.peaks import load_peaks, save_peaks, peaks_to_niftis

# Decorator to protect tests from being run without pytables present
iftables = npt.dec.skipif(not have_tables,
'Pytables does not appear to be installed')
iftables = npt.dec.skipif(not have_tables, 'Pytables does not appear to be installed')


@iftables
def test_io_peaks():

with InTemporaryDirectory():

fname = 'test.pam5'

sphere = get_sphere('repulsion724')
Expand All @@ -54,7 +49,9 @@ def test_io_peaks():
pam2.affine = None

fname2 = 'test2.pam5'
save_peaks(fname2, pam2)
save_peaks(fname2, pam2, np.eye(4))
pam2_res = load_peaks(fname2, verbose=True)
npt.assert_array_equal(pam.peak_dirs, pam2_res.peak_dirs)

pam3 = load_peaks(fname2, verbose=False)

Expand Down Expand Up @@ -97,7 +94,6 @@ def test_io_peaks():

del pam.shm_coeff
save_peaks(fname6, pam, verbose=True)
pam_tmp = load_peaks(fname6, True)

fname_shm = 'shm.nii.gz'
fname_dirs = 'dirs.nii.gz'
Expand All @@ -116,6 +112,47 @@ def test_io_peaks():
os.path.isfile(fname_gfa)


if __name__ == '__main__':
def test_io_save_peaks_error():
with InTemporaryDirectory():
fname = 'test.pam5'

pam = PeaksAndMetrics()

npt.assert_raises(IOError, save_peaks, 'test.pam', pam)
npt.assert_raises(ValueError, save_peaks, fname, pam)

sphere = get_sphere('repulsion724')

pam.affine = np.eye(4)
pam.peak_dirs = np.random.rand(10, 10, 10, 5, 3)
pam.peak_values = np.zeros((10, 10, 10, 5))
pam.peak_indices = np.zeros((10, 10, 10, 5))
pam.shm_coeff = np.zeros((10, 10, 10, 45))
pam.sphere = sphere
pam.B = np.zeros((45, sphere.vertices.shape[0]))
pam.total_weight = 0.5
pam.ang_thr = 60
pam.gfa = np.zeros((10, 10, 10))
pam.qa = np.zeros((10, 10, 10, 5))
pam.odf = np.zeros((10, 10, 10, sphere.vertices.shape[0]))

if not have_tables:
npt.assert_raises(TripWireError, save_peaks, fname, pam)


@npt.dec.skipif(have_tables)
def test_io_safe_save_error():

class Mock(object):
pass

fake_hdf5 = Mock()
fake_group = Mock()
fake_array = np.eye(4)
fake_name = "function_tester"

npt.assert_raises(TripWireError, _safe_save, fake_hdf5, fake_group, fake_array, fake_name)


if __name__ == '__main__':
npt.run_module_suite()

0 comments on commit f92f06b

Please sign in to comment.