Skip to content

Commit

Permalink
Merge pull request #80 from bmcfee/autocrop
Browse files Browse the repository at this point in the history
Autocrop
  • Loading branch information
bmcfee committed Jul 18, 2017
2 parents 8b30ee2 + 0ce50b9 commit 2ca0956
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 67 deletions.
87 changes: 84 additions & 3 deletions pumpp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from collections import namedtuple, Iterable
import numpy as np

from .exceptions import *
__all__ = ['Tensor', 'Scope']
from .exceptions import ParameterError
__all__ = ['Tensor', 'Scope', 'Slicer']

# This type is used for storing shape information
Tensor = namedtuple('Tensor', ['shape', 'dtype'])
'''
Apparently you can document namedtuples here
Multi-dimensional array descriptions: `shape` and `dtype`
'''


Expand Down Expand Up @@ -109,3 +109,84 @@ def merge(self, data):
data_out[self.scope(key)] = np.stack([np.asarray(d[key]) for d in data],
axis=0)
return data_out


class Slicer(object):
'''Slicer can compute the duration of data with time-like fields,
and slice down to the common time index.
This class serves as a base for Sampler and Pump, and should not
be used directly.
Parameters
----------
ops : one or more Scope (TaskTransformer or FeatureExtractor)
'''
def __init__(self, *ops):

self._time = dict()

for operator in ops:
self.add(operator)

def add(self, operator):
'''Add an operator to the Slicer
Parameters
----------
operator : Scope (TaskTransformer or FeatureExtractor)
The new operator to add
'''
if not isinstance(operator, Scope):
raise ParameterError('Operator {} must be a TaskTransformer '
'or FeatureExtractor'.format(operator))
for key in operator.fields:
self._time[key] = None
if None in operator.fields[key].shape:
self._time[key] = 1 + operator.fields[key].shape.index(None)

def data_duration(self, data):
'''Compute the valid data duration of a dict
Parameters
----------
data : dict
As produced by pumpp.transform
Returns
-------
length : int
The minimum temporal extent of a dynamic observation in data
'''
# Find all the time-like indices of the data
lengths = []
for key in self._time:
if self._time[key] is not None:
lengths.append(data[key].shape[self._time[key]])

return min(lengths)

def crop(self, data):
'''Crop a data dictionary down to its common time
Parameters
----------
data : dict
As produced by pumpp.transform
Returns
-------
data_cropped : dict
Like `data` but with all time-like axes truncated to the
minimum common duration
'''

duration = self.data_duration(data)
data_out = dict()
for key in data:
idx = [slice(None)] * data[key].ndim
if key in self._time and self._time[key] is not None:
idx[self._time[key]] = slice(duration)
data_out[key] = data[key][idx]

return data_out
58 changes: 33 additions & 25 deletions pumpp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
import librosa
import jams

from .base import Slicer
from .exceptions import ParameterError
from .task import BaseTaskTransformer
from .feature import FeatureExtractor
from .sampler import Sampler


class Pump(object):
class Pump(Slicer):
'''Top-level pump object.
This class is used to collect feature and task transformers
Expand Down Expand Up @@ -65,35 +66,35 @@ def __init__(self, *ops):

self.ops = []
self.opmap = dict()
for op in ops:
self.add(op)
super(Pump, self).__init__(*ops)

def add(self, op):
def add(self, operator):
'''Add an operation to this pump.
Parameters
----------
op : BaseTaskTransformer, FeatureExtractor
operator : BaseTaskTransformer, FeatureExtractor
The operation to add
Raises
------
ParameterError
if `op` is not of a correct type
'''
if not isinstance(op, (BaseTaskTransformer, FeatureExtractor)):
raise ParameterError('op={} must be one of '
if not isinstance(operator, (BaseTaskTransformer, FeatureExtractor)):
raise ParameterError('operator={} must be one of '
'(BaseTaskTransformer, FeatureExtractor)'
.format(op))
.format(operator))

if op.name in self.opmap:
if operator.name in self.opmap:
raise ParameterError('Duplicate operator name detected: '
'{}'.format(op))
'{}'.format(operator))

self.opmap[op.name] = op
self.ops.append(op)
super(Pump, self).add(operator)
self.opmap[operator.name] = operator
self.ops.append(operator)

def transform(self, audio_f=None, jam=None, y=None, sr=None):
def transform(self, audio_f=None, jam=None, y=None, sr=None, crop=False):
'''Apply the transformations to an audio file, and optionally JAMS object.
Parameters
Expand All @@ -111,6 +112,10 @@ def transform(self, audio_f=None, jam=None, y=None, sr=None):
If provided, operate directly on an existing audio buffer `y` at
sampling rate `sr` rather than load from `audio_f`.
crop : bool
If `True`, then data are cropped to a common time index across all
fields. Otherwise, data may have different time extents.
Returns
-------
data : dict
Expand Down Expand Up @@ -145,11 +150,13 @@ def transform(self, audio_f=None, jam=None, y=None, sr=None):

data = dict()

for op in self.ops:
if isinstance(op, BaseTaskTransformer):
data.update(op.transform(jam))
elif isinstance(op, FeatureExtractor):
data.update(op.transform(y, sr))
for operator in self.ops:
if isinstance(operator, BaseTaskTransformer):
data.update(operator.transform(jam))
elif isinstance(operator, FeatureExtractor):
data.update(operator.transform(y, sr))
if crop:
data = self.crop(data)
return data

def sampler(self, n_samples, duration, random_state=None):
Expand Down Expand Up @@ -189,9 +196,10 @@ def sampler(self, n_samples, duration, random_state=None):

@property
def fields(self):
'''A dictionary of fields constructed by this pump'''
out = dict()
for op in self.ops:
out.update(**op.fields)
for operator in self.ops:
out.update(**operator.fields)

return out

Expand All @@ -206,11 +214,11 @@ def layers(self):
fields.
'''

L = dict()
for op in self.ops:
if hasattr(op, 'layers'):
L.update(op.layers())
return L
layermap = dict()
for operator in self.ops:
if hasattr(operator, 'layers'):
layermap.update(operator.layers())
return layermap

def __getitem__(self, key):
return self.opmap.get(key)
Expand Down
40 changes: 6 additions & 34 deletions pumpp/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@
Sampler
SequentialSampler
Slicer
'''

from itertools import count

import six
import numpy as np

from .base import Slicer
from .exceptions import ParameterError

__all__ = ['Sampler', 'SequentialSampler']


class Sampler(object):
class Sampler(Slicer):
'''Generate samples uniformly at random from a pumpp data dict.
Attributes
Expand Down Expand Up @@ -66,6 +68,8 @@ class Sampler(object):
'''
def __init__(self, n_samples, duration, *ops, **kwargs):

super(Sampler, self).__init__(*ops)

self.n_samples = n_samples
self.duration = duration

Expand All @@ -80,17 +84,6 @@ def __init__(self, n_samples, duration, *ops, **kwargs):
else:
raise ParameterError('Invalid random_state={}'.format(random_state))

fields = dict()
for op in ops:
fields.update(op.fields)

# Pre-determine which fields have time-like indices
self._time = {key: None for key in fields}
for key in fields:
if None in fields[key].shape:
# Add one for the batching index
self._time[key] = 1 + fields[key].shape.index(None)

def sample(self, data, interval):
'''Sample a patch from the data object
Expand Down Expand Up @@ -126,27 +119,6 @@ def sample(self, data, interval):

return data_slice

def data_duration(self, data):
'''Compute the valid data duration of a dict
Parameters
----------
data : dict
As produced by pumpp.transform
Returns
-------
length : int
The minimum temporal extent of a dynamic observation in data
'''
# Find all the time-like indices of the data
lengths = []
for key in self._time:
if self._time[key] is not None:
lengths.append(data[key].shape[self._time[key]])

return min(lengths)

def indices(self, data):
'''Generate patch indices
Expand Down Expand Up @@ -185,7 +157,7 @@ def __call__(self, data):
else:
counter = count(0)

for i, start in six.moves.zip(counter, self.indices(data)):
for _, start in six.moves.zip(counter, self.indices(data)):
yield self.sample(data, slice(start, start + self.duration))


Expand Down
15 changes: 10 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def jam(request):
@pytest.mark.parametrize('audio_f', [None, 'tests/data/test.ogg'])
@pytest.mark.parametrize('y', [None, 'tests/data/test.ogg'])
@pytest.mark.parametrize('sr2', [None, 22050, 44100])
def test_pump(audio_f, jam, y, sr, sr2, hop_length):
@pytest.mark.parametrize('crop', [False, True])
def test_pump(audio_f, jam, y, sr, sr2, hop_length, crop):

ops = [pumpp.feature.STFT(name='stft', sr=sr,
hop_length=hop_length,
Expand Down Expand Up @@ -77,8 +78,8 @@ def test_pump(audio_f, jam, y, sr, sr2, hop_length):

assert set(P.fields.keys()) == fields

data = P.transform(audio_f=audio_f, jam=jam, y=y, sr=sr2)
data2 = P(audio_f=audio_f, jam=jam, y=y, sr=sr2)
data = P.transform(audio_f=audio_f, jam=jam, y=y, sr=sr2, crop=crop)
data2 = P(audio_f=audio_f, jam=jam, y=y, sr=sr2, crop=crop)

# Fields we should have:
assert set(data.keys()) == fields | valids
Expand All @@ -90,8 +91,12 @@ def test_pump(audio_f, jam, y, sr, sr2, hop_length):
assert data['beat/beat'].shape[1] == data['chord/bass'].shape[1]

# Audio features can be off by at most a frame
assert (np.abs(data['stft/mag'].shape[1] - data['beat/beat'].shape[1])
* hop_length / float(sr)) <= 0.05
if crop:
assert data['stft/mag'].shape[1] == data['beat/beat'].shape[1]
assert data['stft/mag'].shape[1] == data['chord/pitch'].shape[1]
else:
assert (np.abs(data['stft/mag'].shape[1] - data['beat/beat'].shape[1])
* hop_length / float(sr)) <= 0.05

assert data.keys() == data2.keys()
for k in data:
Expand Down
43 changes: 43 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,46 @@ def test_sequential_sampler(data, ops, duration, stride, rng):
# Check that all keys have length=1
assert datum[key].shape[0] == 1
assert list(datum[key].shape[1:]) == ref_shape[1:]


def test_slicer():
scope1 = pumpp.base.Scope('test1')
scope1.register('first', (None, 10), np.int)
scope1.register('second', (2, None), np.int)
scope1.register('none', (16, 16), np.int)

scope2 = pumpp.base.Scope('test2')
scope2.register('first', (None, 5), np.int)
scope2.register('second', (20, None), np.int)

slicer = pumpp.base.Slicer(scope1, scope2)

# Minimum time for all of these is 8
data_in = {'test1/first': np.random.randint(0, 7, size=(1, 8, 10)),
'test1/second': np.random.randint(0, 7, size=(1, 2, 100)),
'test1/none': np.random.randint(0, 7, size=(1, 16, 16)),
'test2/first': np.random.randint(0, 7, size=(1, 9, 5)),
'test2/second': np.random.randint(0, 7, (1, 20, 105))}

data_out = slicer.crop(data_in)
assert set(data_out.keys()) == set(data_in.keys())

assert data_out['test1/first'].shape == (1, 8, 10)
assert np.all(data_out['test1/first'] == data_in['test1/first'][:, :8, :])

assert data_out['test1/second'].shape == (1, 2, 8)
assert np.all(data_out['test1/second'] == data_in['test1/second'][:, :, :8])

assert data_out['test1/none'].shape == (1, 16, 16)
assert np.all(data_out['test1/none'] == data_in['test1/none'])

assert data_out['test2/first'].shape == (1, 8, 5)
assert np.all(data_out['test2/first'] == data_in['test2/first'][:, :8, :])

assert data_out['test2/second'].shape == (1, 20, 8)
assert np.all(data_out['test2/second'] == data_in['test2/second'][:, :, :8])


@pytest.mark.xfail(raises=pumpp.ParameterError)
def test_slicer_fail():
pumpp.base.Slicer('not a scope')

0 comments on commit 2ca0956

Please sign in to comment.