Skip to content

Commit

Permalink
FIX: Fix filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored and agramfort committed Jun 26, 2015
1 parent 4393f01 commit beedbfb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 49 deletions.
8 changes: 6 additions & 2 deletions mne/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,11 @@ def fft_resample(x, W, new_len, npad, to_remove,
def _smart_pad(x, n_pad):
"""Pad vector x
"""
if n_pad == 0:
return x
elif n_pad < 0:
raise RuntimeError('n_pad must be non-negative')
# need to pad with zeros if len(x) <= npad
z_pad = np.zeros(max(n_pad - len(x) + 1, 0), dtype=x.dtype)
return np.r_[z_pad, 2 * x[0] - x[n_pad:0:-1], x,
2 * x[-1] - x[-2:-n_pad - 2:-1], z_pad]
return np.concatenate([z_pad, 2 * x[0] - x[n_pad:0:-1], x,
2 * x[-1] - x[-2:-n_pad - 2:-1], z_pad])
85 changes: 40 additions & 45 deletions mne/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _overlap_add_filter(x, h, n_fft=None, zero_phase=True, picks=None,
applied in forward and backward direction, resulting in a zero-phase
filter.
WARNING: This operates on the data in-place.
.. warning:: This operates on the data in-place.
Parameters
----------
Expand Down Expand Up @@ -77,9 +77,9 @@ def _overlap_add_filter(x, h, n_fft=None, zero_phase=True, picks=None,
# Extend the signal by mirroring the edges to reduce transient filter
# response
n_h = len(h)
n_edge = min(n_h, x.shape[1])
n_edge = max(min(n_h, x.shape[1]) - 1, 0)

n_x = x.shape[1] + 2 * n_edge - 2
n_x = x.shape[1] + 2 * n_edge

# Determine FFT length to use
if n_fft is None:
Expand Down Expand Up @@ -112,75 +112,66 @@ def _overlap_add_filter(x, h, n_fft=None, zero_phase=True, picks=None,
warnings.warn("FFT length is not a power of 2. Can be slower.")

# Filter in frequency domain
h_fft = fft(np.r_[h, np.zeros(n_fft - n_h, dtype=h.dtype)])

if zero_phase:
# We will apply the filter in forward and backward direction: Scale
# frequency response of the filter so that the shape of the amplitude
# response stays the same when it is applied twice

# be careful not to divide by too small numbers
idx = np.where(np.abs(h_fft) > 1e-6)
h_fft[idx] = h_fft[idx] / np.sqrt(np.abs(h_fft[idx]))

# Segment length for signal x
n_seg = n_fft - n_h + 1

# Number of segments (including fractional segments)
n_segments = int(np.ceil(n_x / float(n_seg)))
h_fft = fft(np.concatenate([h, np.zeros(n_fft - n_h, dtype=h.dtype)]))

# Figure out if we should use CUDA
n_jobs, cuda_dict, h_fft = setup_cuda_fft_multiply_repeated(n_jobs, h_fft)

# Process each row separately
if n_jobs == 1:
for p in picks:
x[p] = _1d_overlap_filter(x[p], h_fft, n_edge, n_fft, zero_phase,
n_segments, n_seg, cuda_dict)
x[p] = _1d_overlap_filter(x[p], h_fft, n_h, n_edge, zero_phase,
cuda_dict)
else:
parallel, p_fun, _ = parallel_func(_1d_overlap_filter, n_jobs)
data_new = parallel(p_fun(x[p], h_fft, n_edge, n_fft, zero_phase,
n_segments, n_seg, cuda_dict)
data_new = parallel(p_fun(x[p], h_fft, n_h, n_edge, zero_phase,
cuda_dict)
for p in picks)
for pp, p in enumerate(picks):
x[p] = data_new[pp]

return x


def _1d_overlap_filter(x, h_fft, n_edge, n_fft, zero_phase, n_segments, n_seg,
cuda_dict):
def _1d_overlap_filter(x, h_fft, n_h, n_edge, zero_phase, cuda_dict):
"""Do one-dimensional overlap-add FFT FIR filtering"""
# pad to reduce ringing
x_ext = _smart_pad(x, n_edge - 1)
n_fft = len(h_fft)
x_ext = _smart_pad(x, n_edge)
n_x = len(x_ext)
filter_input = x_ext
x_filtered = np.zeros_like(filter_input)

for pass_no in list(range(2)) if zero_phase else list(range(1)):
# Segment length for signal x
n_seg = n_fft - n_h + 1

# Number of segments (including fractional segments)
n_segments = int(np.ceil(n_x / float(n_seg)))

for pass_no in list(range(2 if zero_phase else 1)):

if pass_no == 1:
# second pass: flip signal
filter_input = np.flipud(x_filtered)
filter_input = x_filtered[::-1]
x_filtered = np.zeros_like(x_ext)

for seg_idx in range(n_segments):
seg = filter_input[seg_idx * n_seg:(seg_idx + 1) * n_seg]
seg = np.r_[seg, np.zeros(n_fft - len(seg))]
start = seg_idx * n_seg
stop = (seg_idx + 1) * n_seg
seg = filter_input[start:stop]
seg = np.concatenate([seg, np.zeros(n_fft - len(seg))])
prod = fft_multiply_repeated(h_fft, seg, cuda_dict)

if seg_idx * n_seg + n_fft < n_x:
x_filtered[seg_idx * n_seg:seg_idx * n_seg + n_fft] += prod
x_filtered[start:start + n_fft] += prod
else:
# Last segment
x_filtered[seg_idx * n_seg:] += prod[:n_x - seg_idx * n_seg]

# Remove mirrored edges that we added
x_filtered = x_filtered[n_edge - 1:-n_edge + 1]

if zero_phase:
# flip signal back
x_filtered = np.flipud(x_filtered)
x_filtered[start:] += prod[:n_x - seg_idx * n_seg]

# Remove mirrored edges that we added, flip back if necessary, cast
if n_edge > 0:
x_filtered = x_filtered[n_edge:-n_edge]
x_filtered = x_filtered[::-1] if zero_phase else x_filtered
x_filtered = x_filtered.astype(x.dtype)
return x_filtered

Expand Down Expand Up @@ -307,16 +298,16 @@ def _filter(x, Fs, freq, gain, filter_length='10s', picks=None, n_jobs=1,

N = x.shape[1] + (extend_x is True)

H = firwin2(N, freq, gain)[np.newaxis, :]
h = firwin2(N, freq, gain)[np.newaxis, :]

att_db, att_freq = _filter_attenuation(H, freq, gain)
att_db, att_freq = _filter_attenuation(h, freq, gain)
if att_db < min_att_db:
att_freq *= Fs / 2
warnings.warn('Attenuation at stop frequency %0.1fHz is only '
'%0.1fdB.' % (att_freq, att_db))

# Make zero-phase filter function
B = np.abs(fft(H)).ravel()
B = np.abs(fft(h)).ravel()

# Figure out if we should use CUDA
n_jobs, cuda_dict, B = setup_cuda_fft_multiply_repeated(n_jobs, B)
Expand All @@ -339,17 +330,21 @@ def _filter(x, Fs, freq, gain, filter_length='10s', picks=None, n_jobs=1,
# Gain at Nyquist freq: 1: make N EVEN, 0: make N ODD
N += 1

H = firwin2(N, freq, gain)
# construct filter with gain resulting from forward-backward filtering
h = firwin2(N, freq, gain, window='hann')

att_db, att_freq = _filter_attenuation(H, freq, gain)
att_db, att_freq = _filter_attenuation(h, freq, gain)
att_db += 6 # the filter is applied twice (zero phase)
if att_db < min_att_db:
att_freq *= Fs / 2
warnings.warn('Attenuation at stop frequency %0.1fHz is only '
'%0.1fdB. Increase filter_length for higher '
'attenuation.' % (att_freq, att_db))

x = _overlap_add_filter(x, H, zero_phase=True, picks=picks,
# reconstruct filter, this time with appropriate gain for fwd-bkwd
gain = np.sqrt(gain)
h = firwin2(N, freq, gain, window='hann')
x = _overlap_add_filter(x, h, zero_phase=True, picks=picks,
n_jobs=n_jobs)

x.shape = orig_shape
Expand Down
21 changes: 19 additions & 2 deletions mne/tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_almost_equal,
assert_array_equal)
assert_array_equal, assert_allclose)
from nose.tools import assert_equal, assert_true, assert_raises
import os.path as op
import warnings
Expand Down Expand Up @@ -144,7 +144,11 @@ def test_filters():
lp_oa = low_pass_filter(a, sfreq, 8, filter_length)
hp_oa = high_pass_filter(lp_oa, sfreq, 4, filter_length)
assert_array_almost_equal(hp_oa, bp_oa, 2)
assert_array_almost_equal(bp_oa + bs_oa, a, 2)
# Our filters are no longer quite complementary with linear rolloffs :(
# this is the tradeoff for stability of the filtering
# obtained by directly using the result of firwin2 instead of
# modifying it...
assert_array_almost_equal(bp_oa + bs_oa, a, 1)

# The two methods should give the same result
# As filtering for short signals uses a circular convolution (FFT) and
Expand Down Expand Up @@ -206,6 +210,19 @@ def test_filters():
assert_raises(ValueError, band_pass_filter, a, sfreq, Fp1=4, Fp2=8,
picks=np.array([0, 1]))

# test that our overlap-add filtering doesn't introduce strange
# artifacts (from mne_analyze mailing list 2015/06/25)
N = 300
sfreq = 100.
lp = 10.
sine_freq = 1.
x = np.ones(N)
x += np.sin(2 * np.pi * sine_freq * np.arange(N) / sfreq)
with warnings.catch_warnings(record=True): # filter attenuation
x_filt = low_pass_filter(x, sfreq, lp, '1s')
# the firwin2 function gets us this close
assert_allclose(x, x_filt, rtol=1e-3, atol=1e-3)


def test_cuda():
"""Test CUDA-based filtering
Expand Down

0 comments on commit beedbfb

Please sign in to comment.