diff --git a/doc/changes/devel/12507.bugfix.rst b/doc/changes/devel/12507.bugfix.rst new file mode 100644 index 00000000000..c172701bb93 --- /dev/null +++ b/doc/changes/devel/12507.bugfix.rst @@ -0,0 +1,5 @@ +Fix bug where using ``phase="minimum"`` in filtering functions like +:meth:`mne.io.Raw.filter` constructed a filter half the desired length with +compromised attenuation. Now ``phase="minimum"`` has the same length and comparable +suppression as ``phase="zero"``, and the old (incorrect) behavior can be achieved +with ``phase="minimum-half"``, by `Eric Larson`_. diff --git a/mne/filter.py b/mne/filter.py index 290ddf7f7d7..82b77a17a7c 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -20,6 +20,7 @@ _setup_cuda_fft_resample, _smart_pad, ) +from .fixes import minimum_phase from .parallel import parallel_func from .utils import ( _check_option, @@ -307,39 +308,7 @@ def _overlap_add_filter( copy=True, pad="reflect_limited", ): - """Filter the signal x using h with overlap-add FFTs. - - Parameters - ---------- - x : array, shape (n_signals, n_times) - Signals to filter. - h : 1d array - Filter impulse response (FIR filter coefficients). Must be odd length - if ``phase='linear'``. - n_fft : int - Length of the FFT. If None, the best size is determined automatically. - phase : str - If ``'zero'``, the delay for the filter is compensated (and it must be - an odd-length symmetric filter). If ``'linear'``, the response is - uncompensated. If ``'zero-double'``, the filter is applied in the - forward and reverse directions. If 'minimum', a minimum-phase - filter will be used. - picks : list | None - See calling functions. - n_jobs : int | str - Number of jobs to run in parallel. Can be ``'cuda'`` if ``cupy`` - is installed properly. - copy : bool - If True, a copy of x, filtered, is returned. Otherwise, it operates - on x in place. - pad : str - Padding type for ``_smart_pad``. - - Returns - ------- - x : array, shape (n_signals, n_times) - x filtered. - """ + """Filter the signal x using h with overlap-add FFTs.""" # set up array for filtering, reshape to 2D, operate on last axis x, orig_shape, picks = _prep_for_filtering(x, copy, picks) # Extend the signal by mirroring the edges to reduce transient filter @@ -526,34 +495,6 @@ def _construct_fir_filter( (windowing is a smoothing in frequency domain). If x is multi-dimensional, this operates along the last dimension. - - Parameters - ---------- - sfreq : float - Sampling rate in Hz. - freq : 1d array - Frequency sampling points in Hz. - gain : 1d array - Filter gain at frequency sampling points. - Must be all 0 and 1 for fir_design=="firwin". - filter_length : int - Length of the filter to use. Must be odd length if phase == "zero". - phase : str - If 'zero', the delay for the filter is compensated (and it must be - an odd-length symmetric filter). If 'linear', the response is - uncompensated. If 'zero-double', the filter is applied in the - forward and reverse directions. If 'minimum', a minimum-phase - filter will be used. - fir_window : str - The window to use in FIR design, can be "hamming" (default), - "hann", or "blackman". - fir_design : str - Can be "firwin2" or "firwin". - - Returns - ------- - h : array - Filter coefficients. """ assert freq[0] == 0 if fir_design == "firwin2": @@ -562,7 +503,7 @@ def _construct_fir_filter( assert fir_design == "firwin" fir_design = partial(_firwin_design, sfreq=sfreq) # issue a warning if attenuation is less than this - min_att_db = 12 if phase == "minimum" else 20 + min_att_db = 12 if phase == "minimum-half" else 20 # normalize frequencies freq = np.array(freq) / (sfreq / 2.0) @@ -575,11 +516,13 @@ def _construct_fir_filter( # Use overlap-add filter with a fixed length N = _check_zero_phase_length(filter_length, phase, gain[-1]) # construct symmetric (linear phase) filter - if phase == "minimum": + if phase == "minimum-half": h = fir_design(N * 2 - 1, freq, gain, window=fir_window) - h = signal.minimum_phase(h) + h = minimum_phase(h) else: h = fir_design(N, freq, gain, window=fir_window) + if phase == "minimum": + h = minimum_phase(h, half=False) assert h.size == N att_db, att_freq = _filter_attenuation(h, freq, gain) if phase == "zero-double": @@ -2162,7 +2105,7 @@ def detrend(x, order=1, axis=-1): "blackman": dict(name="Blackman", ripple=0.0017, attenuation=74), } _known_fir_windows = tuple(sorted(_fir_window_dict.keys())) -_known_phases_fir = ("linear", "zero", "zero-double", "minimum") +_known_phases_fir = ("linear", "zero", "zero-double", "minimum", "minimum-half") _known_phases_iir = ("zero", "zero-double", "forward") _known_fir_designs = ("firwin", "firwin2") _fir_design_dict = { diff --git a/mne/fixes.py b/mne/fixes.py index 2af4eba73b9..6d874be8805 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -889,3 +889,58 @@ def _numpy_h5py_dep(): "ignore", "`product` is deprecated.*", DeprecationWarning ) yield + + +def minimum_phase(h, method="homomorphic", n_fft=None, *, half=True): + """Wrap scipy.signal.minimum_phase with half option.""" + # Can be removed once + from scipy.fft import fft, ifft + from scipy.signal import minimum_phase as sp_minimum_phase + + assert isinstance(method, str) and method == "homomorphic" + + if "half" in inspect.getfullargspec(sp_minimum_phase).kwonlyargs: + return sp_minimum_phase(h, method=method, n_fft=n_fft, half=half) + h = np.asarray(h) + if np.iscomplexobj(h): + raise ValueError("Complex filters not supported") + if h.ndim != 1 or h.size <= 2: + raise ValueError("h must be 1-D and at least 2 samples long") + n_half = len(h) // 2 + if not np.allclose(h[-n_half:][::-1], h[:n_half]): + warnings.warn( + "h does not appear to by symmetric, conversion may fail", + RuntimeWarning, + stacklevel=2, + ) + if n_fft is None: + n_fft = 2 ** int(np.ceil(np.log2(2 * (len(h) - 1) / 0.01))) + n_fft = int(n_fft) + if n_fft < len(h): + raise ValueError("n_fft must be at least len(h)==%s" % len(h)) + + # zero-pad; calculate the DFT + h_temp = np.abs(fft(h, n_fft)) + # take 0.25*log(|H|**2) = 0.5*log(|H|) + h_temp += 1e-7 * h_temp[h_temp > 0].min() # don't let log blow up + np.log(h_temp, out=h_temp) + if half: # halving of magnitude spectrum optional + h_temp *= 0.5 + # IDFT + h_temp = ifft(h_temp).real + # multiply pointwise by the homomorphic filter + # lmin[n] = 2u[n] - d[n] + # i.e., double the positive frequencies and zero out the negative ones; + # Oppenheim+Shafer 3rd ed p991 eq13.42b and p1004 fig13.7 + win = np.zeros(n_fft) + win[0] = 1 + stop = n_fft // 2 + win[1:stop] = 2 + if n_fft % 2: + win[stop] = 1 + h_temp *= win + h_temp = ifft(np.exp(fft(h_temp))) + h_minimum = h_temp.real + + n_out = (n_half + len(h) % 2) if half else len(h) + return h_minimum[:n_out] diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index 23ff37b8591..00dce484a08 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -606,12 +606,12 @@ def test_filters(): # try new default and old default freqs = fftfreq(a.shape[-1], 1.0 / sfreq) A = np.abs(fft(a)) - kwargs = dict(fir_design="firwin") + kw = dict(fir_design="firwin") for fl in ["auto", "10s", "5000ms", 1024, 1023]: - bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kwargs) - bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, **kwargs) - lp = filter_data(a, sfreq, None, 8, None, fl, 10, 1.0, n_jobs=2, **kwargs) - hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kwargs) + bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kw) + bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, **kw) + lp = filter_data(a, sfreq, None, 8, None, fl, 10, 1.0, n_jobs=2, **kw) + hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kw) assert_allclose(hp, bp, rtol=1e-3, atol=2e-3) assert_allclose(bp + bs, a, rtol=1e-3, atol=1e-3) # Sanity check ttenuation @@ -619,12 +619,18 @@ def test_filters(): assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.02) assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.2) # now the minimum-phase versions - bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum", **kwargs) + bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum-half", **kw) bs = filter_data( - a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum", **kwargs + a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum-half", **kw ) assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.11) assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.3) + bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum", **kw) + bs = filter_data( + a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum", **kw + ) + assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.12) + assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.27) # and since these are low-passed, downsampling/upsampling should be close n_resamp_ignore = 10 @@ -1050,3 +1056,45 @@ def test_filter_picks(): raw.filter(picks=picks, **kwargs) want = want[1:] assert_allclose(raw.get_data(), want) + + +def test_filter_minimum_phase_bug(): + """Test gh-12267 is fixed.""" + sfreq = 1000.0 + n_taps = 1001 + l_freq = 10.0 # Hz + kwargs = dict( + data=None, + sfreq=sfreq, + l_freq=l_freq, + h_freq=None, + filter_length=n_taps, + l_trans_bandwidth=l_freq / 2.0, + ) + h = create_filter(phase="zero", **kwargs) + h_min = create_filter(phase="minimum", **kwargs) + h_min_half = create_filter(phase="minimum-half", **kwargs) + assert h_min.size == h.size + kwargs = dict(worN=10000, fs=sfreq) + w, H = freqz(h, **kwargs) + assert w[0] == 0 + dc_dB = 20 * np.log10(np.abs(H[0])) + assert dc_dB < -100 + # good + w_min, H_min = freqz(h_min, **kwargs) + assert_allclose(w, w_min) + dc_dB_min = 20 * np.log10(np.abs(H_min[0])) + assert dc_dB_min < -100 + mask = w < 5 + assert 10 < mask.sum() < 101 + assert_allclose(np.abs(H[mask]), np.abs(H_min[mask]), atol=1e-3, rtol=1e-3) + assert_array_less(20 * np.log10(np.abs(H[mask])), -40) + assert_array_less(20 * np.log10(np.abs(H_min[mask])), -40) + # bad + w_min_half, H_min_half = freqz(h_min_half, **kwargs) + assert_allclose(w, w_min_half) + dc_dB_min_half = 20 * np.log10(np.abs(H_min_half[0])) + assert -80 < dc_dB_min_half < 40 + dB_min_half = 20 * np.log10(np.abs(H_min_half[mask])) + assert_array_less(dB_min_half, -20) + assert not (dB_min_half < -30).all() diff --git a/mne/utils/docs.py b/mne/utils/docs.py index e8c30716d48..f5d7c4f4669 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2809,21 +2809,36 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["phase"] = """ phase : str Phase of the filter. - When ``method='fir'``, symmetric linear-phase FIR filters are constructed, - and if ``phase='zero'`` (default), the delay of this filter is compensated - for, making it non-causal. If ``phase='zero-double'``, - then this filter is applied twice, once forward, and once backward - (also making it non-causal). If ``'minimum'``, then a minimum-phase filter - will be constructed and applied, which is causal but has weaker stop-band - suppression. - When ``method='iir'``, ``phase='zero'`` (default) or - ``phase='zero-double'`` constructs and applies IIR filter twice, once - forward, and once backward (making it non-causal) using - :func:`~scipy.signal.filtfilt`. - If ``phase='forward'``, it constructs and applies forward IIR filter using + When ``method='fir'``, symmetric linear-phase FIR filters are constructed + with the following behaviors when ``method="fir"``: + + ``"zero"`` (default) + The delay of this filter is compensated for, making it non-causal. + ``"minimum"`` + A minimum-phase filter will be constructed by decomposing the zero-phase filter + into a minimum-phase and all-pass systems, and then retaining only the + minimum-phase system (of the same length as the original zero-phase filter) + via :func:`scipy.signal.minimum_phase`. + ``"zero-double"`` + *This is a legacy option for compatibility with MNE <= 0.13.* + The filter is applied twice, once forward, and once backward + (also making it non-causal). + ``"minimum-half"`` + *This is a legacy option for compatibility with MNE <= 1.6.* + A minimum-phase filter will be reconstructed from the zero-phase filter with + half the length of the original filter. + + When ``method='iir'``, ``phase='zero'`` (default) or equivalently ``'zero-double'`` + constructs and applies IIR filter twice, once forward, and once backward (making it + non-causal) using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply + the filter once in the forward (causal) direction using :func:`~scipy.signal.lfilter`. .. versionadded:: 0.13 + .. versionchanged:: 1.7 + + The behavior for ``phase="minimum"`` was fixed to use a filter of the requested + length and improved suppression. """ docdict["physical_range_export_params"] = """ diff --git a/tutorials/preprocessing/25_background_filtering.py b/tutorials/preprocessing/25_background_filtering.py index cbd10ab213b..c0f56098bad 100644 --- a/tutorials/preprocessing/25_background_filtering.py +++ b/tutorials/preprocessing/25_background_filtering.py @@ -148,6 +148,7 @@ from scipy import signal import mne +from mne.fixes import minimum_phase from mne.time_frequency.tfr import morlet from mne.viz import plot_filter, plot_ideal_filter @@ -168,7 +169,7 @@ gain = [1, 1, 0, 0] third_height = np.array(plt.rcParams["figure.figsize"]) * [1, 1.0 / 3.0] -ax = plt.subplots(1, figsize=third_height)[1] +ax = plt.subplots(1, figsize=third_height, layout="constrained")[1] plot_ideal_filter(freq, gain, ax, title="Ideal %s Hz lowpass" % f_p, flim=flim) # %% @@ -249,7 +250,7 @@ freq = [0.0, f_p, f_s, nyq] gain = [1.0, 1.0, 0.0, 0.0] -ax = plt.subplots(1, figsize=third_height)[1] +ax = plt.subplots(1, figsize=third_height, layout="constrained")[1] title = f"{f_p} Hz lowpass with a {trans_bandwidth} Hz transition" plot_ideal_filter(freq, gain, ax, title=title, flim=flim) @@ -316,15 +317,15 @@ # is constant) but small in the pass-band. Unlike zero-phase filters, which # require time-shifting backward the output of a linear-phase filtering stage # (and thus becoming non-causal), minimum-phase filters do not require any -# compensation to achieve small delays in the pass-band. Note that as an -# artifact of the minimum phase filter construction step, the filter does -# not end up being as steep as the linear/zero-phase version. +# compensation to achieve small delays in the pass-band. # # We can construct a minimum-phase filter from our existing linear-phase -# filter with the :func:`scipy.signal.minimum_phase` function, and note -# that the falloff is not as steep: +# filter, and note that the falloff is not as steep. Here we do this with function +# ``mne.fixes.minimum_phase()`` to avoid a SciPy bug; once SciPy 1.14.0 is released you +# could directly use +# :func:`scipy.signal.minimum_phase(..., half=False) `. -h_min = signal.minimum_phase(h) +h_min = minimum_phase(h, half=False) plot_filter(h_min, sfreq, freq, gain, "Minimum-phase", **kwargs) # %% @@ -683,7 +684,6 @@ def plot_signal(x, offset): for text in axes[0].get_yticklabels(): text.set(rotation=45, size=8) axes[1].set(xlim=flim, ylim=(-60, 10), xlabel="Frequency (Hz)", ylabel="Magnitude (dB)") -mne.viz.adjust_axes(axes) plt.show() # %% @@ -779,7 +779,7 @@ def plot_signal(x, offset): xlabel = "Time (s)" ylabel = r"Amplitude ($\mu$V)" tticks = [0, 0.5, 1.3, t[-1]] -axes = plt.subplots(2, 2)[1].ravel() +axes = plt.subplots(2, 2, layout="constrained")[1].ravel() for ax, x_f, title in zip( axes, [x_lp_2, x_lp_30, x_hp_2, x_hp_p1], @@ -791,7 +791,6 @@ def plot_signal(x, offset): ylim=ylim, xlim=xlim, xticks=tticks, title=title, xlabel=xlabel, ylabel=ylabel ) -mne.viz.adjust_axes(axes) plt.show() # %% @@ -830,7 +829,7 @@ def plot_signal(x, offset): def baseline_plot(x): - all_axes = plt.subplots(3, 2, layout="constrained")[1] + fig, all_axes = plt.subplots(3, 2, layout="constrained") for ri, (axes, freq) in enumerate(zip(all_axes, [0.1, 0.3, 0.5])): for ci, ax in enumerate(axes): if ci == 0: @@ -846,8 +845,7 @@ def baseline_plot(x): ax.set(title=("No " if ci == 0 else "") + "Baseline Correction") ax.set(xticks=tticks, ylim=ylim, xlim=xlim, xlabel=xlabel) ax.set_ylabel("%0.1f Hz" % freq, rotation=0, horizontalalignment="right") - mne.viz.adjust_axes(axes) - plt.suptitle(title) + fig.suptitle(title) plt.show()