Skip to content

Commit

Permalink
Merge pull request #55 from bmcfee/sampler-enhancements
Browse files Browse the repository at this point in the history
refactored sampler, added random states
  • Loading branch information
bmcfee committed Apr 3, 2017
2 parents eefb7f1 + 127a2a9 commit 7fb3ca0
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 17 deletions.
4 changes: 2 additions & 2 deletions pumpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# -*- encoding: utf-8 -*-
'''Practically universal music pre-processing'''

from .version import version
from .version import version as __version__
from .core import *
from .exceptions import *
from . import feature
from . import task
from .sampler import Sampler
from .sampler import *
127 changes: 114 additions & 13 deletions pumpp/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@
:toctree: generated/
Sampler
SequentialSampler
'''

from itertools import count

import six
import numpy as np

__all__ = ['Sampler']
from .exceptions import ParameterError

__all__ = ['Sampler', 'SequentialSampler']


class Sampler(object):
'''Generate samples from a pumpp data dict.
'''Generate samples uniformly at random from a pumpp data dict.
Attributes
----------
Expand All @@ -28,7 +32,17 @@ class Sampler(object):
duration : int > 0
the duration (in frames) of each sample
ops : one or more pumpp.feature.FeatureExtractor or pumpp.task.BaseTaskTransformer
random_state : None, int, or np.random.RandomState
If int, random_state is the seed used by the random number
generator;
If RandomState instance, random_state is the random number
generator;
If None, the random number generator is the RandomState instance
used by np.random.
ops : array of pumpp.feature.FeatureExtractor or pumpp.task.BaseTaskTransformer
The operators to include when sampling data.
Expand All @@ -50,11 +64,22 @@ class Sampler(object):
>>> for example in stream(data):
... process(data)
'''
def __init__(self, n_samples, duration, *ops):
def __init__(self, n_samples, duration, *ops, **kwargs):

self.n_samples = n_samples
self.duration = duration

random_state = kwargs.pop('random_state', None)

if random_state is None:
self.rng = np.random
elif isinstance(random_state, int):
self.rng = np.random.RandomState(seed=random_state)
elif isinstance(random_state, np.random.RandomState):
self.rng = random_state
else:
raise ParameterError('Invalid random_state={}'.format(random_state))

fields = dict()
for op in ops:
fields.update(op.fields)
Expand Down Expand Up @@ -91,7 +116,7 @@ def sample(self, data, interval):
index = [slice(None)] * data[key].ndim

# if we have multiple observations for this key, pick one
index[0] = np.random.randint(0, data[key].shape[0])
index[0] = self.rng.randint(0, data[key].shape[0])
index[0] = slice(index[0], index[0] + 1)

if self._time.get(key, None) is not None:
Expand Down Expand Up @@ -122,6 +147,25 @@ def data_duration(self, data):

return min(lengths)

def indices(self, data):
'''Generate patch indices
Parameters
----------
data : dict of np.ndarray
As produced by pumpp.transform
Yields
------
start : int >= 0
The start index of a sample patch
'''
duration = self.data_duration(data)

while True:
# Generate a sampling interval
yield self.rng.randint(0, duration - self.duration)

def __call__(self, data):
'''Generate samples from a data dict.
Expand All @@ -136,14 +180,71 @@ def __call__(self, data):
A sequence of patch samples from `data`,
as parameterized by the sampler object.
'''
duration = self.data_duration(data)
if self.n_samples:
counter = six.moves.range(self.n_samples)
else:
counter = count(0)

for i in count(0):
# are we done?
if self.n_samples and i >= self.n_samples:
break
for i, start in six.moves.zip(counter, self.indices(data)):
yield self.sample(data, slice(start, start + self.duration))

# Generate a sampling interval
start = np.random.randint(0, duration - self.duration)

yield self.sample(data, slice(start, start + self.duration))
class SequentialSampler(Sampler):
'''Sample patches in sequential (temporal) order
Attributes
----------
duration : int > 0
the duration (in frames) of each sample
stride : int > 0
The number of frames to advance between samples.
By default, matches `duration` so there is no overlap.
ops : array of pumpp.feature.FeatureExtractor or pumpp.task.BaseTaskTransformer
The operators to include when sampling data.
random_state : None, int, or np.random.RandomState
If int, random_state is the seed used by the random number
generator;
If RandomState instance, random_state is the random number
generator;
If None, the random number generator is the RandomState instance
See Also
--------
Sampler
'''

def __init__(self, duration, *ops, **kwargs):

stride = kwargs.pop('stride', None)

super(SequentialSampler, self).__init__(None, duration, *ops, **kwargs)

if stride is None:
stride = duration

if not stride > 0:
raise ParameterError('Invalid patch stride={}'.format(stride))
self.stride = stride

def indices(self, data):
'''Generate patch start indices
Parameters
----------
data : dict of np.ndarray
As produced by pumpp.transform
Yields
------
start : int >= 0
The start index of a sample patch
'''
duration = self.data_duration(data)

for start in range(0, duration - self.duration, self.stride):
yield start
44 changes: 42 additions & 2 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# -*- encoding: utf-8 -*-
'''Testing the sampler module'''

import numpy as np

import pytest

import pumpp
Expand Down Expand Up @@ -63,10 +65,25 @@ def duration(request):
return request.param


def test_sampler(data, ops, n_samples, duration):
@pytest.fixture(params=[None, 16, 256,
pytest.mark.xfail(-1, raises=pumpp.ParameterError)],
scope='module')
def stride(request):
return request.param


@pytest.fixture(params=[None, 20170401, np.random.RandomState(100),
pytest.mark.xfail('bad rng',
raises=pumpp.ParameterError)],
scope='module')
def rng(request):
return request.param


def test_sampler(data, ops, n_samples, duration, rng):

MAX_SAMPLES = 30
sampler = pumpp.Sampler(n_samples, duration, *ops)
sampler = pumpp.Sampler(n_samples, duration, *ops, random_state=rng)

# Build the set of reference keys that we want to track
ref_keys = set()
Expand All @@ -92,3 +109,26 @@ def test_sampler(data, ops, n_samples, duration):
assert n == MAX_SAMPLES - 1
else:
assert n == n_samples - 1


def test_sequential_sampler(data, ops, duration, stride, rng):
sampler = pumpp.SequentialSampler(duration, *ops, stride=stride, random_state=rng)

# Build the set of reference keys that we want to track
ref_keys = set()
for op in ops:
ref_keys |= set(op.fields.keys())

for datum in sampler(data):
# First, test that we have the right fields
assert set(datum.keys()) == ref_keys

# Now test that shape is preserved in the right way
for key in datum:
ref_shape = list(data[key].shape)
if sampler._time.get(key, None) is not None:
ref_shape[sampler._time[key]] = duration

# Check that all keys have length=1
assert datum[key].shape[0] == 1
assert list(datum[key].shape[1:]) == ref_shape[1:]

0 comments on commit 7fb3ca0

Please sign in to comment.