Skip to content

Commit

Permalink
added some tests for the VirtualSource class and made the code a bit …
Browse files Browse the repository at this point in the history
…more robust.
  • Loading branch information
Aaron Parsons authored and takluyver committed Jun 24, 2018
1 parent ace8955 commit d87ca6e
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 13 deletions.
30 changes: 23 additions & 7 deletions h5py/_hl/group.py
Expand Up @@ -638,17 +638,33 @@ def __getitem__(self, *key):
raise IndexError('Index rank is greater than dataset rank')
# need to deal with integer inputs
tmp = copy(self)
tmp.slice_list = list(key[0] + (slice(None, None, None),)*(len(self.shape)-len(key[0]))) # generate the right slice

if not isinstance(key, tuple):
key = tuple(key)
elif not isinstance(key[0],slice):
key = key[0]
tmp.slice_list = list(key + (slice(None, None, None),)*(len(self.shape)-len(key))) # generate the right slice

# sanitize this slice list to get rid of the nones and integers/floats(?)
tmp.slice_list = [slice(ix) if isinstance(ix, (int,float)) else ix for ix in tmp.slice_list]
new_shape = ()
tmp.slice_list = [slice(ix,ix+1,1) if isinstance(ix, (int,float)) else ix for ix in tmp.slice_list]
new_shape = []
for ix,sl in enumerate(tmp.slice_list):
start = 0 if sl.start is None else sl.start
step = 1 if sl.step is None else sl.step
stop = self.shape[ix]
new_shape+=((stop-start)/abs(step),)
if step>0:
start = 0 if sl.start is None else sl.start
stop = self.shape[ix] if sl.stop is None else sl.stop
new_shape.append((stop-start)/step)
elif step<0:
stop = 0 if sl.stop is None else sl.stop
start = self.shape[ix] if sl.start is None else sl.start
if start>stop: # this gets the same behaviour as numpy array
new_shape.append((start-stop)/abs(step))
else:
new_shape.append(0)
elif step==0:
raise IndexError("A step of 0 is not valid")
tmp.slice_list[ix] = slice(start,stop,step)
tmp.shape = new_shape
tmp.shape = tuple(new_shape)
return tmp

class VirtualTarget(DatasetContainer):
Expand Down
2 changes: 2 additions & 0 deletions h5py/tests/hl/test_vds/__init__.py
@@ -1,3 +1,5 @@
from __future__ import absolute_import

from .test_eiger_high_level import *
from .test_eiger_low_level import *
from .test_excalibur_high_level import *
Expand Down
2 changes: 1 addition & 1 deletion h5py/tests/hl/test_vds/test_eiger_low_level.py
Expand Up @@ -10,7 +10,7 @@
import tempfile


class EigerLowLevelTest(unittest.TestCase):
class TestEigerLowLevel(unittest.TestCase):
def setUp(self):
self.working_dir = tempfile.mkdtemp()
self.fname = ['raw_file_1.h5', 'raw_file_2.h5', 'raw_file_3.h5']
Expand Down
2 changes: 1 addition & 1 deletion h5py/tests/hl/test_vds/test_excalibur_high_level.py
Expand Up @@ -40,7 +40,7 @@ def generate_fem_stripe_image(self, value, dtype='uint16'):



class ExcaliburHighLevelTest(unittest.TestCase):
class TestExcaliburHighLevel(unittest.TestCase):
def create_excalibur_fem_stripe_datafile(self, fname, nframes, excalibur_data,scale):
shape = (nframes,) + excalibur_data.fem_stripe_dimensions
max_shape = shape#(None,) + excalibur_data.fem_stripe_dimensions
Expand Down
2 changes: 1 addition & 1 deletion h5py/tests/hl/test_vds/test_excalibur_low_level.py
Expand Up @@ -40,7 +40,7 @@ def generate_fem_stripe_image(self, value, dtype='uint16'):



class ExcaliburLowLevelTest(unittest.TestCase):
class TestExcaliburLowLevel(unittest.TestCase):
def create_excalibur_fem_stripe_datafile(self, fname, nframes, excalibur_data,scale):
shape = (nframes,) + excalibur_data.fem_stripe_dimensions
max_shape = (nframes,) + excalibur_data.fem_stripe_dimensions
Expand Down
2 changes: 1 addition & 1 deletion h5py/tests/hl/test_vds/test_percival_high_level.py
Expand Up @@ -9,7 +9,7 @@
import h5py as h5
import tempfile

class PercivalHighLevelTest(unittest.TestCase):
class TestPercivalHighLevel(unittest.TestCase):

def setUp(self):
self.working_dir = tempfile.mkdtemp()
Expand Down
2 changes: 1 addition & 1 deletion h5py/tests/hl/test_vds/test_percival_low_level.py
Expand Up @@ -9,7 +9,7 @@
import h5py as h5
import tempfile

class PercivalLowLevelTest(unittest.TestCase):
class TestPercivalLowLevel(unittest.TestCase):

def setUp(self):
self.working_dir = tempfile.mkdtemp()
Expand Down
138 changes: 138 additions & 0 deletions h5py/tests/hl/vds_tests/test_virtual_source.py
@@ -0,0 +1,138 @@
import unittest
import h5py as h5


class TestVirtualSource(unittest.TestCase):
def test_full_slice(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[:,:,:]
self.assertEqual(dataset.shape,sliced.shape)

def test_full_slice_inverted(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[:,:,::-1]
self.assertEqual(dataset.shape,sliced.shape)

def test_subsampled_slice_inverted(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[:,:,::-2]
self.assertEqual((20,30,15),sliced.shape)

def test_integer_indexed(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[5,:,:]
self.assertEqual((1,30,30),sliced.shape)

def test_integer_single_indexed(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[5]
self.assertEqual((1,30,30),sliced.shape)

def test_two_integer_indexed(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[5,:,10]
self.assertEqual((1,30,1),sliced.shape)

def test_single_range(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[5:10,:,:]
self.assertEqual((5,)+dataset.shape[1:],sliced.shape)

def test_shape_calculation_positive_step(self):
dataset = h5.VirtualSource('test','test',(20,))
cmp = []
for i in range(5):
d = dataset[2:12+i:3].shape[0]
ref = np.arange(20)[2:12+i:3].size
cmp.append(ref==d)
self.assertEqual(5, sum(cmp))

def test_shape_calculation_positive_step_switched_start_stop(self):
dataset = h5.VirtualSource('test','test',(20,))
cmp = []
for i in range(5):
d = dataset[12+i:2:3].shape[0]
ref = np.arange(20)[12+i:2:3].size
print d,ref
cmp.append(ref==d)
self.assertEqual(5, sum(cmp))


def test_shape_calculation_negative_step(self):
dataset = h5.VirtualSource('test','test',(20,))
cmp = []
for i in range(5):
d = dataset[12+i:2:-3].shape[0]
ref = np.arange(20)[12+i:2:-3].size
cmp.append(ref==d)
self.assertEqual(5, sum(cmp))

def test_shape_calculation_negative_step_switched_start_stop(self):
dataset = h5.VirtualSource('test','test',(20,))
cmp = []
for i in range(5):
d = dataset[2:12+i:-3].shape[0]
ref = np.arange(20)[2:12+i:-3].size
cmp.append(ref==d)
self.assertEqual(5, sum(cmp))


def test_double_range(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[5:10,:,20:25]
self.assertEqual((5,30,5),sliced.shape)

def test_double_strided_range(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[6:12:2,:,20:26:3]
self.assertEqual((3,30,2,),sliced.shape)

def test_double_strided_range_inverted(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[12:6:-2,:,26:20:-3]
self.assertEqual((3,30,2),sliced.shape)

def test_negative_start_index(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[-10:16]
self.assertEqual((6,30,30),sliced.shape)

def test_negative_stop_index(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[10:-4]
self.assertEqual((6,30,30),sliced.shape)

def test_negative_start_and_stop_index(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[-10:-4]
self.assertEqual((6,30,30),sliced.shape)

def test_negative_start_and_stop_and_stride_index(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[-4:-10:-2]
self.assertEqual((3,30,30),sliced.shape)
#
def test_ellipsis(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[...]
self.assertEqual(dataset.shape,sliced.shape)

def test_ellipsis_end(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[0,...]
self.assertEqual((1,)+dataset.shape[1:],sliced.shape)

def test_ellipsis_start(self):
dataset = h5.VirtualSource('test','test',(20,30,30))
sliced = dataset[...,0]
self.assertEqual(dataset.shape[:-1]+(1,),sliced.shape)

def test_ellipsis_sandwich(self):
dataset = h5.VirtualSource('test','test',(20,30,30,40))
sliced = dataset[0,...,5]
self.assertEqual((1,)+dataset.shape[1:-1]+(1,),sliced.shape)



if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion setup.py
Expand Up @@ -158,7 +158,9 @@ def run(self):
license = 'BSD',
url = 'http://www.h5py.org',
download_url = 'https://pypi.python.org/pypi/h5py',
packages = ['h5py', 'h5py._hl', 'h5py.tests', 'h5py.tests.old', 'h5py.tests.hl'],
packages = ['h5py', 'h5py._hl', 'h5py.tests',
'h5py.tests.old', 'h5py.tests.hl',
'h5py.tests.hl.test_vds'],
package_data = package_data,
ext_modules = [Extension('h5py.x',['x.c'])], # To trick build into running build_ext
install_requires = RUN_REQUIRES,
Expand Down

0 comments on commit d87ca6e

Please sign in to comment.