Skip to content

Commit

Permalink
Merge pull request #81 from bmcfee/multi-time
Browse files Browse the repository at this point in the history
Support multiple time-like indices
  • Loading branch information
bmcfee committed Jul 18, 2017
2 parents 2ca0956 + 848a535 commit c8d7be6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
16 changes: 9 additions & 7 deletions pumpp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,11 @@ def add(self, operator):
raise ParameterError('Operator {} must be a TaskTransformer '
'or FeatureExtractor'.format(operator))
for key in operator.fields:
self._time[key] = None
if None in operator.fields[key].shape:
self._time[key] = 1 + operator.fields[key].shape.index(None)
self._time[key] = []
# We add 1 to the dimension here to account for batching
for tdim, idx in enumerate(operator.fields[key].shape, 1):
if idx is None:
self._time[key].append(tdim)

def data_duration(self, data):
'''Compute the valid data duration of a dict
Expand All @@ -161,8 +163,8 @@ def data_duration(self, data):
# Find all the time-like indices of the data
lengths = []
for key in self._time:
if self._time[key] is not None:
lengths.append(data[key].shape[self._time[key]])
for idx in self._time.get(key, []):
lengths.append(data[key].shape[idx])

return min(lengths)

Expand All @@ -185,8 +187,8 @@ def crop(self, data):
data_out = dict()
for key in data:
idx = [slice(None)] * data[key].ndim
if key in self._time and self._time[key] is not None:
idx[self._time[key]] = slice(duration)
for tdim in self._time.get(key, []):
idx[tdim] = slice(duration)
data_out[key] = data[key][idx]

return data_out
4 changes: 2 additions & 2 deletions pumpp/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def sample(self, data, interval):
index[0] = self.rng.randint(0, data[key].shape[0])
index[0] = slice(index[0], index[0] + 1)

if self._time.get(key, None) is not None:
index[self._time[key]] = interval
for tdim in self._time[key]:
index[tdim] = interval

data_slice[key] = data[key][index]

Expand Down
15 changes: 10 additions & 5 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_sampler(data, ops, n_samples, duration, rng):
# Now test that shape is preserved in the right way
for key in datum:
ref_shape = list(data[key].shape)
if sampler._time.get(key, None) is not None:
ref_shape[sampler._time[key]] = duration
for tdim in sampler._time[key]:
ref_shape[tdim] = duration

# Check that all keys have length=1
assert datum[key].shape[0] == 1
Expand Down Expand Up @@ -127,8 +127,8 @@ def test_sequential_sampler(data, ops, duration, stride, rng):
# Now test that shape is preserved in the right way
for key in datum:
ref_shape = list(data[key].shape)
if sampler._time.get(key, None) is not None:
ref_shape[sampler._time[key]] = duration
for tdim in sampler._time[key]:
ref_shape[tdim] = duration

# Check that all keys have length=1
assert datum[key].shape[0] == 1
Expand All @@ -144,6 +144,7 @@ def test_slicer():
scope2 = pumpp.base.Scope('test2')
scope2.register('first', (None, 5), np.int)
scope2.register('second', (20, None), np.int)
scope2.register('square', (None, None, 3), np.int)

slicer = pumpp.base.Slicer(scope1, scope2)

Expand All @@ -152,7 +153,8 @@ def test_slicer():
'test1/second': np.random.randint(0, 7, size=(1, 2, 100)),
'test1/none': np.random.randint(0, 7, size=(1, 16, 16)),
'test2/first': np.random.randint(0, 7, size=(1, 9, 5)),
'test2/second': np.random.randint(0, 7, (1, 20, 105))}
'test2/second': np.random.randint(0, 7, (1, 20, 105)),
'test2/square': np.random.randint(0, 7, (1, 20, 20, 3))}

data_out = slicer.crop(data_in)
assert set(data_out.keys()) == set(data_in.keys())
Expand All @@ -172,6 +174,9 @@ def test_slicer():
assert data_out['test2/second'].shape == (1, 20, 8)
assert np.all(data_out['test2/second'] == data_in['test2/second'][:, :, :8])

assert data_out['test2/square'].shape == (1, 8, 8, 3)
assert np.all(data_out['test2/square'] == data_in['test2/square'][:, :8, :8, :])


@pytest.mark.xfail(raises=pumpp.ParameterError)
def test_slicer_fail():
Expand Down

0 comments on commit c8d7be6

Please sign in to comment.