Skip to content

Commit

Permalink
Merge pull request #14 from bmcfee/rhythm-features
Browse files Browse the repository at this point in the history
adding tempogram and temposcale features
  • Loading branch information
bmcfee committed Aug 31, 2016
2 parents b3432f5 + 273fc5a commit 2070d93
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 4 deletions.
4 changes: 3 additions & 1 deletion pumpp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def register(self, field, shape, dtype):

self.fields[self.scope(field)] = Tensor(tuple(shape), dtype)

def pop(self, field):
return self.fields.pop(self.scope(field))

def merge(self, data):
'''Merge an array of output dictionaries into a single dictionary
with properly scoped names.
Expand All @@ -100,4 +103,3 @@ def merge(self, data):
data_out[self.scope(key)] = np.stack([np.asarray(d[key]) for d in data],
axis=0)
return data_out

3 changes: 3 additions & 0 deletions pumpp/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
STFTMag
STFTPhaseDiff
Mel
Tempogram
TempoScale
'''

from .base import *
from .cqt import *
from .fft import *
from .mel import *
from .rhythm import *
6 changes: 4 additions & 2 deletions pumpp/feature/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ def transform_audio(self, y):
dtype=np.float32))
return {'mag': mag.T, 'phase': np.angle(phase.T)}


class STFTPhaseDiff(STFT):

def __init__(self, *args, **kwargs):

super(STFTPhaseDiff, self).__init__(*args, **kwargs)
phase_field = self.fields.pop(self.scope('phase'))
phase_field = self.pop('phase')
self.register('dphase', phase_field.shape, phase_field.dtype)

def transform_audio(self, y):
Expand All @@ -41,12 +42,13 @@ def transform_audio(self, y):
data['dphase'] = phase_diff(data.pop('phase'), axis=0)
return data


class STFTMag(STFT):

def __init__(self, *args, **kwargs):

super(STFTMag, self).__init__(*args, **kwargs)
self.fields.pop(self.scope('phase'))
self.pop('phase')

def transform_audio(self, y):

Expand Down
45 changes: 45 additions & 0 deletions pumpp/feature/rhythm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python

import numpy as np
import librosa

from .base import FeatureExtractor

__all__ = ['Tempogram', 'TempoScale']


class Tempogram(FeatureExtractor):

def __init__(self, name, sr, hop_length, win_length):

super(Tempogram, self).__init__(name, sr, hop_length)

self.win_length = win_length

self.register('tempogram', [None, win_length], np.float32)

def transform_audio(self, y):
tgram = librosa.feature.tempogram(y=y, sr=self.sr,
hop_length=self.hop_length,
win_length=self.win_length).astype(np.float32)

return {'tempogram': tgram.T}


class TempoScale(Tempogram):

def __init__(self, name, sr, hop_length, win_length, n_fmt=128):

super(TempoScale, self).__init__(name, sr, hop_length, win_length)

self.n_fmt = n_fmt
self.pop('tempogram')
self.register('temposcale', [None, 1 + n_fmt // 2], np.float32)

def transform_audio(self, y):

data = super(TempoScale, self).transform_audio(y)
data['temposcale'] = np.abs(librosa.fmt(data.pop('tempogram'),
axis=1,
n_fmt=self.n_fmt)).astype(np.float32)
return data
76 changes: 75 additions & 1 deletion tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def audio(request):
def n_fft(request):
return request.param


@pytest.fixture(params=[32, 128])
def n_mels(request):
return request.param
Expand All @@ -38,6 +39,16 @@ def HOP_LENGTH():
return 512


@pytest.fixture(params=[192, 384])
def WIN_LENGTH(request):
return request.param


@pytest.fixture(params=[16, 128])
def N_FMT(request):
return request.param


@pytest.fixture(params=[1, 3])
def over_sample(request):
return request.param
Expand All @@ -48,7 +59,6 @@ def n_octaves(request):
return request.param



# STFT features

def test_feature_stft_fields(SR, HOP_LENGTH, n_fft):
Expand Down Expand Up @@ -188,6 +198,7 @@ def test_feature_cqt_fields(SR, HOP_LENGTH, over_sample, n_octaves):
assert ext.fields['cqt/mag'].dtype is np.float32
assert ext.fields['cqt/phase'].dtype is np.float32


def test_feature_cqtmag_fields(SR, HOP_LENGTH, over_sample, n_octaves):

ext = pumpp.feature.CQTMag(name='cqt',
Expand All @@ -201,6 +212,7 @@ def test_feature_cqtmag_fields(SR, HOP_LENGTH, over_sample, n_octaves):
assert ext.fields['cqt/mag'].shape == (None, over_sample * n_octaves * 12)
assert ext.fields['cqt/mag'].dtype is np.float32


def test_feature_cqtphasediff_fields(SR, HOP_LENGTH, over_sample, n_octaves):

ext = pumpp.feature.CQTPhaseDiff(name='cqt',
Expand All @@ -216,6 +228,7 @@ def test_feature_cqtphasediff_fields(SR, HOP_LENGTH, over_sample, n_octaves):
assert ext.fields['cqt/mag'].dtype is np.float32
assert ext.fields['cqt/dphase'].dtype is np.float32


def test_feature_cqt(audio, SR, HOP_LENGTH, over_sample, n_octaves):

ext = pumpp.feature.CQT(name='cqt',
Expand All @@ -231,6 +244,7 @@ def test_feature_cqt(audio, SR, HOP_LENGTH, over_sample, n_octaves):
assert shape_match(output[key].shape[1:], ext.fields[key].shape)
assert type_match(output[key].dtype, ext.fields[key].dtype)


def test_feature_cqtmag(audio, SR, HOP_LENGTH, over_sample, n_octaves):

ext = pumpp.feature.CQTMag(name='cqt',
Expand All @@ -246,6 +260,7 @@ def test_feature_cqtmag(audio, SR, HOP_LENGTH, over_sample, n_octaves):
assert shape_match(output[key].shape[1:], ext.fields[key].shape)
assert type_match(output[key].dtype, ext.fields[key].dtype)


def test_feature_cqtphasediff(audio, SR, HOP_LENGTH, over_sample, n_octaves):

ext = pumpp.feature.CQTPhaseDiff(name='cqt',
Expand All @@ -261,3 +276,62 @@ def test_feature_cqtphasediff(audio, SR, HOP_LENGTH, over_sample, n_octaves):
assert shape_match(output[key].shape[1:], ext.fields[key].shape)
assert type_match(output[key].dtype, ext.fields[key].dtype)


# Rhythm features
def test_feature_tempogram_fields(SR, HOP_LENGTH, WIN_LENGTH):

ext = pumpp.feature.Tempogram(name='rhythm',
sr=SR, hop_length=HOP_LENGTH,
win_length=WIN_LENGTH)

# Check the fields
assert set(ext.fields.keys()) == set(['rhythm/tempogram'])

assert ext.fields['rhythm/tempogram'].shape == (None, WIN_LENGTH)
assert ext.fields['rhythm/tempogram'].dtype is np.float32


def test_feature_tempogram(audio, SR, HOP_LENGTH, WIN_LENGTH):

ext = pumpp.feature.Tempogram(name='rhythm',
sr=SR, hop_length=HOP_LENGTH,
win_length=WIN_LENGTH)


output = ext.transform(**audio)

assert set(output.keys()) == set(ext.fields.keys())

for key in ext.fields:
assert shape_match(output[key].shape[1:], ext.fields[key].shape)
assert type_match(output[key].dtype, ext.fields[key].dtype)


def test_feature_temposcale_fields(SR, HOP_LENGTH, WIN_LENGTH, N_FMT):

ext = pumpp.feature.TempoScale(name='rhythm',
sr=SR, hop_length=HOP_LENGTH,
win_length=WIN_LENGTH,
n_fmt=N_FMT)

# Check the fields
assert set(ext.fields.keys()) == set(['rhythm/temposcale'])

assert ext.fields['rhythm/temposcale'].shape == (None, 1 + N_FMT // 2)
assert ext.fields['rhythm/temposcale'].dtype is np.float32


def test_feature_temposcale(audio, SR, HOP_LENGTH, WIN_LENGTH, N_FMT):

ext = pumpp.feature.TempoScale(name='rhythm',
sr=SR, hop_length=HOP_LENGTH,
win_length=WIN_LENGTH,
n_fmt=N_FMT)

output = ext.transform(**audio)

assert set(output.keys()) == set(ext.fields.keys())

for key in ext.fields:
assert shape_match(output[key].shape[1:], ext.fields[key].shape)
assert type_match(output[key].dtype, ext.fields[key].dtype)

0 comments on commit 2070d93

Please sign in to comment.