Skip to content

Commit

Permalink
Merge pull request #57 from bmcfee/sampler-interface
Browse files Browse the repository at this point in the history
fixed #56, Pump.sampler binding
  • Loading branch information
bmcfee committed Apr 3, 2017
2 parents 7486f99 + 9cc0be6 commit 49ef34e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
16 changes: 14 additions & 2 deletions pumpp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def transform(self, audio_f, jam=None):

return transform(audio_f, jam, *self.ops)

def sampler(self, n_samples, duration):
def sampler(self, n_samples, duration, random_state=None):
'''Construct a sampler object for this pump's operators.
Parameters
Expand All @@ -161,6 +161,16 @@ def sampler(self, n_samples, duration):
duration : int > 0
The duration (in frames) of each sample patch
random_state : None, int, or np.random.RandomState
If int, random_state is the seed used by the random number
generator;
If RandomState instance, random_state is the random number
generator;
If None, the random number generator is the RandomState instance
used by np.random.
Returns
-------
sampler : pumpp.Sampler
Expand All @@ -171,7 +181,9 @@ def sampler(self, n_samples, duration):
pumpp.sampler.Sampler
'''

return Sampler(n_samples, duration, *self.ops)
return Sampler(n_samples, duration,
random_state=random_state,
*self.ops)

@property
def fields(self):
Expand Down
2 changes: 1 addition & 1 deletion pumpp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""Version info"""

short_version = '0.1'
version = '0.1.3'
version = '0.1.4pre'
7 changes: 4 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def test_pump_badkey(sr, hop_length):

@pytest.mark.parametrize('n_samples', [None, 10])
@pytest.mark.parametrize('duration', [1, 5])
def test_pump_sampler(sr, hop_length, n_samples, duration):
@pytest.mark.parametrize('rng', [None, 1])
def test_pump_sampler(sr, hop_length, n_samples, duration, rng):
ops = [pumpp.feature.STFT(name='stft', sr=sr,
hop_length=hop_length,
n_fft=2*hop_length),
Expand All @@ -176,8 +177,8 @@ def test_pump_sampler(sr, hop_length, n_samples, duration):

P = pumpp.Pump(*ops)

S1 = pumpp.Sampler(n_samples, duration, *ops)
S2 = P.sampler(n_samples, duration)
S1 = pumpp.Sampler(n_samples, duration, random_state=rng, *ops)
S2 = P.sampler(n_samples, duration, random_state=rng)

assert S1._time == S2._time
assert S1.n_samples == S2.n_samples
Expand Down

0 comments on commit 49ef34e

Please sign in to comment.