In [23]:
# One-cell Jupyter implementation: High-pass filter → block RMS/peakiness → adaptive threshold → hits → Plotly chart
import numpy as np
from scipy.io import wavfile
from scipy.signal import butter, sosfilt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio


def doAnalysis(wav_path: str, *,
               block: int = 2048,
               hp_cutoff_hz: int = 15000,
               refractory_s: float = 1.0, 
               noise_alpha: float = 0.01,
               mult: float = 8.0, 
               min_thresh: float = 0.03,
               peakiness_min: float = 8.0, 
               use_peakiness_gate: bool = False):
    """Analyze a WAV file for strike events using a high-pass RMS detector.

    Parameters
    ----------
    path : str
        Path to the input WAV file. 
        """

    # ---------- Config ----------
    FS_TARGET = 48000                # informational (we use the file's fs)
    HP_ORDER = 4                     # 4th order (sos, stable)

    # ---------- Load WAV (mono float) ----------
    fs, x = wavfile.read(wav_path)

    if x.dtype == np.int16:
        x = x.astype(np.float32) / 32768.0
    elif x.dtype == np.int32:
        x = x.astype(np.float32) / 2147483648.0
    else:
        x = x.astype(np.float32)

    if x.ndim > 1:
        x = x[:, 0]

    x = x - np.mean(x)  # remove DC
    duration_s = len(x) / fs

    # ---------- High-pass filter ----------
    sos = butter(HP_ORDER, hp_cutoff_hz, btype="highpass", fs=fs, output="sos")
    x_hp = sosfilt(sos, x)

    # ---------- Block features: RMS + Peakiness ----------
    n_blocks = len(x_hp) // block
    x_hp_b = x_hp[:n_blocks * block].reshape(n_blocks, block)
    x_raw_b = x[:n_blocks * block].reshape(n_blocks, block)

    rms_hp = np.sqrt(np.mean(x_hp_b * x_hp_b, axis=1) + 1e-12)
    rms_raw = np.sqrt(np.mean(x_raw_b * x_raw_b, axis=1) + 1e-12)
    peak_hp = np.max(np.abs(x_hp_b), axis=1) + 1e-12
    peakiness = peak_hp / (rms_hp + 1e-12)

    t = (np.arange(n_blocks) * block) / fs  # block timebase

    # ---------- Adaptive threshold + hit detection ----------
    noise = 0.0
    thr = np.zeros_like(rms_hp)
    gate_rms = np.zeros_like(rms_hp, dtype=bool)
    gate_pk = np.zeros_like(rms_hp, dtype=bool)
    gate_all = np.zeros_like(rms_hp, dtype=bool)

    hits = []
    last_hit_t = -1e9

    for i in range(n_blocks):
        ti = float(t[i])

        # update noise floor only when "far" from a hit (prevents learning the hit)
        if (ti - last_hit_t) > 0.25:
            noise = rms_hp[i] if noise == 0.0 else (
                1.0 - noise_alpha) * noise + noise_alpha * rms_hp[i]

        thr_i = max(min_thresh, mult * noise)
        thr[i] = thr_i

        gate_r = rms_hp[i] > thr_i
        gate_p = peakiness[i] > peakiness_min
        gate_rms[i] = gate_r
        gate_pk[i] = gate_p

        ok = gate_r and ((not use_peakiness_gate) or gate_p)

        if ok and (ti - last_hit_t) >= refractory_s:
            hits.append(ti)
            last_hit_t = ti
            gate_all[i] = True

    # ---------- Plotly chart (same spirit as plot_features_with_hits) ----------
    pio.renderers.default = "notebook_connected"  # if issues, try: "browser"

    fig = make_subplots(
        rows=4, cols=1, shared_xaxes=True, vertical_spacing=0.06,
        subplot_titles=(
            "Raw waveform (pre-HP)",
            f"High-pass RMS per block (HP>{hp_cutoff_hz} Hz) with Adaptive Threshold (mult={mult}, min={min_thresh})",
            f"Peakiness per block (peak/RMS) (min={peakiness_min}, enabled={use_peakiness_gate})",
            "Gate / Hits",
        )
    )

    # Row 1: raw waveform (before HP)
    fig.add_trace(go.Scatter(x=t, y=rms_raw, mode="lines",
                             name="raw", line=dict(width=1)), row=1, col=1)

    # Row 2: RMS + threshold (previously row 1)
    fig.add_trace(go.Scatter(x=t, y=rms_hp, mode="lines",
                             name="rms_hp", line=dict(width=1)), row=2, col=1)
    fig.add_trace(go.Scatter(x=t, y=thr, mode="lines", name="threshold",
                             line=dict(width=1, dash="dash")), row=2, col=1)

    # Row 3: Peakiness (previously row 2)
    fig.add_trace(go.Scatter(x=t, y=peakiness, mode="lines",
                             name="peakiness", line=dict(width=1)), row=3, col=1)
    fig.add_hline(y=peakiness_min, line_dash="dash", row=3, col=1)

    # Row 4: gates and hit markers (previously row 3)
    fig.add_trace(go.Scatter(x=t, y=gate_rms.astype(int), mode="lines",
                             name="gate_rms", line=dict(width=1)), row=4, col=1)
    if use_peakiness_gate:
        fig.add_trace(go.Scatter(x=t, y=gate_pk.astype(int), mode="lines",
                                 name="gate_peakiness", line=dict(width=1)), row=4, col=1)
    fig.add_trace(go.Scatter(x=t, y=gate_all.astype(int),
                             mode="markers", name="HIT", marker=dict(size=6)), row=4, col=1)
    # Add vertical lines for hits across all rows
    for th in hits:
        for r in (1, 2, 3):
            fig.add_vline(x=th, line_dash="dash", line_width=1, row=r, col=1)

    fig.update_yaxes(title_text="RMS", row=1, col=1)
    fig.update_yaxes(title_text="Peak/RMS", row=2, col=1)
    fig.update_yaxes(title_text="Gate", row=3, col=1, range=[-0.1, 1.1])
    fig.update_xaxes(title_text="Time (s)", row=3, col=1)

    fig.update_layout(
        title=f"StrikePoint: HP-filtered RMS detector on {wav_path.split('/')[-1]} (fs={fs} Hz, duration={duration_s:.1f}s)",
        template="plotly_white",
        height=900,
        legend=dict(orientation="h", yanchor="bottom",
                    y=1.02, xanchor="left", x=0.0)
    )

    print(f"Detected hits ({len(hits)}): {list(int(a*1e9) for a in hits)}")
    fig.show()

In [24]:
doAnalysis("test-01.wav")

Detected hits (5): [14421333333, 23424000000, 33664000000, 44629333333, 56405333333]


In [25]:
doAnalysis("test-02.wav")

Detected hits (6): [5077333333, 11520000000, 21888000000, 33962666666, 44416000000, 54869333333]
