Skip to content

Commit

Permalink
Merge pull request #65 from bmcfee/transform-api
Browse files Browse the repository at this point in the history
Transform api
  • Loading branch information
bmcfee committed Apr 26, 2017
2 parents 797a732 + 69f8bfc commit f29b0d5
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 111 deletions.
108 changes: 57 additions & 51 deletions pumpp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,6 @@
from .sampler import Sampler


def transform(audio_f, jam, *ops):
'''Apply a set of operations to a track
Parameters
----------
audio_f : str
The path to the audio file
jam : str, jams.JAMS, or file-like
A JAMS object, or path to a JAMS file.
If not provided, an empty jams object will be created.
ops : list of task.BaseTaskTransform or feature.FeatureExtractor
The operators to apply to the input data
Returns
-------
data : dict
Extracted features and annotation encodings
'''

# Load the audio
y, sr = librosa.load(audio_f, sr=None, mono=True)

if jam is None:
jam = jams.JAMS()
jam.file_metadata.duration = librosa.get_duration(y=y, sr=sr)

# Load the jams
if not isinstance(jam, jams.JAMS):
jam = jams.load(jam)

data = dict()

for op in ops:
if isinstance(op, BaseTaskTransformer):
data.update(op.transform(jam))
elif isinstance(op, FeatureExtractor):
data.update(op.transform(y, sr))
return data


class Pump(object):
'''Top-level pump object.
Expand All @@ -79,7 +36,18 @@ class Pump(object):
>>> p_cqt = pumpp.feature.CQT('cqt', sr=44100, hop_length=1024)
>>> p_chord = pumpp.task.ChordTagTransformer(sr=44100, hop_length=1024)
>>> pump = pumpp.Pump(p_cqt, p_chord)
>>> data = pump.transform('/my/audio/file.mp3', '/my/jams/annotation.jams')
>>> data = pump.transform(audio_f='/my/audio/file.mp3',
... jam='/my/jams/annotation.jams')
Or use the call interface:
>>> data = pump(audio_f='/my/audio/file.mp3',
... jam='/my/jams/annotation.jams')
Or apply to audio in memory, and without existing annotations:
>>> y, sr = librosa.load('/my/audio/file.mp3')
>>> data = pump(y=y, sr=sr)
Access all the fields produced by this pump:
Expand All @@ -92,10 +60,6 @@ class Pump(object):
>>> pump['chord'].fields
{'chord/chord': Tensor(shape=(None, 170), dtype=<class 'bool'>)}
See Also
--------
transform
'''

def __init__(self, *ops):
Expand Down Expand Up @@ -124,12 +88,13 @@ def add(self, op):
.format(op))

if op.name in self.opmap:
raise ParameterError('Duplicate operator name detected: {}'.format(op))
raise ParameterError('Duplicate operator name detected: '
'{}'.format(op))

self.opmap[op.name] = op
self.ops.append(op)

def transform(self, audio_f, jam=None):
def transform(self, audio_f=None, jam=None, y=None, sr=None):
'''Apply the transformations to an audio file, and optionally JAMS object.
Parameters
Expand All @@ -142,13 +107,51 @@ def transform(self, audio_f, jam=None):
If provided, this will provide data for task transformers.
y : np.ndarray
sr : number > 0
If provided, operate directly on an existing audio buffer `y` at
sampling rate `sr` rather than load from `audio_f`.
Returns
-------
data : dict
Data dictionary containing the transformed audio (and annotations)
Raises
------
ParameterError
At least one of `audio_f` or `(y, sr)` must be provided.
'''

return transform(audio_f, jam, *self.ops)
if y is None:
if audio_f is None:
raise ParameterError('At least one of `y` or `audio_f` '
'must be provided')

# Load the audio
y, sr = librosa.load(audio_f, sr=sr, mono=True)

if sr is None:
raise ParameterError('If audio is provided as `y`, you must '
'specify the sampling rate as sr=')

if jam is None:
jam = jams.JAMS()
jam.file_metadata.duration = librosa.get_duration(y=y, sr=sr)

# Load the jams
if not isinstance(jam, jams.JAMS):
jam = jams.load(jam)

data = dict()

for op in self.ops:
if isinstance(op, BaseTaskTransformer):
data.update(op.transform(jam))
elif isinstance(op, FeatureExtractor):
data.update(op.transform(y, sr))
return data

def sampler(self, n_samples, duration, random_state=None):
'''Construct a sampler object for this pump's operators.
Expand Down Expand Up @@ -212,3 +215,6 @@ def layers(self):

def __getitem__(self, key):
return self.opmap.get(key)

def __call__(self, *args, **kwargs):
return self.transform(*args, **kwargs)
106 changes: 47 additions & 59 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import pytest
import numpy as np

import pumpp
import librosa
import jams

import pumpp


@pytest.fixture(params=[11025, 22050])
def sr(request):
Expand All @@ -26,48 +28,10 @@ def jam(request):
return request.param


@pytest.mark.parametrize('audio_f', ['tests/data/test.ogg'])
def test_transform(audio_f, jam, sr, hop_length):

ops = [pumpp.feature.STFT(name='stft', sr=sr,
hop_length=hop_length,
n_fft=2*hop_length),

pumpp.task.BeatTransformer(name='beat', sr=sr,
hop_length=hop_length),

pumpp.task.ChordTransformer(name='chord', sr=sr,
hop_length=hop_length),

pumpp.task.StaticLabelTransformer(name='tags',
namespace='tag_open',
labels=['rock', 'jazz'])]

data = pumpp.transform(audio_f, jam, *ops)

# Fields we should have:
assert set(data.keys()) == set(['stft/mag', 'stft/phase',
'beat/beat', 'beat/downbeat',
'beat/_valid',
'beat/mask_downbeat',
'chord/pitch', 'chord/root', 'chord/bass',
'chord/_valid',
'tags/tags', 'tags/_valid'])

# time shapes should be the same for annotations
assert data['beat/beat'].shape[1] == data['beat/downbeat'].shape[1]
assert data['beat/beat'].shape[1] == data['chord/pitch'].shape[1]
assert data['beat/beat'].shape[1] == data['chord/root'].shape[1]
assert data['beat/beat'].shape[1] == data['chord/bass'].shape[1]

# Audio features can be off by at most a frame
assert (np.abs(data['stft/mag'].shape[1] - data['beat/beat'].shape[1])
* hop_length / float(sr)) <= 0.05
pass


@pytest.mark.parametrize('audio_f', ['tests/data/test.ogg'])
def test_pump(audio_f, jam, sr, hop_length):
@pytest.mark.parametrize('audio_f', [None, 'tests/data/test.ogg'])
@pytest.mark.parametrize('y', [None, 'tests/data/test.ogg'])
@pytest.mark.parametrize('sr2', [None, 22050, 44100])
def test_pump(audio_f, jam, y, sr, sr2, hop_length):

ops = [pumpp.feature.STFT(name='stft', sr=sr,
hop_length=hop_length,
Expand All @@ -83,22 +47,46 @@ def test_pump(audio_f, jam, sr, hop_length):
namespace='tag_open',
labels=['rock', 'jazz'])]

data1 = pumpp.transform(audio_f, jam, *ops)

pump = pumpp.Pump(*ops)
data2 = pump.transform(audio_f, jam)

assert data1.keys() == data2.keys()

for key in data1:
assert np.allclose(data1[key], data2[key])

fields = dict()
for op in ops:
fields.update(**op.fields)
assert pump[op.name] == op

assert pump.fields == fields
P = pumpp.Pump(*ops)
if audio_f is None and y is None:
# no input
with pytest.raises(pumpp.ParameterError):
data = P.transform(audio_f=audio_f, jam=jam, y=y, sr=sr2)
elif y is not None and sr2 is None:
# input buffer, but no sampling rate
y = librosa.load(y, sr=sr2)[0]
with pytest.raises(pumpp.ParameterError):
data = P.transform(audio_f=audio_f, jam=jam, y=y, sr=sr2)
elif y is not None:
y = librosa.load(y, sr=sr2)[0]
data = P.transform(audio_f=audio_f, jam=jam, y=y, sr=sr2)
else:
data = P.transform(audio_f=audio_f, jam=jam, y=y, sr=sr2)
data2 = P(audio_f=audio_f, jam=jam, y=y, sr=sr2)

# Fields we should have:
assert set(data.keys()) == set(['stft/mag', 'stft/phase',
'beat/beat', 'beat/downbeat',
'beat/_valid',
'beat/mask_downbeat',
'chord/pitch', 'chord/root',
'chord/bass',
'chord/_valid',
'tags/tags', 'tags/_valid'])

# time shapes should be the same for annotations
assert data['beat/beat'].shape[1] == data['beat/downbeat'].shape[1]
assert data['beat/beat'].shape[1] == data['chord/pitch'].shape[1]
assert data['beat/beat'].shape[1] == data['chord/root'].shape[1]
assert data['beat/beat'].shape[1] == data['chord/bass'].shape[1]

# Audio features can be off by at most a frame
assert (np.abs(data['stft/mag'].shape[1] - data['beat/beat'].shape[1])
* hop_length / float(sr)) <= 0.05

assert data.keys() == data2.keys()
for k in data:
assert np.allclose(data[k], data2[k])


@pytest.mark.parametrize('audio_f', ['tests/data/test.ogg'])
Expand Down
3 changes: 2 additions & 1 deletion tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def data(ops):
audio_f = 'tests/data/test.ogg'
jams_f = 'tests/data/test.jams'

return pumpp.transform(audio_f, jams_f, *ops)
P = pumpp.Pump(*ops)
return P.transform(audio_f=audio_f, jam=jams_f)


@pytest.fixture(params=[4, 16, None], scope='module')
Expand Down

0 comments on commit f29b0d5

Please sign in to comment.