Skip to content

Commit

Permalink
Merge pull request #49 from bmcfee/pump-layers
Browse files Browse the repository at this point in the history
added layers constructor to Pump class
  • Loading branch information
bmcfee committed Mar 20, 2017
2 parents 7879d33 + 349f571 commit e8a38bc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
17 changes: 17 additions & 0 deletions pumpp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,20 @@ def fields(self):
out.update(**op.fields)

return out

def layers(self):
'''Construct Keras input layers for all feature transformers
in the pump.
Returns
-------
layers : {field: keras.layers.Input}
A dictionary of keras input layers, keyed by the corresponding
fields.
'''

L = dict()
for op in self.ops:
if hasattr(op, 'layers'):
L.update(op.layers())
return L
27 changes: 27 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,30 @@ def test_pump_sampler(sr, hop_length, n_samples, duration):
assert S1._time == S2._time
assert S1.n_samples == S2.n_samples
assert S1.duration == S2.duration


@pytest.mark.skip
def test_pump_layers(sr, hop_length):
ops = [pumpp.feature.STFT(name='stft', sr=sr,
hop_length=hop_length,
n_fft=2*hop_length),

pumpp.feature.CQT(name='cqt', sr=sr,
hop_length=hop_length),

pumpp.task.BeatTransformer(name='beat', sr=sr,
hop_length=hop_length)]

P = pumpp.Pump(*ops)

L1 = P.layers()
L2 = dict()
L2.update(ops[0].layers())
L2.update(ops[1].layers())

assert L1.keys() == L2.keys()

for k in L1:
assert L1[k].dtype == L2[k].dtype
for d1, d2 in zip(L1[k].shape, L2[k].shape):
assert str(d1) == str(d2)

0 comments on commit e8a38bc

Please sign in to comment.