Skip to content

Commit

Permalink
Merge pull request #96 from bmcfee/length-exception-sampler
Browse files Browse the repository at this point in the history
fixed #95, check data duration before sampling
  • Loading branch information
bmcfee committed Jun 19, 2018
2 parents 68a14ca + d383865 commit a20c810
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pumpp/sampler.py
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from .base import Slicer
from .exceptions import ParameterError
from .exceptions import ParameterError, DataError

__all__ = ['Sampler', 'SequentialSampler', 'VariableLengthSampler']

Expand Down Expand Up @@ -134,6 +134,10 @@ def indices(self, data):
'''
duration = self.data_duration(data)

if self.duration > duration:
raise DataError('Data duration={} is less than '
'sample duration={}'.format(duration, self.duration))

while True:
# Generate a sampling interval
yield self.rng.randint(0, duration - self.duration + 1)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_sampler.py
Expand Up @@ -219,3 +219,37 @@ def test_vlsampler(data, ops, n_samples, durations, rng):
assert n == MAX_SAMPLES - 1
else:
assert n == n_samples - 1


@pytest.mark.xfail(raises=pumpp.DataError)
def test_sampler_short_error(data, ops):

MAX_SAMPLES = 2
sampler = pumpp.Sampler(MAX_SAMPLES, 5000, *ops)

# Build the set of reference keys that we want to track
ref_keys = set()
for op in ops:
ref_keys |= set(op.fields.keys())

for datum, n in zip(sampler(data), range(MAX_SAMPLES)):
# First, test that we have the right fields
assert set(datum.keys()) == ref_keys

# Now test that shape is preserved in the right way
for key in datum:
ref_shape = list(data[key].shape)
for tdim in sampler._time[key]:
ref_shape[tdim] = duration

# Check that all keys have length=1
assert datum[key].shape[0] == 1
assert list(datum[key].shape[1:]) == ref_shape[1:]

# Test that we got the right number of samples out
if n_samples is None:
assert n == MAX_SAMPLES - 1
else:
assert n == n_samples - 1


0 comments on commit a20c810

Please sign in to comment.