Skip to content

Commit

Permalink
Merge pull request #43 from bmcfee/keras2
Browse files Browse the repository at this point in the history
added keras2 terminology support for dimension ordering
  • Loading branch information
bmcfee committed Mar 15, 2017
2 parents e61dfdc + ae9b86d commit ea31faa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
24 changes: 13 additions & 11 deletions pumpp/feature/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@ class FeatureExtractor(Scope):
hop_length : int > 0
The hop length between analysis frames
conv : {'tf', 'th', None}
conv : {'tf', 'th', 'channels_last', 'channels_first', None}
convolution dimension ordering:
- 'tf' for tensorflow-style 2D convolution
- 'th' for theano-style 2D convolution
- 'channels_last' for tensorflow-style 2D convolution
- 'tf' equivalent to 'channels_last'
- 'channels_first' for theano-style 2D convolution
- 'th' equivalent to 'channels_first'
- None for 1D or non-convolutional representations
'''
def __init__(self, name, sr, hop_length, conv=None):

super(FeatureExtractor, self).__init__(name)

if conv not in ('tf', 'th', None):
if conv not in ('tf', 'th', 'channels_last', 'channels_first', None):
raise ParameterError('conv="{}", must be one of '
'("tf", "th", None)'.format(conv))
'("channels_last", "tf", "channels_first", "th", None)'.format(conv))

self.sr = sr
self.hop_length = hop_length
Expand All @@ -46,10 +48,10 @@ def register(self, key, dimension, dtype):

shape = [None, dimension]

if self.conv == 'tf':
if self.conv in ('channels_last', 'tf'):
shape.append(1)

elif self.conv == 'th':
elif self.conv in ('channels_first', 'th'):
shape.insert(0, 1)

super(FeatureExtractor, self).register(key, shape, dtype)
Expand All @@ -59,10 +61,10 @@ def idx(self):
if self.conv is None:
return Ellipsis

elif self.conv == 'tf':
elif self.conv in ('channels_last', 'tf'):
return (slice(None), slice(None), np.newaxis)

elif self.conv == 'th':
elif self.conv in ('channels_first', 'th'):
return (np.newaxis, slice(None), slice(None))

def transform(self, y, sr):
Expand Down Expand Up @@ -109,9 +111,9 @@ def phase_diff(self, phase):

if self.conv is None:
axis = 0
elif self.conv == 'tf':
elif self.conv in ('channels_last', 'tf'):
axis = 0
elif self.conv == 'th':
elif self.conv in ('channels_first', 'th'):
axis = 1

# Compute the phase differential
Expand Down
6 changes: 3 additions & 3 deletions tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def n_octaves(request):
return request.param


@pytest.fixture(params=[None, 'tf', 'th',
@pytest.fixture(params=[None, 'tf', 'th', 'channels_last', 'channels_first',
pytest.mark.xfail('bad mode',
raises=pumpp.ParameterError)])
def conv(request):
Expand All @@ -78,9 +78,9 @@ def __check_shape(fields, key, dim, conv):

if conv is None:
assert fields[key].shape == (None, dim)
elif conv == 'tf':
elif conv in ('channels_last', 'tf'):
assert fields[key].shape == (None, dim, 1)
elif conv == 'th':
elif conv in ('channels_first', 'th'):
assert fields[key].shape == (1, None, dim)


Expand Down

0 comments on commit ea31faa

Please sign in to comment.