Skip to content

Commit

Permalink
Merge pull request #15 from bmcfee/sampler
Browse files Browse the repository at this point in the history
Sampler module
  • Loading branch information
bmcfee committed Sep 1, 2016
2 parents 54a1755 + 02d3232 commit c5cfedd
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 3 deletions.
1 change: 1 addition & 0 deletions pumpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .core import *
from . import feature
from . import task
from .sampler import Sampler
64 changes: 64 additions & 0 deletions pumpp/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''The sampler'''

from itertools import count

import numpy as np


class Sampler(object):
def __init__(self, n_samples, duration, *ops):

self.n_samples = n_samples
self.duration = duration

fields = dict()
for op in ops:
fields.update(op.fields)

# Pre-determine which fields have time-like indices
self._time = {key: None for key in fields}
for key in fields:
if None in fields[key].shape:
# Add one for the batching index
self._time[key] = 1 + fields[key].shape.index(None)

def sample(self, data, interval):

data_slice = dict()

for key in data:
if key in self._time:
index = [slice(None)] * data[key].ndim

if self._time[key] is not None:
index[self._time[key]] = interval

data_slice[key] = data[key][index]

return data_slice

def data_duration(self, data):

# Find all the time-like indices of the data
lengths = []
for key in self._time:
if self._time[key] is not None:
lengths.append(data[key].shape[self._time[key]])

return min(lengths)

def __call__(self, data):

duration = self.data_duration(data)

for i in count(0):
# are we done?
if self.n_samples and i >= self.n_samples:
break

# Generate a sampling interval
start = np.random.randint(0, duration - self.duration)

yield self.sample(data, slice(start, start + self.duration))
2 changes: 1 addition & 1 deletion pumpp/task/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, name, namespace, dimension, dtype=np.float32):
self.dimension = dimension
self.dtype = dtype

self.register('vector', [None, self.dimension], self.dtype)
self.register('vector', [1, self.dimension], self.dtype)

def empty(self, duration):
ann = super(VectorTransformer, self).empty(duration)
Expand Down
92 changes: 92 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''Testing the sampler module'''

import pytest

import pumpp


# Make a fixture with some audio and task output
@pytest.fixture(params=[11025], scope='module')
def sr(request):
return request.param


@pytest.fixture(params=[512], scope='module')
def hop_length(request):
return request.param


@pytest.fixture(scope='module')
def ops(sr, hop_length):

ops = []

# Let's put on two feature extractors
ops.append(pumpp.feature.STFT(name='stft', sr=sr,
hop_length=hop_length,
n_fft=hop_length))

ops.append(pumpp.feature.Tempogram(name='rhythm', sr=sr,
hop_length=hop_length,
win_length=hop_length))

# A time-varying annotation
ops.append(pumpp.task.BeatTransformer(name='beat', sr=sr,
hop_length=hop_length))

# And a static annotation
ops.append(pumpp.task.VectorTransformer(namespace='vector',
dimension=32,
name='vec'))

yield ops


@pytest.fixture(scope='module')
def data(ops):

audio_f = 'tests/data/test.ogg'
jams_f = 'tests/data/test.jams'

return pumpp.transform(audio_f, jams_f, *ops)


@pytest.fixture(params=[4, 16, None], scope='module')
def n_samples(request):
return request.param


@pytest.fixture(params=[16, 32], scope='module')
def duration(request):
return request.param


def test_sampler(data, ops, n_samples, duration):

MAX_SAMPLES = 30
sampler = pumpp.Sampler(n_samples, duration, *ops)

# Build the set of reference keys that we want to track
ref_keys = set()
for op in ops:
ref_keys |= set(op.fields.keys())

for datum, n in zip(sampler(data), range(MAX_SAMPLES)):
# First, test that we have the right fields
assert set(datum.keys()) == ref_keys

# Now test that shape is preserved in the right way
for key in datum:
ref_shape = list(data[key].shape)
if sampler._time[key] is not None:
ref_shape[sampler._time[key]] = duration

assert list(datum[key].shape) == ref_shape

# Test that we got the right number of samples out
if n_samples is None:
assert n == MAX_SAMPLES - 1
else:
assert n == n_samples - 1
4 changes: 2 additions & 2 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_task_vector_absent(dimension, name):
assert not np.any(output[var_name])

for key in trans.fields:
assert shape_match(output[key].shape[1:], trans.fields[key].shape)
assert shape_match(output[key].shape, trans.fields[key].shape)
assert type_match(output[key].dtype, trans.fields[key].dtype)


Expand Down Expand Up @@ -426,7 +426,7 @@ def test_task_vector_present(target_dimension, data_dimension, name):
assert np.allclose(output[var_name], ann.data.loc[0].value)

for key in trans.fields:
assert shape_match(output[key].shape[1:], trans.fields[key].shape)
assert shape_match(output[key].shape, trans.fields[key].shape)
assert type_match(output[key].dtype, trans.fields[key].dtype)


Expand Down

0 comments on commit c5cfedd

Please sign in to comment.