-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
53fe18f
commit ac280c3
Showing
4 changed files
with
261 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
"""Funtions for audio filtering.""" | ||
|
||
from typing import Optional | ||
|
||
import numpy as np | ||
import xarray as xr | ||
from scipy import signal | ||
|
||
__all__ = [ | ||
"filter", | ||
] | ||
|
||
|
||
def _get_filter( | ||
samplerate: int, | ||
low_freq: Optional[float] = None, | ||
high_freq: Optional[float] = None, | ||
order: int = 5, | ||
) -> np.ndarray: | ||
if low_freq is None and high_freq is None: | ||
raise ValueError( | ||
"At least one of low_freq and high_freq must be specified." | ||
) | ||
|
||
if low_freq is None: | ||
# Low pass filter | ||
return signal.butter( | ||
order, | ||
high_freq, | ||
btype="lowpass", | ||
output="sos", | ||
fs=samplerate, | ||
) | ||
|
||
if high_freq is None: | ||
# High pass filter | ||
return signal.butter( | ||
order, | ||
low_freq, | ||
btype="highpass", | ||
output="sos", | ||
fs=samplerate, | ||
) | ||
|
||
if low_freq > high_freq: | ||
raise ValueError("low_freq must be less than high_freq.") | ||
|
||
# Band pass filter | ||
return signal.butter( | ||
order, | ||
[low_freq, high_freq], | ||
btype="bandpass", | ||
output="sos", | ||
fs=samplerate, | ||
) | ||
|
||
|
||
def filter( | ||
audio: xr.DataArray, | ||
low_freq: Optional[float] = None, | ||
high_freq: Optional[float] = None, | ||
order: int = 5, | ||
) -> xr.DataArray: | ||
"""Filter audio data. | ||
This function assumes that the input audio object is a | ||
:class:`xarray.DataArray` with a "samplerate" attribute and a "time" | ||
dimension. | ||
The filtering is done using a Butterworth filter or the specified order. | ||
The type of filter (lowpass/highpass/bandpass filter) is determined | ||
by the specified cutoff frequencies. If only one cutoff frequency is | ||
specified, a low pass or high pass filter is used. If both cutoff | ||
frequencies are specified, a band pass filter is used. | ||
Parameters | ||
---------- | ||
audio : xr.DataArray | ||
The audio data to filter with a "samplerate" attribute and | ||
a "time" dimension. | ||
low_freq : float, optional | ||
The low cutoff frequency in Hz. | ||
high_freq : float, optional | ||
The high cutoff frequency in Hz. | ||
order : int, optional | ||
The order of the filter. By default, 5. | ||
Returns | ||
------- | ||
xr.DataArray | ||
The filtered audio data. | ||
Raises | ||
------ | ||
ValueError | ||
If neither low_freq nor high_freq is specified, or if both | ||
are specified and low_freq > high_freq. | ||
""" | ||
if not isinstance(audio, xr.DataArray): | ||
raise ValueError("Audio must be an xarray.DataArray") | ||
|
||
if "samplerate" not in audio.attrs: | ||
raise ValueError("Audio must have a 'samplerate' attribute") | ||
|
||
if "time" not in audio.dims: | ||
raise ValueError("Audio must have a time dimension") | ||
|
||
axis: int = audio.get_axis_num("time") # type: ignore | ||
sos = _get_filter( | ||
audio.attrs["samplerate"], | ||
low_freq, | ||
high_freq, | ||
order, | ||
) | ||
|
||
filtered = signal.sosfiltfilt(sos, audio.data, axis=axis) | ||
return xr.DataArray( | ||
data=filtered, | ||
dims=audio.dims, | ||
coords=audio.coords, | ||
attrs=audio.attrs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
"""Test suite for filtering functions.""" | ||
|
||
from unittest import mock | ||
|
||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
from scipy import signal | ||
|
||
from soundevent import audio | ||
|
||
|
||
def test_filter_audio_fails_if_no_samplerate(): | ||
"""Test that filter_audio fails if samplerate is missing.""" | ||
data = xr.DataArray(np.random.randn(100), dims=["time"]) | ||
with pytest.raises(ValueError): | ||
audio.filter(data, 16000) | ||
|
||
|
||
def test_filter_audio_fails_if_not_an_xarray(): | ||
"""Test that filter_audio fails if not an xarray.DataArray.""" | ||
data = np.random.randn(100) | ||
with pytest.raises(ValueError): | ||
audio.filter(data, 16000) # type: ignore | ||
|
||
|
||
def test_filter_audio_fails_if_no_time_axis(): | ||
"""Test that filter_audio fails with missing time axis.""" | ||
data = xr.DataArray( | ||
np.random.randn(100), | ||
dims=["channel"], | ||
coords={"channel": range(100)}, | ||
attrs={"samplerate": 16000}, | ||
) | ||
with pytest.raises(ValueError): | ||
audio.filter(data, 16000) | ||
|
||
|
||
def test_filter_audio_returns_an_xarray(): | ||
"""Test that filter_audio returns an xarray.DataArray.""" | ||
data = xr.DataArray( | ||
np.random.randn(1000), | ||
dims=["time"], | ||
coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, | ||
attrs={"samplerate": 16000}, | ||
) | ||
filtered = audio.filter(data, 1000) | ||
assert isinstance(filtered, xr.DataArray) | ||
|
||
|
||
def test_filter_audio_preserves_attrs(): | ||
"""Test that filter_audio preserves attributes.""" | ||
data = xr.DataArray( | ||
np.random.randn(1000), | ||
dims=["time"], | ||
coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, | ||
attrs={"samplerate": 16000, "other": "value"}, | ||
) | ||
filtered = audio.filter(data, 1000) | ||
assert filtered.attrs == data.attrs | ||
|
||
|
||
def test_filter_audio_fails_if_no_low_or_high_freq_provided(): | ||
"""Test filter_audio fails if low_freq and high_freq arent provided.""" | ||
data = xr.DataArray( | ||
np.random.randn(1000), | ||
dims=["time"], | ||
coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, | ||
attrs={"samplerate": 16000, "other": "value"}, | ||
) | ||
with pytest.raises(ValueError): | ||
audio.filter(data) | ||
|
||
|
||
def test_filter_audio_applies_a_lowpass_filter(): | ||
"""Test that filter_audio applies a lowpass filter.""" | ||
data = xr.DataArray( | ||
np.random.randn(1000), | ||
dims=["time"], | ||
coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, | ||
attrs={"samplerate": 16000}, | ||
) | ||
|
||
mock_butter = mock.Mock(side_effect=signal.butter) | ||
with mock.patch.object(signal, "butter", mock_butter): | ||
audio.filter(data, high_freq=6000) | ||
mock_butter.assert_called_once_with( | ||
5, | ||
6000, | ||
btype="lowpass", | ||
fs=16000, | ||
output="sos", | ||
) | ||
|
||
|
||
def test_filter_audio_applies_a_highpass_filter(): | ||
"""Test that filter_audio applies a highpass filter.""" | ||
data = xr.DataArray( | ||
np.random.randn(1000), | ||
dims=["time"], | ||
coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, | ||
attrs={"samplerate": 16000}, | ||
) | ||
|
||
mock_butter = mock.Mock(side_effect=signal.butter) | ||
with mock.patch.object(signal, "butter", mock_butter): | ||
audio.filter(data, low_freq=6000) | ||
mock_butter.assert_called_once_with( | ||
5, | ||
6000, | ||
btype="highpass", | ||
fs=16000, | ||
output="sos", | ||
) | ||
|
||
|
||
def test_filter_audio_applies_a_bandpass_filter(): | ||
"""Test that filter_audio applies a bandpass filter.""" | ||
data = xr.DataArray( | ||
np.random.randn(1000), | ||
dims=["time"], | ||
coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, | ||
attrs={"samplerate": 16000}, | ||
) | ||
|
||
mock_butter = mock.Mock(side_effect=signal.butter) | ||
with mock.patch.object(signal, "butter", mock_butter): | ||
audio.filter(data, low_freq=1000, high_freq=6000) | ||
mock_butter.assert_called_once_with( | ||
5, | ||
[1000, 6000], | ||
btype="bandpass", | ||
fs=16000, | ||
output="sos", | ||
) |