Skip to content

Commit

Permalink
Merge pull request #42 from bmcfee/logamp
Browse files Browse the repository at this point in the history
Added decibel scaling to spectrogram representations
  • Loading branch information
bmcfee committed Mar 15, 2017
2 parents ca9b338 + b4338ee commit e61dfdc
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
13 changes: 11 additions & 2 deletions pumpp/feature/cqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
'''CQT features'''

import numpy as np
from librosa import cqt, magphase, note_to_hz
from librosa import cqt, magphase, note_to_hz, amplitude_to_db

from .base import FeatureExtractor

Expand Down Expand Up @@ -31,9 +31,15 @@ class CQT(FeatureExtractor):
fmin : float > 0
The minimum frequency of the CQT
log : boolean
If `True`, scale the magnitude to decibels
Otherwise, use linear magnitude
'''
def __init__(self, name, sr, hop_length, n_octaves=8, over_sample=3,
fmin=None, conv=None):
fmin=None, log=False, conv=None):
super(CQT, self).__init__(name, sr, hop_length, conv=conv)

if fmin is None:
Expand All @@ -42,6 +48,7 @@ def __init__(self, name, sr, hop_length, n_octaves=8, over_sample=3,
self.n_octaves = n_octaves
self.over_sample = over_sample
self.fmin = fmin
self.log = log

n_bins = n_octaves * 12 * over_sample
self.register('mag', n_bins, np.float32)
Expand Down Expand Up @@ -71,6 +78,8 @@ def transform_audio(self, y):
n_bins=(self.n_octaves *
self.over_sample * 12),
bins_per_octave=(self.over_sample * 12)))
if self.log:
cqtm = amplitude_to_db(cqtm, ref=np.max)

return {'mag': cqtm.T.astype(np.float32)[self.idx],
'phase': np.angle(phase).T.astype(np.float32)[self.idx]}
Expand Down
11 changes: 10 additions & 1 deletion pumpp/feature/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@ class STFT(FeatureExtractor):
n_fft : int > 0
The number of FFT bins per frame
log : bool
If `True`, scale magnitude in decibels.
Otherwise use linear magnitude.
See Also
--------
STFTMag
STFTPhaseDiff
'''
def __init__(self, name, sr, hop_length, n_fft, conv=None):
def __init__(self, name, sr, hop_length, n_fft, log=False, conv=None):
super(STFT, self).__init__(name, sr, hop_length, conv=conv)

self.n_fft = n_fft
self.log = log

self.register('mag', 1 + n_fft // 2, np.float32)
self.register('phase', 1 + n_fft // 2, np.float32)
Expand All @@ -61,6 +67,9 @@ def transform_audio(self, y):
hop_length=self.hop_length,
n_fft=self.n_fft,
dtype=np.float32))
if self.log:
mag = librosa.amplitude_to_db(mag, ref=np.max)

return {'mag': mag.T[self.idx],
'phase': np.angle(phase.T)[self.idx]}

Expand Down
10 changes: 9 additions & 1 deletion pumpp/feature/mel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@ class Mel(FeatureExtractor):
fmax : number > 0
The maximum frequency bin.
Defaults to `0.5 * sr`
log : bool
If `True`, scale magnitude in decibels.
Otherwise, use a linear amplitude scale.
'''
def __init__(self, name, sr, hop_length, n_fft, n_mels, fmax=None,
conv=None):
log=False, conv=None):
super(Mel, self).__init__(name, sr, hop_length, conv=conv)

self.n_fft = n_fft
self.n_mels = n_mels
self.fmax = fmax
self.log = log

self.register('mag', n_mels, np.float32)

Expand All @@ -62,5 +68,7 @@ def transform_audio(self, y):
hop_length=self.hop_length,
n_mels=self.n_mels,
fmax=self.fmax)).astype(np.float32)
if self.log:
mel = librosa.amplitude_to_db(mel, ref=np.max)

return {'mag': mel.T[self.idx]}
23 changes: 15 additions & 8 deletions tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def conv(request):
return request.param



@pytest.fixture(params=[False, True])
def log(request):
return request.param



# STFT features
def __check_shape(fields, key, dim, conv):

Expand All @@ -77,7 +84,7 @@ def __check_shape(fields, key, dim, conv):
assert fields[key].shape == (1, None, dim)


def test_feature_stft_fields(SR, HOP_LENGTH, n_fft, conv):
def test_feature_stft_fields(SR, HOP_LENGTH, n_fft, conv, log):

ext = pumpp.feature.STFT(name='stft',
sr=SR, hop_length=HOP_LENGTH,
Expand Down Expand Up @@ -123,7 +130,7 @@ def test_feature_stft_phasediff_fields(SR, HOP_LENGTH, n_fft, conv):
assert ext.fields['stft/dphase'].dtype is np.float32


def test_feature_stft(audio, SR, HOP_LENGTH, n_fft, conv):
def test_feature_stft(audio, SR, HOP_LENGTH, n_fft, conv, log):

ext = pumpp.feature.STFT(name='stft',
sr=SR, hop_length=HOP_LENGTH,
Expand All @@ -139,7 +146,7 @@ def test_feature_stft(audio, SR, HOP_LENGTH, n_fft, conv):
assert type_match(output[key].dtype, ext.fields[key].dtype)


def test_feature_stft_phasediff(audio, SR, HOP_LENGTH, n_fft, conv):
def test_feature_stft_phasediff(audio, SR, HOP_LENGTH, n_fft, conv, log):

ext = pumpp.feature.STFTPhaseDiff(name='stft',
sr=SR, hop_length=HOP_LENGTH,
Expand All @@ -156,7 +163,7 @@ def test_feature_stft_phasediff(audio, SR, HOP_LENGTH, n_fft, conv):
assert type_match(output[key].dtype, ext.fields[key].dtype)


def test_feature_stft_mag(audio, SR, HOP_LENGTH, n_fft, conv):
def test_feature_stft_mag(audio, SR, HOP_LENGTH, n_fft, conv, log):

ext = pumpp.feature.STFTMag(name='stft',
sr=SR, hop_length=HOP_LENGTH,
Expand Down Expand Up @@ -188,7 +195,7 @@ def test_feature_mel_fields(SR, HOP_LENGTH, n_fft, n_mels, conv):
assert ext.fields['mel/mag'].dtype is np.float32


def test_feature_mel(audio, SR, HOP_LENGTH, n_fft, n_mels, conv):
def test_feature_mel(audio, SR, HOP_LENGTH, n_fft, n_mels, conv, log):

ext = pumpp.feature.Mel(name='mel',
sr=SR, hop_length=HOP_LENGTH,
Expand Down Expand Up @@ -256,7 +263,7 @@ def test_feature_cqtphasediff_fields(SR, HOP_LENGTH, over_sample, n_octaves, con
assert ext.fields['cqt/dphase'].dtype is np.float32


def test_feature_cqt(audio, SR, HOP_LENGTH, over_sample, n_octaves, conv):
def test_feature_cqt(audio, SR, HOP_LENGTH, over_sample, n_octaves, conv, log):

ext = pumpp.feature.CQT(name='cqt',
sr=SR, hop_length=HOP_LENGTH,
Expand All @@ -273,7 +280,7 @@ def test_feature_cqt(audio, SR, HOP_LENGTH, over_sample, n_octaves, conv):
assert type_match(output[key].dtype, ext.fields[key].dtype)


def test_feature_cqtmag(audio, SR, HOP_LENGTH, over_sample, n_octaves, conv):
def test_feature_cqtmag(audio, SR, HOP_LENGTH, over_sample, n_octaves, conv, log):

ext = pumpp.feature.CQTMag(name='cqt',
sr=SR, hop_length=HOP_LENGTH,
Expand All @@ -290,7 +297,7 @@ def test_feature_cqtmag(audio, SR, HOP_LENGTH, over_sample, n_octaves, conv):
assert type_match(output[key].dtype, ext.fields[key].dtype)


def test_feature_cqtphasediff(audio, SR, HOP_LENGTH, over_sample, n_octaves, conv):
def test_feature_cqtphasediff(audio, SR, HOP_LENGTH, over_sample, n_octaves, conv, log):

ext = pumpp.feature.CQTPhaseDiff(name='cqt',
sr=SR, hop_length=HOP_LENGTH,
Expand Down

0 comments on commit e61dfdc

Please sign in to comment.