# Import

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from scipy import signal
from IPython.display import display
import ipywidgets as widgets



# Setup

In [3]:
def get_window(win_type: str, N: int):
    """Return analysis window of length N."""
    if win_type == "rect":
        return np.ones(N)
    # scipy.signal.get_window supports many window names
    return signal.get_window(win_type, N, fftbins=True)

def next_pow2(n: int) -> int:
    return 1 << (int(n - 1).bit_length())

def ola_fft_convolve(x, h, Nw, hop, win_type="hann", nfft=None, normalize=True):
    """
    Online-ish convolution via frame-by-frame FFT and overlap-add.

    x: input signal (1D)
    h: FIR filter (1D)
    Nw: frame/window length
    hop: hop size (advance per frame)
    win_type: window type (rect, hann, hamming, blackman, ...)
    nfft: FFT size (>= Nw + len(h) - 1). If None -> next power of 2.
    normalize: if True, divide by overlap-added window power (helps for non-rect windows)
    """
    x = np.asarray(x, dtype=float)
    h = np.asarray(h, dtype=float)

    Lh = len(h)
    if nfft is None:
        nfft = next_pow2(Nw + Lh - 1)
    if nfft < Nw + Lh - 1:
        raise ValueError("nfft must be >= Nw + len(h) - 1 for linear convolution per frame.")

    w = get_window(win_type, Nw)

    # Output length for full linear convolution
    y_len = len(x) + Lh - 1
    y = np.zeros(y_len)

    # For normalization (to compensate window overlap)
    if normalize:
        win_acc = np.zeros(y_len)

    H = np.fft.rfft(h, nfft)

    # Process frames
    n_frames = int(np.ceil((len(x) - 1) / hop)) + 1
    for m in range(n_frames):
        start = m * hop
        if start >= len(x):
            break

        frame = np.zeros(Nw)
        end = min(start + Nw, len(x))
        frame[:end-start] = x[start:end]
        frame_w = frame * w

        Yf = np.fft.rfft(frame_w, nfft) * H
        y_block = np.fft.irfft(Yf, nfft)  # length nfft

        # Overlap-add into output
        out_start = start
        out_end = min(out_start + nfft, y_len)
        y[out_start:out_end] += y_block[:out_end - out_start]

        if normalize:
            # Accumulate window power aligned to where the time-domain frame contributes.
            # We only know the first Nw samples of the block are windowed; the rest are from convolution tails.
            # A practical normalization is to overlap-add w^2 at the same positions as the frame start.
            win_end = min(start + Nw, y_len)
            win_acc[start:win_end] += (w[:win_end-start] ** 2)

    if normalize:
        eps = 1e-12
        # Normalize only where window energy exists; tails after last frame should not blow up.
        nz = win_acc > eps
        y[nz] /= win_acc[nz]

    return y

def make_test_signal(fs=16000, duration=1.5, seed=0):
    rng = np.random.default_rng(seed)
    t = np.arange(int(fs * duration)) / fs
    # Mixture: two tones + some noise + a transient
    x = 0.6*np.sin(2*np.pi*440*t) + 0.3*np.sin(2*np.pi*880*t)
    x += 0.05*rng.standard_normal(len(t))
    x[int(0.6*len(t)):int(0.6*len(t))+200] += signal.windows.hann(200) * 1.0
    return x, fs

def make_fir_lowpass(numtaps, cutoff_hz, fs, window="hann"):
    # Ensure odd taps for linear phase symmetry (optional)
    if numtaps % 2 == 0:
        numtaps += 1
    h = signal.firwin(numtaps, cutoff_hz, fs=fs, window=window, pass_zero="lowpass")
    return h


# Signal filtering

In [4]:
x, fs = make_test_signal()

win_type_widget = widgets.Dropdown(
    options=[("Rectangular (exact-ish)", "rect"),
             ("Hann", "hann"),
             ("Hamming", "hamming"),
             ("Blackman", "blackman")],
    value="hann",
    description="Window:",
)

Nw_widget = widgets.IntSlider(
    value=1024, min=128, max=8192, step=128,
    description="Win size",
    continuous_update=False
)

hop_widget = widgets.IntSlider(
    value=256, min=64, max=4096, step=64,
    description="Hop",
    continuous_update=False
)

Lh_widget = widgets.IntSlider(
    value=257, min=17, max=2049, step=2,
    description="Filter taps",
    continuous_update=False
)

cutoff_widget = widgets.FloatSlider(
    value=2000.0, min=200.0, max=7000.0, step=100.0,
    description="Cutoff (Hz)",
    continuous_update=False
)

normalize_widget = widgets.Checkbox(
    value=True, description="Normalize overlap"
)

def run_demo(win_type, Nw, hop, Lh, cutoff_hz, normalize):
    # Build filter
    h = make_fir_lowpass(Lh, cutoff_hz, fs, window="hann")

    # Choose an FFT size large enough for linear convolution in each frame
    nfft = next_pow2(Nw + len(h) - 1)

    # Online-style OLA FFT convolution
    y_ola = ola_fft_convolve(
        x, h, Nw=Nw, hop=hop, win_type=win_type, nfft=nfft, normalize=normalize
    )

    # Reference full linear convolution
    y_ref = signal.fftconvolve(x, h, mode="full")

    # Align lengths
    L = min(len(y_ola), len(y_ref))
    err = y_ola[:L] - y_ref[:L]
    rmse = np.sqrt(np.mean(err**2))
    rel = rmse / (np.sqrt(np.mean(y_ref[:L]**2)) + 1e-12)

    # Plot a short excerpt for readability
    nshow = min(4000, L)
    t = np.arange(nshow) / fs

    plt.figure(figsize=(12, 8))

    plt.subplot(3, 1, 1)
    plt.title("Input signal (excerpt)")
    plt.plot(np.arange(min(len(x), int(0.5*fs))) / fs, x[:min(len(x), int(0.5*fs))])
    plt.xlabel("Time (s)")
    plt.grid(True)

    plt.subplot(3, 1, 2)
    plt.title(f"Output excerpt: OLA-FFT vs Reference (nfft={nfft})")
    plt.plot(t, y_ref[:nshow], label="Reference (fftconvolve)")
    plt.plot(t, y_ola[:nshow], label="OLA-FFT (frame-by-frame)", alpha=0.8)
    plt.xlabel("Time (s)")
    plt.legend()
    plt.grid(True)

    plt.subplot(3, 1, 3)
    plt.title(f"Error excerpt | RMSE={rmse:.3e}, Relative={rel:.3e}")
    plt.plot(t, err[:nshow])
    plt.xlabel("Time (s)")
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    print(f"Window type: {win_type}, Nw={Nw}, hop={hop}, filter taps={len(h)}, cutoff={cutoff_hz:.0f} Hz")
    print(f"FFT size used per frame: nfft={nfft}")
    print(f"RMSE: {rmse:.6e} | Relative RMSE: {rel:.6e}")

ui = widgets.VBox([
    win_type_widget,
    Nw_widget,
    hop_widget,
    Lh_widget,
    cutoff_widget,
    normalize_widget
])

out = widgets.interactive_output(
    run_demo,
    {
        "win_type": win_type_widget,
        "Nw": Nw_widget,
        "hop": hop_widget,
        "Lh": Lh_widget,
        "cutoff_hz": cutoff_widget,
        "normalize": normalize_widget
    }
)

display(ui, out)


VBox(children=(Dropdown(description='Window:', index=1, options=(('Rectangular (exact-ish)', 'rect'), ('Hann',â€¦

Output()