Skip to content

Commit

Permalink
Merge pull request #25 from bmcfee/unlabeled-transform
Browse files Browse the repository at this point in the history
Better syntax for unlabeled transformation
  • Loading branch information
bmcfee committed Sep 15, 2016
2 parents 97a976b + 36635cc commit 40b5435
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
19 changes: 14 additions & 5 deletions pumpp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@
from .feature import FeatureExtractor


def transform(audio_f, jams_f, *ops):
def transform(audio_f, *ops, **kwargs):
'''Apply a set of operations to a track
Parameters
----------
audio_f : str
The path to the audio file
jams_f : str
The path to the jams file
jam : str, jams.JAMS, or file-like
A JAMS object, or path to a JAMS file.
ops : list of pumpp.task.BaseTaskTransform or pumpp.feature.FeatureExtractor
If not provided, an empty jams object will be created.
ops : list of task.BaseTaskTransform or feature.FeatureExtractor
The operators to apply to the input data
Returns
Expand All @@ -39,8 +41,15 @@ def transform(audio_f, jams_f, *ops):
# Load the audio
y, sr = librosa.load(audio_f, sr=None, mono=True)

jam = kwargs.pop('jam', None)

if jam is None:
jam = jams.JAMS()
jam.file_metadata.duration = librosa.get_duration(y=y, sr=sr)

# Load the jams
jam = jams.load(jams_f)
if not isinstance(jam, jams.JAMS):
jam = jams.load(jam)

data = dict()

Expand Down
36 changes: 24 additions & 12 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,49 @@
import numpy as np

import pumpp
import jams


@pytest.fixture(params=[11025, 22050])
def sr(request):
return request.param



@pytest.fixture(params=[128, 512])
def hop_length(request):
return request.param


@pytest.mark.parametrize('audio_f, jams_f',
[('tests/data/test.ogg',
'tests/data/test.jams')])
def test_transform(audio_f, jams_f, sr, hop_length):
@pytest.fixture(params=[None,
'tests/data/test.jams',
jams.load('tests/data/test.jams')])
def jam(request):
return request.param


@pytest.mark.parametrize('audio_f', ['tests/data/test.ogg'])
def test_transform(audio_f, jam, sr, hop_length):

ops = [pumpp.feature.STFT(name='stft', sr=sr,
hop_length=hop_length,
n_fft=2*hop_length),
pumpp.task.BeatTransformer(name='beat', sr=sr, hop_length=hop_length),
pumpp.task.ChordTransformer(name='chord', sr=sr, hop_length=hop_length),
pumpp.task.StaticLabelTransformer(name='tags', namespace='tag_open',

pumpp.task.BeatTransformer(name='beat', sr=sr,
hop_length=hop_length),

pumpp.task.ChordTransformer(name='chord', sr=sr,
hop_length=hop_length),

pumpp.task.StaticLabelTransformer(name='tags',
namespace='tag_open',
labels=['rock', 'jazz'])]

data = pumpp.transform(audio_f, jams_f, *ops)
data = pumpp.transform(audio_f, jam=jam, *ops)

# Fields we should have:
assert set(data.keys()) == set(['stft/mag', 'stft/phase',
'beat/beat', 'beat/downbeat', 'beat/_valid',
'beat/beat', 'beat/downbeat',
'beat/_valid',
'beat/mask_downbeat',
'chord/pitch', 'chord/root', 'chord/bass',
'chord/_valid',
Expand All @@ -48,7 +60,7 @@ def test_transform(audio_f, jams_f, sr, hop_length):
assert data['beat/beat'].shape[1] == data['chord/root'].shape[1]
assert data['beat/beat'].shape[1] == data['chord/bass'].shape[1]

# Audio features can be off by
assert (np.abs(data['stft/mag'].shape[1] - data['beat/beat'].shape[1])
# Audio features can be off by at most a frame
assert (np.abs(data['stft/mag'].shape[1] - data['beat/beat'].shape[1])
* hop_length / float(sr)) <= 0.05
pass

0 comments on commit 40b5435

Please sign in to comment.