Skip to content

Commit

Permalink
added effects.split
Browse files Browse the repository at this point in the history
added tests for effects.split

cleaned up effects.py __all__ and docstring

refactored trim and split

cleanup in effects

renamed n_fft to frame_length in trim/split

strengthened trim test

cleaned up trim test

cleaned up trim test more

strengthened split test

converted trim index output from slice to ndarray

updated rmse parameter in effects._signal_to_frame_nonsilent

updated docstring for trim
  • Loading branch information
bmcfee committed Nov 2, 2016
1 parent ce61ee0 commit c8280cd
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 52 deletions.
148 changes: 120 additions & 28 deletions librosa/effects.py
Expand Up @@ -27,6 +27,8 @@
:toctree: generated/
remix
trim
split
"""

import numpy as np
Expand All @@ -39,7 +41,7 @@

__all__ = ['hpss', 'harmonic', 'percussive',
'time_stretch', 'pitch_shift',
'remix', 'trim']
'remix', 'trim', 'split']


def hpss(y, **kwargs):
Expand Down Expand Up @@ -377,8 +379,49 @@ def remix(y, intervals, align_zeros=True):
return np.concatenate(y_out, axis=-1)


def trim(y, top_db=60, ref_power=np.max, n_fft=2048, hop_length=512,
index=False):
def _signal_to_frame_nonsilent(y, frame_length=2048, hop_length=512, top_db=60,
ref_power=np.max):
'''Frame-wise non-silent indicator for audio input.
This is a helper function for `trim` and `split`.
Parameters
----------
y : np.ndarray, shape=(n,) or (2,n)
Audio signal, mono or stereo
frame_length : int > 0
The number of samples per frame
hop_length : int > 0
The number of samples between frames
top_db : number > 0
The threshold (in decibels) below reference to consider as
silence
ref_power : callable or float
The reference power
Returns
-------
non_silent : np.ndarray, shape=(m,), dtype=bool
Indicator of non-silent frames
'''
# Convert to mono
y_mono = core.to_mono(y)

# Compute the MSE for the signal
mse = feature.rmse(y=y_mono,
frame_length=frame_length,
hop_length=hop_length)**2

return (core.logamplitude(mse.squeeze(),
ref_power=ref_power,
top_db=None) > - top_db)


def trim(y, top_db=60, ref_power=np.max, frame_length=2048, hop_length=512):
'''Trim leading and trailing silence from an audio signal.
Parameters
Expand All @@ -394,60 +437,109 @@ def trim(y, top_db=60, ref_power=np.max, n_fft=2048, hop_length=512,
The reference power. By default, it uses `np.max` and compares
to the peak power in the signal.
n_fft : int > 0
frame_length : int > 0
The number of samples per analysis frame
hop_length : int > 0
The number of samples between analysis frames
index : bool
If `True`, return the start and end of the non-silent
region of `y` along with the trimmed signal.
If `False`, only return the trimmed signal.
Returns
-------
y_trimmed : np.ndarray, shape=(m,) or (2, m)
The trimmed signal
index : slice, optional
If `index=True` is provided, then this contains
the slice of `y` corresponding to the non-silent region:
`y_trimmed = y[index]`.
index : np.ndarray, shape=(2,)
the interval of `y` corresponding to the non-silent region:
`y_trimmed = y[index[0]:index[1]]` (for mono) or
`y_trimmed = y[:, index[0]:index[1]]` (for stereo).
Examples
--------
>>> # Load some audio
>>> y, sr = librosa.load(librosa.util.example_audio_file())
>>> # Trim the beginning and ending silence
>>> yt = librosa.effects.trim(y)
>>> yt, index = librosa.effects.trim(y)
>>> # Print the durations
>>> print(librosa.get_duration(y), librosa.get_duration(yt))
61.45886621315193 60.58086167800454
'''

# Convert to mono
y_mono = core.to_mono(y)

# Compute the MSE for the signal
mse = feature.rmse(y=y_mono, n_fft=n_fft, hop_length=hop_length)**2
non_silent = _signal_to_frame_nonsilent(y,
frame_length=frame_length,
hop_length=hop_length,
ref_power=ref_power,
top_db=top_db)

# Compute the log power indicator and non-zero positions
logp = core.logamplitude(mse, ref_power=ref_power, top_db=None) > - top_db
nonzero = np.flatnonzero(logp)
nonzero = np.flatnonzero(non_silent)

# Compute the start and end positions
# End position goes one frame past the last non-zero
start = int(core.frames_to_samples(nonzero[0], hop_length))
end = min(len(y_mono),
end = min(y.shape[-1],
int(core.frames_to_samples(nonzero[-1] + 1, hop_length)))

# Build the mono/stereo index
full_index = [slice(None)] * y.ndim
full_index[-1] = slice(start, end)

if index:
return y[full_index], full_index[-1]
else:
return y[full_index]
return y[full_index], np.asarray([start, end])


def split(y, top_db=60, ref_power=np.max, frame_length=2048, hop_length=512):
'''Split an audio signal into non-silent intervals.
Parameters
----------
y : np.ndarray, shape=(n,) or (2, n)
An audio signal
top_db : number > 0
The threshold (in decibels) below reference to consider as
silence
ref_power : number or callable
The reference power. By default, it uses `np.max` and compares
to the peak power in the signal.
frame_length : int > 0
The number of samples per analysis frame
hop_length : int > 0
The number of samples between analysis frames
Returns
-------
intervals : np.ndarray, shape=(m, 2)
`intervals[i] == (start_i, end_i)` are the start and end time
(in samples) if the `i`th non-silent interval.
'''

non_silent = _signal_to_frame_nonsilent(y,
frame_length=frame_length,
hop_length=hop_length,
ref_power=ref_power,
top_db=top_db)

# Interval slicing, adapted from
# https://stackoverflow.com/questions/2619413/efficiently-finding-the-interval-with-non-zeros-in-scipy-numpy-in-python
# Find points where the sign flips
edges = np.flatnonzero(np.diff(non_silent.astype(int)))

# Pad back the sample lost in the diff
edges = [edges + 1]

# If the first frame had high energy, count it
if non_silent[0]:
edges.insert(0, [0])

# Likewise for the last frame
if non_silent[-1]:
edges.append([len(non_silent)])

# Convert from frames to samples
edges = core.frames_to_samples(np.concatenate(edges),
hop_length=hop_length)

# Stack the results back as an ndarray
return edges.reshape((-1, 2))
88 changes: 64 additions & 24 deletions tests/test_effects.py
Expand Up @@ -143,42 +143,82 @@ def test_harmonic():

def test_trim():

def __test(y, top_db, ref_power, index):
def __test(y, top_db, ref_power, trim_duration):
yt, idx = librosa.effects.trim(y, top_db=top_db,
ref_power=ref_power)

if index:
yt, idx = librosa.effects.trim(y, top_db=top_db,
ref_power=ref_power,
index=True)

# Test for index position
fidx = [slice(None)] * y.ndim
fidx[-1] = idx
assert np.allclose(yt, y[fidx])

else:
yt = librosa.effects.trim(y, top_db=top_db, ref_power=ref_power,
index=False)
# Test for index position
fidx = [slice(None)] * y.ndim
fidx[-1] = slice(*idx.tolist())
assert np.allclose(yt, y[fidx])

# Verify logamp
rms = librosa.feature.rmse(librosa.to_mono(yt))
logamp = librosa.logamplitude(rms**2, ref_power=ref_power, top_db=None)
assert np.all(logamp > - top_db)

# Verify logamp
rms_all = librosa.feature.rmse(librosa.to_mono(y)).squeeze()
logamp_all = librosa.logamplitude(rms_all**2, ref_power=ref_power,
top_db=None)

assert np.all(logamp >= - top_db)
start = int(librosa.samples_to_frames(idx[0]))
stop = int(librosa.samples_to_frames(idx[1]))
assert np.all(logamp_all[:start] <= - top_db)
assert np.all(logamp_all[stop:] <= - top_db)

# Verify duration
duration = librosa.get_duration(yt)
assert np.allclose(duration, 3.0, atol=1e-1), duration
assert np.allclose(duration, trim_duration, atol=1e-1), duration

# construct 5 seconds of stereo silence
# Stick a sine wave in the middle three seconds
sr = float(22050)
y = np.zeros((2, int(5 * sr)))
y[0, sr:4*sr] = np.sin(2 * np.pi * 440 * np.arange(0, 3 * sr) / sr)
trim_duration = 3.0
y = np.sin(2 * np.pi * 440. * np.arange(0, trim_duration * sr) / sr)
y = librosa.util.pad_center(y, 5 * sr)
y = np.vstack([y, np.zeros_like(y)])

for top_db in [60, 40, 20]:
for index in [False, True]:
for ref_power in [1, np.max]:
# Test stereo
yield __test, y, top_db, ref_power, index
# Test mono
yield __test, y[0], top_db, ref_power, index
for ref_power in [1, np.max]:
# Test stereo
yield __test, y, top_db, ref_power, trim_duration
# Test mono
yield __test, y[0], top_db, ref_power, trim_duration


def test_split():

def __test(hop_length, frame_length, top_db):

intervals = librosa.effects.split(y,
top_db=top_db,
frame_length=frame_length,
hop_length=hop_length)

int_match = librosa.util.match_intervals(intervals, idx_true)

for i in range(len(intervals)):
i_true = idx_true[int_match[i]]

assert np.all(np.abs(i_true - intervals[i]) <= frame_length), intervals[i]

# Make some high-frequency noise
sr = 8192

y = np.ones(10 * sr)
y[::2] *= -1

# Zero out all but two intervals
y[:sr] = 0
y[2 * sr:3 * sr] = 0
y[4 * sr:] = 0

# The true non-silent intervals
idx_true = np.asarray([[sr, 2 * sr],
[3 * sr, 4 * sr]])

for frame_length in [1024, 2048, 4096]:
for hop_length in [256, 512, 1024]:
for top_db in [20, 60, 80]:
yield __test, hop_length, frame_length, top_db

0 comments on commit c8280cd

Please sign in to comment.