Skip to content

Commit

Permalink
Merge pull request #8034 from andfoy/fix_lsosfilt_zi
Browse files Browse the repository at this point in the history
Fix lfilter_zi and sosfilt_zi when any IIR coefficient is zero
  • Loading branch information
takagi committed Dec 19, 2023
2 parents 7f55d35 + bb2ca1a commit ec763cd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
23 changes: 19 additions & 4 deletions cupyx/scipy/signal/_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ def lfilter_zi(b, a):
y, _ = lfilter(b, a, cupy.ones(num_a + 1), zi=zi_t)
y1 = y[:num_a]
y2 = y[-num_a:]
zero_coef = cupy.where(a_r == 0)[0]

C = compute_correction_factors(a_r, a_r.size + 1, a_r.dtype)
C = C[:, a_r.size:]
Expand All @@ -926,9 +927,15 @@ def lfilter_zi(b, a):

# Take the difference between the non-adjusted output values and
# compute which initial output state would cause them to be constant.
y_zi = cupy.linalg.solve(C1 - C2, y2 - y1)
if not len(zero_coef):
y_zi = cupy.linalg.solve(C1 - C2, y2 - y1)
else:
# Any zero coefficient would cause the system to be underdetermined
# therefore a least square solution is computed instead.
y_zi, _, _, _ = cupy.linalg.lstsq(C1 - C2, y2 - y1, rcond=None)

y_zi = cupy.nan_to_num(y_zi, nan=0, posinf=cupy.inf, neginf=-cupy.inf)
zi = cupy.r_[zi, y_zi]
zi = cupy.r_[zi, y_zi[::-1]]
return zi


Expand Down Expand Up @@ -1546,11 +1553,19 @@ def sosfilt_zi(sos):
C1 = C_s[:, :2].T
C2 = C_s[:, -2:].T

zero_iir_coef = cupy.where(sos[s, 3:] == 0)[0]

# Take the difference between the non-adjusted output values and
# compute which initial output state would cause them to be constant.
y_zi = cupy.linalg.solve(C1 - C2, y2 - y1)
if not len(zero_iir_coef):
y_zi = cupy.linalg.solve(C1 - C2, y2 - y1)
else:
# Any zero coefficient would cause the system to be underdetermined
# therefore a least square solution is computed instead.
y_zi, _, _, _ = cupy.linalg.lstsq(C1 - C2, y2 - y1, rcond=None)

y_zi = cupy.nan_to_num(y_zi, nan=0, posinf=cupy.inf, neginf=-cupy.inf)
zi_s[0, 2:] = y_zi
zi_s[0, 2:] = y_zi[::-1]
x_s, _ = sosfilt(sos_s, x_s, zi=zi_s)

return zi
Expand Down
31 changes: 31 additions & 0 deletions tests/cupyx_tests/scipy_tests/signal_tests/test_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,24 @@ def test_lfilter_zi(self, fir_order, iir_order, xp, scp):
out, _ = scp.signal.lfilter(b, a, x, zi=zi)
return out

@pytest.mark.parametrize(
'zeros', [(2,), (1,), (0,), (1, 2), (0, 1), (0, 2), (0, 1, 2)])
@testing.numpy_cupy_array_almost_equal(scipy_name='scp', decimal=5)
def test_lfilter_zi_with_zeros(self, zeros, xp, scp):
fir_order = 3
iir_order = 3

x = xp.ones(20)
b = testing.shaped_random((fir_order,), xp, scale=0.3)
a = testing.shaped_random((iir_order,), xp, scale=0.3)
a[list(zeros)] = 0
a = xp.r_[1, a]
a = a.astype(x.dtype)

zi = scp.signal.lfilter_zi(b, a)
out, _ = scp.signal.lfilter(b, a, x, zi=zi)
return out


@testing.with_requires('scipy')
class TestDeconvolve:
Expand Down Expand Up @@ -783,6 +801,19 @@ def test_sosfilt_zi(self, sections, xp, scp):
out, _ = scp.signal.sosfilt(sos, x, zi=zi)
return out

@pytest.mark.parametrize(
'zeros', [(4,), (5,), (4, 5)])
@testing.numpy_cupy_array_almost_equal(scipy_name='scp', decimal=5)
def test_sosfilt_zi_with_zeros(self, zeros, xp, scp):
x = xp.ones(20)
sos = testing.shaped_random((1, 6), xp, xp.float64, scale=0.2)
sos[:, 3] = 1
sos[0, list(zeros)] = 0

zi = scp.signal.sosfilt_zi(sos)
out, _ = scp.signal.sosfilt(sos, x, zi=zi)
return out


@testing.with_requires('scipy')
class TestDetrend:
Expand Down

0 comments on commit ec763cd

Please sign in to comment.