In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import (
    butter, filtfilt, welch,
    find_peaks, hilbert, medfilt
)

# -- 1. Decode raw BIN → ADC cube --
def decode_iwr6843_data(filename,
                        num_rx=4,
                        num_adc_samples=256,
                        num_chirps=128,
                        header_bytes=0):
    with open(filename,'rb') as f:
        f.seek(header_bytes)
        raw = np.fromfile(f, dtype=np.int16)
    raw = raw.reshape(-1,2)
    complex_data = raw[:,0] + 1j*raw[:,1]
    samp_pf = num_rx * num_chirps * num_adc_samples
    num_frames = len(complex_data)//samp_pf

    adc = np.empty((num_frames, num_rx, num_chirps, num_adc_samples),
                   dtype=complex)
    for fr in range(num_frames):
        for rx in range(num_rx):
            st = fr*samp_pf + rx*(num_chirps*num_adc_samples)
            en = st + (num_chirps*num_adc_samples)
            adc[fr,rx] = complex_data[st:en].reshape(num_chirps, num_adc_samples)
    return adc

# -- 2. Range FFT --
def compute_range_profiles(adc):
    return np.fft.fft(adc, axis=-1)

# -- 3. Range–Angle map --
def compute_range_angle_map(rp, n_angle_bins=256):
    avg_rx = rp.mean(axis=(0,2))
    ang = np.fft.fftshift(
        np.fft.fft(avg_rx, n=n_angle_bins, axis=0),
        axes=0
    )
    return np.abs(ang).T

# -- 4. Physical‐distance‐based target pick --
def detect_targets_physical(ra_map, num_targets=2,
                            min_dist_m=0.7, B=3.9e9,
                            n_angle_bins=256, R_mean=4.0):
    c = 3e8
    rng_res = c/(2*B)
    min_rb = int(np.ceil(min_dist_m/rng_res))
    ang_res = 180.0/n_angle_bins
    ang_sep = np.degrees(np.arcsin(min_dist_m/R_mean))
    min_ab = int(np.ceil(ang_sep/ang_res))

    R,A = ra_map.shape
    idxs = np.argsort(ra_map.flatten())[::-1]
    targets=[]
    for idx in idxs:
        r,a = divmod(idx,A)
        if r<min_rb or r>R-min_rb: continue
        if any(abs(r-r0)<min_rb and abs(a-a0)<min_ab for r0,a0 in targets):
            continue
        targets.append((r,a))
        if len(targets)==num_targets: break
    return targets

# -- 5. MVDR beamforming ±3-bin slab --
def separate_with_mvdr(adc, targets,
                       n_angle_bins=256,
                       fc=60e9,
                       rng_win=3,
                       eps=1e-6):
    frames, num_rx, chirps, samples = adc.shape
    rp = np.fft.fft(adc, axis=-1)
    rng_bins = rp.shape[-1]
    angle_axis = np.linspace(-90,90,n_angle_bins)

    def steer(theta):
        lam = 3e8/fc
        d   = lam/2
        k   = 2*np.pi/lam
        idx = np.arange(num_rx)
        return np.exp(-1j*k*d*idx*np.sin(np.deg2rad(theta)))[:,None]

    beams = np.zeros((frames,len(targets)),dtype=complex)
    for i,(rbin,abin) in enumerate(targets):
        lo,hi = max(0,rbin-rng_win), min(rng_bins,rbin+rng_win+1)
        gated = rp[:,:,:,lo:hi]
        X = gated.transpose(1,0,2,3).reshape(num_rx,-1)
        Rcov = X @ X.conj().T / X.shape[1] + eps*np.eye(num_rx)
        a_vec = steer(angle_axis[abin])
        w = np.linalg.inv(Rcov) @ a_vec
        w /= (a_vec.conj().T @ np.linalg.inv(Rcov) @ a_vec)
        for fr in range(frames):
            Y = gated[fr].reshape(num_rx,-1)
            beams[fr,i] = (w.conj().T @ Y).sum()
    return beams

# -- 6. HR via Welch (unchanged) --
def bandpass_filter(x, low, high, fs, order=4):
    b,a = butter(order, [low/(fs/2), high/(fs/2)], btype='band')
    return filtfilt(b, a, x)

def estimate_rate_welch(x, fs, fmin, fmax):
    f,P = welch(x, fs=fs, nperseg=min(len(x),int(5*fs)))
    mask = (f>=fmin)&(f<=fmax)
    return (f[mask][np.argmax(P[mask])]*60.0) if np.any(mask) else np.nan

def estimate_hr_from_beam(z, fs):
    phase = np.unwrap(np.angle(z))
    dphi  = np.gradient(phase)*fs/(2*np.pi)
    heart = bandpass_filter(dphi,0.8,3.0,fs,order=4)
    return estimate_rate_welch(heart, fs, 0.8, 3.0)

# -- NEW 7. RR via amplitude-envelope peak count --
def estimate_rr_from_beam_amplitude(z, fs,
                                    rr_band=(0.1,0.6),
                                    order=4,
                                    min_rr=5,
                                    max_rr=40,
                                    prom_factor=0.3):
    """
    z    : complex beam signal
    fs   : frame rate (Hz)
    rr_band : respiratory band in Hz
    Returns: (rr_bpm, peak_indices)
    """
    # 1) amplitude envelope
    amp = np.abs(z)
    # 2) bandpass envelope
    b,a = butter(order,
                [rr_band[0]/(fs/2), rr_band[1]/(fs/2)],
                btype='band')
    env = filtfilt(b, a, amp)
    # 3) smoothing
    env_s = medfilt(env, kernel_size=5)
    # 4) peak detection
    min_dist = int(fs*60.0/max_rr)
    prom     = prom_factor * np.std(env_s)
    peaks,_ = find_peaks(env_s,
                         distance=min_dist,
                         prominence=prom)
    # 5) compute BPM
    if len(peaks)<2:
        return np.nan, peaks
    ibis = np.diff(peaks)/fs
    bpm  = 60.0/ibis
    valid = bpm[(bpm>=min_rr)&(bpm<=max_rr)]
    return (float(np.mean(valid)), peaks) if valid.size>0 else (np.nan, peaks)

# -- 8. Main --
if __name__=="__main__":
    FILE = "083_TopFront_2m_45deg_Man2_CondA_CondF_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_1.bin"
    FS   = 40.0

    # decode → range–angle → beamforming
    adc     = decode_iwr6843_data(FILE)
    rp      = compute_range_profiles(adc)
    ra_map  = compute_range_angle_map(rp, n_angle_bins=256)
    targets = detect_targets_physical(ra_map, 2,
                                      min_dist_m=0.7,
                                      B=3.9e9,
                                      n_angle_bins=256,
                                      R_mean=4.0)
    beams   = separate_with_mvdr(adc, targets,
                                 n_angle_bins=256,
                                 rng_win=3)

    # estimate HR & RR for each beam
    for i in range(beams.shape[1]):
        z   = beams[:,i]
        hr  = estimate_hr_from_beam(z, FS)
        rr, peaks = estimate_rr_from_beam_amplitude(z, FS,
                        rr_band=(0.1,0.6),
                        order=4,
                        min_rr=5,
                        max_rr=40,
                        prom_factor=0.3)
        print(f"Patient #{i+1}: HR = {hr:.1f} bpm, RR_amp = {rr:.1f} bpm  (peaks={len(peaks)})")

Patient #1: HR = 156.0 bpm, RR_amp = 22.0 bpm  (peaks=17)
Patient #2: HR = 156.0 bpm, RR_amp = 23.0 bpm  (peaks=18)


In [4]:
import os
import numpy as np
from scipy.signal import butter, filtfilt, welch, find_peaks, hilbert, medfilt

# =========================
# 0) CONFIG / CONSTANTS
# =========================
NUM_RX          = 4
NUM_TX          = 3
NUM_ADC_SAMPLES = 256
NUM_CHIRPS      = 128
HEADER_BYTES    = 0

FC = 60e9        # carrier (Hz)
B  = 3.9e9       # chirp bandwidth (Hz) for range resolution calc
FS = 40.0        # frame rate (Hz) for HR/RR time series

ANGLE_GRID = np.linspace(-90, 90, 181)  # azimuth scan grid (deg)
RNG_GATE   = 3                          # ± bins around target for MVDR
EPS        = 1e-6                       # MVDR diagonal loading

# IWR6843AOP antenna phase centers (mm) -> provided by you
RX_MM = np.array([
    [ 0.000, +1.800, 0.0],   # Rx0
    [+1.800,  0.000, 0.0],   # Rx1
    [ 0.000, -1.800, 0.0],   # Rx2
    [-1.800,  0.000, 0.0],   # Rx3
])
TX_MM = np.array([
    [+0.900, +0.900, 0.0],   # Tx0
    [-0.900, +0.900, 0.0],   # Tx1
    [ 0.000, -0.900, 0.0],   # Tx2
])

# =========================
# 1) IO / DECODE
# =========================
def decode_iwr6843_data(filename,
                        num_rx=NUM_RX,
                        num_adc_samples=NUM_ADC_SAMPLES,
                        num_chirps=NUM_CHIRPS,
                        header_bytes=HEADER_BYTES):
    """
    Returns adc of shape (frames, RX, chirps, samples) complex64.
    Assumes interleaved int16 I/Q pairs (I,Q,I,Q,...).
    """
    with open(filename, 'rb') as f:
        f.seek(header_bytes)
        raw = np.fromfile(f, dtype=np.int16)
    if raw.size % 2 != 0:
        raw = raw[:-1]  # safety: drop odd last sample
    raw = raw.reshape(-1, 2)
    complex_data = raw[:, 0].astype(np.float32) + 1j * raw[:, 1].astype(np.float32)

    samp_pf = num_rx * num_chirps * num_adc_samples
    num_frames = complex_data.size // samp_pf
    if num_frames == 0:
        raise ValueError("Not enough samples for even a single frame with given params.")
    leftover = complex_data.size - num_frames * samp_pf
    if leftover != 0:
        # Truncate to full frames
        complex_data = complex_data[:num_frames * samp_pf]

    adc = np.empty((num_frames, num_rx, num_chirps, num_adc_samples), dtype=np.complex64)
    for fr in range(num_frames):
        base = fr * samp_pf
        for rx in range(num_rx):
            st = base + rx * (num_chirps * num_adc_samples)
            en = st + (num_chirps * num_adc_samples)
            adc[fr, rx] = complex_data[st:en].reshape(num_chirps, num_adc_samples)
    return adc

# =========================
# 2) RANGE FFT (with Hann window)
# =========================
def compute_range_profiles(adc):
    """
    adc: (frames, RX, chirps, samples)
    returns rp: (frames, RX, chirps, range_bins) complex64
    """
    window = np.hanning(adc.shape[-1]).astype(np.float32)
    rp = np.fft.fft(adc * window, axis=-1)
    return rp

# =========================
# 3) FORM 12 VIRTUAL CHANNELS (TDM-MIMO)
# =========================
def form_virtual_channels_from_rp(
    rp,
    num_tx=3,
    mode="interleaved",          # or "grouped"
    drop_remainder=True,
    tx_order=None,               # e.g., (0,1,2) or (2,0,1) if known
    group_lengths=None           # for mode="grouped": list/tuple len=num_tx with chirps per Tx
):
    """
    Build virtual channels (TX*RX) from range profiles.
    rp: (frames, RX, chirps, R)
    Returns: (frames, V=TX*RX, chirps_per_tx, R)

    Modes:
      - "interleaved": chirps alternate by Tx index: Tx0,Tx1,Tx2,Tx0,Tx1,Tx2, ...
                        If chirps % num_tx != 0, we optionally drop the final remainder.
      - "grouped":     chirps are in contiguous blocks per Tx, e.g. [Tx0]*L0 + [Tx1]*L1 + [Tx2]*L2 (+ repeats?)
                        You MUST provide `group_lengths` (one pass) or a single pass split for the frame.
    """
    frames, rx, chirps, R = rp.shape
    if tx_order is None:
        tx_order = tuple(range(num_tx))  # default (0,1,2)

    if mode == "interleaved":
        if drop_remainder:
            usable = (chirps // num_tx) * num_tx
            if usable == 0:
                raise ValueError("Not enough chirps to form even a single TDM set.")
            if usable != chirps:
                # trim tail to make divisible
                rp = rp[:, :, :usable, :]
                chirps = usable
        else:
            if chirps % num_tx != 0:
                raise AssertionError("num_chirps must be divisible by num_tx (set drop_remainder=True to trim).")

        cpt = chirps // num_tx
        # stack Tx blocks along channel axis -> V = TX*RX
        virt_blocks = []
        for t in tx_order:
            rp_t = rp[:, :, t:chirps:num_tx, :]      # (F, RX, cpt, R)
            virt_blocks.append(rp_t)
        # concatenate along channel axis: (F, TX*RX, cpt, R)
        rp_virt = np.concatenate(virt_blocks, axis=1)
        return rp_virt

    elif mode == "grouped":
        if group_lengths is None or len(group_lengths) != num_tx:
            raise ValueError("For mode='grouped', provide group_lengths with length=num_tx.")
        total = sum(group_lengths)
        if chirps < total:
            raise ValueError(f"Frame has {chirps} chirps but group_lengths sum to {total}.")
        # Optionally trim to exactly one TDM cycle (first total chirps)
        rp1 = rp[:, :, :total, :]
        # Slice contiguous blocks in the specified tx_order
        start = 0
        virt_blocks = []
        for t in tx_order:
            L = group_lengths[t]
            rp_t = rp1[:, :, start:start+L, :]       # (F, RX, L, R)
            start += L
            virt_blocks.append(rp_t)
        # To align per-Tx chirp counts, we need same L for all Tx for a simple 3D stack.
        # If lengths differ, take the minimum across Tx blocks.
        minL = min(b.shape[2] for b in virt_blocks)
        virt_blocks = [b[:, :, :minL, :] for b in virt_blocks]
        # Concatenate along channel axis: (F, TX*RX, minL, R)
        rp_virt = np.concatenate(virt_blocks, axis=1)
        return rp_virt

    else:
        raise ValueError("mode must be 'interleaved' or 'grouped'")


# =========================
# 4) GEOMETRY: VIRTUAL ARRAY COORDS & STEERING
# =========================
def virtual_array_positions_wavelengths(fc=FC, rx_mm=RX_MM, tx_mm=TX_MM):
    """
    Returns virtual element coordinates in WAVELENGTHS, shape (V,3).
    Virtual = Tx + Rx (vector sum), in far-field plane-wave model.
    """
    lam = 3e8 / fc
    rx_m = rx_mm * 1e-3
    tx_m = tx_mm * 1e-3
    virt = []
    for t in tx_m:
        for r in rx_m:
            virt.append((t + r) / lam)
    return np.array(virt, dtype=np.float64)

def steering_vector_azimuth(theta_deg, virt_xyz_lam):
    """
    Azimuth-only steering in array plane (z≈0).
    θ=0° along +x, +90° along +y.
    Returns (V,1) complex vector.
    """
    th = np.deg2rad(theta_deg)
    kx, ky = np.cos(th), np.sin(th)
    phase = -1j * 2 * np.pi * (virt_xyz_lam[:, 0] * kx + virt_xyz_lam[:, 1] * ky)
    return np.exp(phase)[:, None]

# =========================
# 5) RANGE×AZIMUTH MAP (Bartlett/Capon with real geometry)
# =========================
def range_azimuth_map(rp_virt, virt_xyz_lam, angle_grid=ANGLE_GRID, method="capon"):
    """
    rp_virt: (frames, V, cpt, R)
    returns RA (R, A), normalized per range bin.
    """
    frames, V, cpt, R = rp_virt.shape
    # snapshots across frames & chirps-per-tx
    X = rp_virt.transpose(3, 1, 0, 2).reshape(R, V, -1)  # (R, V, snapshots)
    RA = np.zeros((R, len(angle_grid)), dtype=np.float64)

    for rbin in range(R):
        S = X[rbin]  # (V, snapshots)
        if S.shape[1] < V:
            # not enough snapshots; fall back to Bartlett power
            Rcov = (S @ S.conj().T) / max(1, S.shape[1]) + 1e-6 * np.eye(V)
            Rinv = np.linalg.pinv(Rcov)
        else:
            Rcov = (S @ S.conj().T) / S.shape[1] + 1e-6 * np.eye(V)
            Rinv = np.linalg.pinv(Rcov)

        for ai, ang in enumerate(angle_grid):
            a = steering_vector_azimuth(ang, virt_xyz_lam)  # (V,1)
            if method.lower() == "bartlett":
                num = np.real((a.conj().T @ Rcov @ a)[0,0])
                den = np.real((a.conj().T @ a)[0,0])
                RA[rbin, ai] = num / (den + 1e-12)
            else:  # Capon / MVDR spectrum
                RA[rbin, ai] = 1.0 / np.real((a.conj().T @ Rinv @ a)[0,0] + 1e-12)

    # normalize each range row for display/selection
    RA /= (RA.max(axis=1, keepdims=True) + 1e-12)
    return RA

# =========================
# 6) TARGET PICKING WITH PHYSICAL SPACING
# =========================
def detect_targets_physical(RA, num_targets=2, min_dist_m=0.7, B=B, angle_grid=ANGLE_GRID, R_mean=4.0):
    """
    RA: (R, A), A = len(angle_grid).
    Returns list of (range_bin, angle_index).
    """
    c = 3e8
    rng_res = c / (2 * B)
    min_rb = int(np.ceil(min_dist_m / rng_res))  # min separation in range bins

    if len(angle_grid) > 1:
        ang_res = float(np.abs(angle_grid[1] - angle_grid[0]))
    else:
        ang_res = 1.0
    ang_sep = np.degrees(np.arcsin(np.clip(min_dist_m / max(R_mean, 1e-6), -1.0, 1.0)))
    min_ab = int(np.ceil(ang_sep / max(ang_res, 1e-6)))

    R, A = RA.shape
    idxs = np.argsort(RA.ravel())[::-1]  # strongest first
    targets = []
    for idx in idxs:
        r, a = divmod(idx, A)
        if r < min_rb or r > R - min_rb:
            continue
        if any(abs(r - r0) < min_rb and abs(a - a0) < min_ab for r0, a0 in targets):
            continue
        targets.append((r, a))
        if len(targets) == num_targets:
            break
    return targets

# =========================
# 7) GEOMETRY-AWARE MVDR BEAMFORMING
# =========================
def mvdr_beamform_tracks(rp_virt, targets, virt_xyz_lam, angle_grid=ANGLE_GRID, rng_gate=RNG_GATE, eps=EPS):
    """
    rp_virt: (frames, V, cpt, R)
    targets: list of (rbin, abin) indices
    returns beams: (frames, num_targets) complex time series
    """
    frames, V, cpt, R = rp_virt.shape
    beams = np.zeros((frames, len(targets)), dtype=np.complex64)

    for ti, (rbin, abin) in enumerate(targets):
        # range gate
        lo = max(0, rbin - rng_gate)
        hi = min(R, rbin + rng_gate + 1)
        gated = rp_virt[:, :, :, lo:hi]  # (frames, V, cpt, gate)

        # Build snapshots across frames & chirps for covariance
        X = gated.transpose(1, 0, 2, 3).reshape(V, -1)  # (V, snapshots)
        Rcov = (X @ X.conj().T) / max(1, X.shape[1]) + eps * np.eye(V)
        Rinv = np.linalg.pinv(Rcov)

        theta = angle_grid[abin]
        a = steering_vector_azimuth(theta, virt_xyz_lam)  # (V,1)
        w = (Rinv @ a) / (a.conj().T @ Rinv @ a)          # (V,1) MVDR weights

        # Apply per frame to get temporal signal (sum over chirps & gate)
        for fr in range(frames):
            Y = gated[fr].reshape(V, -1)                  # (V, cpt*gate)
            beams[fr, ti] = (w.conj().T @ Y).sum()
    return beams

# =========================
# 8) HR / RR ESTIMATION
# =========================
def bandpass_filter(x, low, high, fs, order=4):
    b, a = butter(order, [low/(fs/2), high/(fs/2)], btype='band')
    return filtfilt(b, a, x)

def estimate_hr_from_beam_phase(z, fs=FS):
    """
    HR from instantaneous frequency of the beam phase.
    """
    phase = np.unwrap(np.angle(z))
    dphi  = np.gradient(phase) * fs / (2*np.pi)  # Hz
    heart = bandpass_filter(dphi, 1.0, 2.5, fs, order=4)  # ~60–150 bpm
    f, P  = welch(heart, fs=fs, nperseg=min(len(heart), int(8*fs)))
    mask  = (f >= 1.0) & (f <= 3.0)
    return (f[mask][np.argmax(P[mask])] * 60.0) if np.any(mask) else np.nan

def estimate_rr_from_beam_amp(z, fs=FS, rr_band=(0.1, 0.6), min_rr=5, max_rr=40, prom_factor=0.3):
    """
    RR from amplitude envelope peak counting in the respiration band.
    """
    amp   = np.abs(z)
    env   = bandpass_filter(amp, rr_band[0], rr_band[1], fs, order=4)
    env_s = medfilt(env, kernel_size=5)
    min_dist = int(fs * 60.0 / max_rr)
    prom     = prom_factor * np.std(env_s)
    peaks, _ = find_peaks(env_s, distance=min_dist, prominence=prom)
    if len(peaks) < 2:
        return np.nan, peaks
    ibis = np.diff(peaks) / fs
    bpm  = 60.0 / ibis
    valid = bpm[(bpm >= min_rr) & (bpm <= max_rr)]
    return (float(np.mean(valid)) if valid.size else np.nan, peaks)

# =========================
# 9) DRIVER / MAIN
# =========================
def run_pipeline(FILE, num_targets=2, method_ra="capon"):
    # 1) Decode
    adc = decode_iwr6843_data(FILE, NUM_RX, NUM_ADC_SAMPLES, NUM_CHIRPS, HEADER_BYTES)

    # 2) Range FFT
    rp  = compute_range_profiles(adc)  # (F, RX, C, R)

    # 3) Form 12 virtual channels
    rp_virt = form_virtual_channels_from_rp(rp, NUM_TX)  # (F, 12, C/3, R)

    # 4) Geometry-aware RA map
    virt_xyz_lam = virtual_array_positions_wavelengths(FC, RX_MM, TX_MM)
    RA = range_azimuth_map(rp_virt, virt_xyz_lam, ANGLE_GRID, method=method_ra)

    # 5) Target detection
    targets = detect_targets_physical(RA, num_targets=num_targets, min_dist_m=0.7, B=B, angle_grid=ANGLE_GRID, R_mean=4.0)

    # 6) MVDR beams per target
    beams = mvdr_beamform_tracks(rp_virt, targets, virt_xyz_lam, ANGLE_GRID, RNG_GATE, EPS)

    # 7) Vital signs
    est = []
    for i in range(beams.shape[1]):
        z = beams[:, i]
        hr = estimate_hr_from_beam_phase(z, FS)
        rr, peaks = estimate_rr_from_beam_amp(z, FS)
        est.append((i, hr, rr, len(peaks)))
    return RA, targets, beams, est

if __name__ == "__main__":
    FILE = "plots/SubsetD/083_TopFront_2m_45deg_Man2_CondA_CondF_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin"
    RA, targets, beams, est = run_pipeline(FILE, num_targets=2, method_ra="capon")

    print("Targets (range_bin, angle_idx):", targets)
    for i, hr, rr, npeak in est:
        print(f"Patient #{i+1}: HR={hr:.1f} bpm, RR={rr:.1f} bpm (peaks={npeak})")


Targets (range_bin, angle_idx): [(26, 144), (118, 137)]
Patient #1: HR=97.5 bpm, RR=23.1 bpm (peaks=11)
Patient #2: HR=105.0 bpm, RR=9.6 bpm (peaks=5)


In [5]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, welch, find_peaks, hilbert, medfilt

# =========================
# CONFIG
# =========================
NUM_RX          = 4
NUM_TX          = 3
NUM_ADC_SAMPLES = 256
NUM_CHIRPS      = 128
HEADER_BYTES    = 0

FC = 60e9
B  = 3.9e9
FS = 40.0

ANGLE_GRID = np.linspace(-90, 90, 181)
RNG_GATE   = 3
EPS        = 1e-6

RX_MM = np.array([
    [ 0.000, +1.800, 0.0],
    [+1.800,  0.000, 0.0],
    [ 0.000, -1.800, 0.0],
    [-1.800,  0.000, 0.0],
])
TX_MM = np.array([
    [+0.900, +0.900, 0.0],
    [-0.900, +0.900, 0.0],
    [ 0.000, -0.900, 0.0],
])

PLOT_DIR = "plots_out"
os.makedirs(PLOT_DIR, exist_ok=True)

# =========================
# IO / DECODE
# =========================
def decode_iwr6843_data(filename,
                        num_rx=NUM_RX,
                        num_adc_samples=NUM_ADC_SAMPLES,
                        num_chirps=NUM_CHIRPS,
                        header_bytes=HEADER_BYTES):
    with open(filename,'rb') as f:
        f.seek(header_bytes)
        raw = np.fromfile(f, dtype=np.int16)
    if raw.size % 2 != 0:
        raw = raw[:-1]
    raw = raw.reshape(-1,2)
    complex_data = raw[:,0].astype(np.float32) + 1j*raw[:,1].astype(np.float32)

    samp_pf   = num_rx*num_chirps*num_adc_samples
    num_frames= complex_data.size // samp_pf
    complex_data = complex_data[:num_frames*samp_pf]
    adc = np.empty((num_frames, num_rx, num_chirps, num_adc_samples), dtype=np.complex64)
    for fr in range(num_frames):
        base = fr*samp_pf
        for rx in range(num_rx):
            st = base + rx*(num_chirps*num_adc_samples)
            en = st +   (num_chirps*num_adc_samples)
            adc[fr,rx] = complex_data[st:en].reshape(num_chirps, num_adc_samples)
    return adc

# =========================
# RANGE FFT
# =========================
def compute_range_profiles(adc):
    window = np.hanning(adc.shape[-1]).astype(np.float32)
    rp = np.fft.fft(adc * window, axis=-1)
    return rp

# =========================
# MIMO → 12 VIRTUAL CH
# =========================
def form_virtual_channels_from_rp(
    rp,
    num_tx=NUM_TX,
    mode="interleaved",
    drop_remainder=True,
    tx_order=(0,1,2),
    group_lengths=None
):
    frames, rx, chirps, R = rp.shape

    if mode == "interleaved":
        if drop_remainder:
            usable = (chirps // num_tx)*num_tx
            rp = rp[:, :, :usable, :]
            chirps = usable
        else:
            assert chirps % num_tx == 0, "num_chirps must be divisible by num_tx."

        cpt = chirps // num_tx
        virt_blocks = []
        for t in tx_order:
            rp_t = rp[:, :, t:chirps:num_tx, :]
            virt_blocks.append(rp_t)
        rp_virt = np.concatenate(virt_blocks, axis=1)  # (F, TX*RX, cpt, R)
        return rp_virt

    elif mode == "grouped":
        if group_lengths is None or len(group_lengths) != num_tx:
            raise ValueError("group_lengths must be provided for mode='grouped'")
        total = sum(group_lengths)
        if chirps < total:
            raise ValueError(f"chirps={chirps} < sum(group_lengths)={total}")
        rp1 = rp[:, :, :total, :]
        start = 0
        virt_blocks = []
        for t in tx_order:
            L = group_lengths[t]
            rp_t = rp1[:, :, start:start+L, :]
            start += L
            virt_blocks.append(rp_t)
        minL = min(b.shape[2] for b in virt_blocks)
        virt_blocks = [b[:, :, :minL, :] for b in virt_blocks]
        rp_virt = np.concatenate(virt_blocks, axis=1)
        return rp_virt

    else:
        raise ValueError("mode must be 'interleaved' or 'grouped'")

# =========================
# GEOMETRY
# =========================
def virtual_array_positions_wavelengths(fc=FC, rx_mm=RX_MM, tx_mm=TX_MM):
    lam = 3e8/fc
    rx_m = rx_mm*1e-3
    tx_m = tx_mm*1e-3
    virt=[]
    for t in tx_m:
        for r in rx_m:
            virt.append((t+r)/lam)
    return np.array(virt, dtype=np.float64)  # (12,3)

def steering_vector_azimuth(theta_deg, virt_xyz_lam):
    th = np.deg2rad(theta_deg)
    kx, ky = np.cos(th), np.sin(th)
    phase = -1j*2*np.pi*(virt_xyz_lam[:,0]*kx + virt_xyz_lam[:,1]*ky)
    return np.exp(phase)[:,None]

# =========================
# RA MAP (Bartlett/Capon)
# =========================
def range_azimuth_map(rp_virt, virt_xyz_lam, angle_grid=ANGLE_GRID, method="capon"):
    frames, V, cpt, R = rp_virt.shape
    X = rp_virt.transpose(3,1,0,2).reshape(R, V, -1)  # (R,V,snaps)
    RA = np.zeros((R, len(angle_grid)), dtype=np.float64)

    for rb in range(R):
        S = X[rb]
        Rcov = (S @ S.conj().T) / max(1, S.shape[1]) + 1e-6*np.eye(V)
        Rinv = np.linalg.pinv(Rcov)
        for ai, ang in enumerate(angle_grid):
            a = steering_vector_azimuth(ang, virt_xyz_lam)
            if method.lower()=="bartlett":
                num = np.real((a.conj().T @ Rcov @ a)[0,0])
                den = np.real((a.conj().T @ a)[0,0])
                RA[rb, ai] = num/(den+1e-12)
            else:
                RA[rb, ai] = 1.0/np.real((a.conj().T @ Rinv @ a)[0,0] + 1e-12)
    RA /= (RA.max(axis=1, keepdims=True) + 1e-12)
    return RA

# =========================
# TARGET PICK
# =========================
def detect_targets_physical(RA, num_targets=2, min_dist_m=0.7, B=B, angle_grid=ANGLE_GRID, R_mean=4.0):
    c = 3e8
    rng_res = c/(2*B)
    min_rb = int(np.ceil(min_dist_m/rng_res))
    ang_res = float(abs(angle_grid[1]-angle_grid[0])) if len(angle_grid)>1 else 1.0
    ang_sep = np.degrees(np.arcsin(np.clip(min_dist_m/max(R_mean,1e-6), -1, 1)))
    min_ab = int(np.ceil(ang_sep/max(ang_res,1e-6)))

    R, A = RA.shape
    idxs = np.argsort(RA.ravel())[::-1]
    targets=[]
    for idx in idxs:
        r,a = divmod(idx,A)
        if r<min_rb or r>R-min_rb: 
            continue
        if any(abs(r-r0)<min_rb and abs(a-a0)<min_ab for r0,a0 in targets):
            continue
        targets.append((r,a))
        if len(targets)==num_targets: 
            break
    return targets

# =========================
# MVDR BEAMS
# =========================
def mvdr_beamform_tracks(rp_virt, targets, virt_xyz_lam, angle_grid=ANGLE_GRID, rng_gate=RNG_GATE, eps=EPS):
    frames, V, cpt, R = rp_virt.shape
    beams = np.zeros((frames, len(targets)), dtype=np.complex64)
    for ti, (rbin, abin) in enumerate(targets):
        lo = max(0, rbin-rng_gate)
        hi = min(R, rbin+rng_gate+1)
        gated = rp_virt[:, :, :, lo:hi]  # (F,V,cpt,gate)
        X = gated.transpose(1,0,2,3).reshape(V,-1)
        Rcov = (X @ X.conj().T)/max(1,X.shape[1]) + eps*np.eye(V)
        Rinv = np.linalg.pinv(Rcov)
        theta = angle_grid[abin]
        a = steering_vector_azimuth(theta, virt_xyz_lam)
        w = (Rinv @ a)/(a.conj().T @ Rinv @ a)
        for fr in range(frames):
            Y = gated[fr].reshape(V,-1)
            beams[fr, ti] = (w.conj().T @ Y).sum()
    return beams

# =========================
# HR / RR
# =========================
def bandpass_filter(x, low, high, fs, order=4):
    b, a = butter(order, [low/(fs/2), high/(fs/2)], btype='band')
    return filtfilt(b, a, x)

def estimate_hr_from_beam_phase(z, fs=FS):
    phase = np.unwrap(np.angle(z))
    dphi  = np.gradient(phase)*fs/(2*np.pi)
    heart = bandpass_filter(dphi, 1.0, 2.5, fs, order=4)
    f, P  = welch(heart, fs=fs, nperseg=min(len(heart), int(8*fs)))
    mask  = (f>=1.0) & (f<=3.0)
    return (f[mask][np.argmax(P[mask])] * 60.0) if np.any(mask) else np.nan, dphi, heart, (f, P)

def estimate_rr_from_beam_amp(z, fs=FS, rr_band=(0.1,0.6), min_rr=5, max_rr=40, prom_factor=0.3):
    amp   = np.abs(z)
    env   = bandpass_filter(amp, rr_band[0], rr_band[1], fs, order=4)
    env_s = medfilt(env, kernel_size=5)
    min_dist = int(fs * 60.0 / max_rr)
    prom     = prom_factor * np.std(env_s)
    peaks, _ = find_peaks(env_s, distance=min_dist, prominence=prom)
    if len(peaks) < 2:
        return np.nan, amp, env, env_s, peaks
    ibis = np.diff(peaks) / fs
    bpm  = 60.0 / ibis
    valid = bpm[(bpm >= min_rr) & (bpm <= max_rr)]
    rr = float(np.mean(valid)) if valid.size else np.nan
    return rr, amp, env, env_s, peaks

# =========================
# PLOTTING HELPERS (each makes ONE figure)
# =========================
def savefig(name):
    path = os.path.join(PLOT_DIR, name)
    plt.tight_layout()
    plt.savefig(path, dpi=150)
    plt.close()
    print(f"saved: {path}")

def plot_adc_example(adc, frame=0, rx=0, chirp=0):
    x = adc[frame, rx, chirp].astype(np.complex64)
    plt.figure()
    plt.title(f"ADC | frame={frame}, rx={rx}, chirp={chirp}")
    plt.plot(np.abs(x))
    plt.xlabel("sample")
    plt.ylabel("|I+jQ|")
    savefig("01_adc_example.png")

def plot_range_profiles_per_rx(rp, frame=0):
    mag = np.abs(rp[frame])  # (RX, C, R)
    prof = mag.mean(axis=1)  # avg over chirps -> (RX, R)
    plt.figure()
    plt.title("Range profiles (avg over chirps) per Rx")
    for rx in range(mag.shape[0]):
        plt.plot(prof[rx], label=f"Rx{rx}")
    plt.xlabel("range bin")
    plt.ylabel("magnitude")
    plt.legend()
    savefig("02_range_profiles_per_rx.png")

def plot_ra_heatmap(RA):
    plt.figure()
    plt.title("Range × Azimuth (normalized)")
    plt.imshow(RA, aspect='auto', origin='lower',
               extent=[ANGLE_GRID[0], ANGLE_GRID[-1], 0, RA.shape[0]])
    plt.xlabel("azimuth (deg)")
    plt.ylabel("range bin")
    savefig("03_ra_heatmap.png")

def plot_angle_cut(RA):
    rbin = np.argmax(RA.max(axis=1))
    plt.figure()
    plt.title(f"Angle cut at strongest range bin r={rbin}")
    plt.plot(ANGLE_GRID, RA[rbin])
    plt.xlabel("azimuth (deg)")
    plt.ylabel("normalized power")
    savefig("04_angle_cut.png")

def plot_ra_with_targets(RA, targets):
    plt.figure()
    plt.title("RA with targets")
    plt.imshow(RA, aspect='auto', origin='lower',
               extent=[ANGLE_GRID[0], ANGLE_GRID[-1], 0, RA.shape[0]])
    if targets:
        azs = [ANGLE_GRID[a] for (_,a) in targets]
        rng = [r for (r,_) in targets]
        plt.scatter(azs, rng, marker='x')
    plt.xlabel("azimuth (deg)")
    plt.ylabel("range bin")
    savefig("05_ra_with_targets.png")

def plot_beam_time_series(beams):
    plt.figure()
    plt.title("Beam time series (magnitude)")
    for i in range(beams.shape[1]):
        plt.plot(np.abs(beams[:, i]), label=f"beam {i+1}")
    plt.xlabel("frame")
    plt.ylabel("|z|")
    plt.legend()
    savefig("06_beam_time_series.png")

def plot_inst_freq(dphi):
    plt.figure()
    plt.title("Instantaneous frequency from phase (Hz)")
    plt.plot(dphi)
    plt.xlabel("frame")
    plt.ylabel("Hz")
    savefig("07_inst_freq.png")

def plot_welch_hr(f, P):
    plt.figure()
    plt.title("Welch PSD (HR band)")
    plt.semilogy(f, P)
    plt.xlabel("Hz")
    plt.ylabel("PSD")
    savefig("08_welch_hr.png")

def plot_rr_envelope(env, env_s, peaks):
    plt.figure()
    plt.title("Respiratory envelope + peaks")
    plt.plot(env, alpha=0.6, label="env")
    plt.plot(env_s, label="env (medfilt)")
    if len(peaks) > 0:
        plt.plot(peaks, env_s[peaks], "o", label="peaks")
    plt.xlabel("frame")
    plt.ylabel("amplitude")
    plt.legend()
    savefig("09_rr_envelope.png")

# =========================
# DRIVER
# =========================
def run_pipeline(FILE, num_targets=2, method_ra="capon", PLOT=True):
    # 1) Decode
    adc = decode_iwr6843_data(FILE, NUM_RX, NUM_ADC_SAMPLES, NUM_CHIRPS, HEADER_BYTES)
    print("adc shape:", adc.shape)  # (F,RX,C,S)
    if PLOT:
        plot_adc_example(adc, frame=0, rx=0, chirp=0)

    # 2) Range FFT
    rp  = compute_range_profiles(adc)
    print("rp shape:", rp.shape)  # (F,RX,C,R)
    if PLOT:
        plot_range_profiles_per_rx(rp, frame=0)

    # 3) Virtual channels (robust)
    rp_virt = form_virtual_channels_from_rp(
        rp,
        num_tx=NUM_TX,
        mode="interleaved",   # change if grouped
        drop_remainder=True,
        tx_order=(0,1,2)
    )
    print("rp_virt shape:", rp_virt.shape)  # (F,12,C/3,R)

    # 4) Geometry-aware RA map
    virt_xyz_lam = virtual_array_positions_wavelengths(FC, RX_MM, TX_MM)
    RA = range_azimuth_map(rp_virt, virt_xyz_lam, ANGLE_GRID, method=method_ra)
    print("RA shape:", RA.shape)  # (R,A)
    if PLOT:
        plot_ra_heatmap(RA)
        plot_angle_cut(RA)

    # 5) Target detection
    targets = detect_targets_physical(RA, num_targets=num_targets, min_dist_m=0.7, B=B, angle_grid=ANGLE_GRID, R_mean=4.0)
    print("targets:", targets)
    if PLOT:
        plot_ra_with_targets(RA, targets)

    # 6) MVDR beams
    beams = mvdr_beamform_tracks(rp_virt, targets, virt_xyz_lam, ANGLE_GRID, RNG_GATE, EPS)
    print("beams shape:", beams.shape)  # (F, num_targets)
    if PLOT:
        plot_beam_time_series(beams)

    # 7) Vital signs (show full details for beam 1)
    est = []
    for i in range(beams.shape[1]):
        z = beams[:, i]
        hr, dphi, heart_bp, wel = estimate_hr_from_beam_phase(z, FS)
        rr, amp, env, env_s, peaks = estimate_rr_from_beam_amp(z, FS)
        est.append((i, hr, rr, len(peaks)))

        if i == 0 and PLOT:
            plot_inst_freq(dphi)
            f, P = wel
            plot_welch_hr(f, P)
            plot_rr_envelope(env, env_s, peaks)

    return RA, targets, beams, est

# =========================
# RUN
# =========================
if __name__ == "__main__":
    FILE = "plots/SubsetD/083_TopFront_2m_45deg_Man2_CondA_CondF_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin"
    RA, targets, beams, est = run_pipeline(FILE, num_targets=2, method_ra="capon", PLOT=True)
    for i, hr, rr, npeak in est:
        print(f"Patient #{i+1}: HR={hr:.1f} bpm, RR={rr:.1f} bpm (peaks={npeak})")


adc shape: (1408, 4, 128, 256)
saved: plots_out/01_adc_example.png
rp shape: (1408, 4, 128, 256)
saved: plots_out/02_range_profiles_per_rx.png
rp_virt shape: (1408, 12, 42, 256)
RA shape: (256, 181)
saved: plots_out/03_ra_heatmap.png
saved: plots_out/04_angle_cut.png
targets: [(26, 144), (118, 137)]
saved: plots_out/05_ra_with_targets.png
beams shape: (1408, 2)
saved: plots_out/06_beam_time_series.png
saved: plots_out/07_inst_freq.png
saved: plots_out/08_welch_hr.png
saved: plots_out/09_rr_envelope.png
Patient #1: HR=97.5 bpm, RR=23.1 bpm (peaks=11)
Patient #2: HR=105.0 bpm, RR=9.6 bpm (peaks=5)


In [10]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, welch, find_peaks, hilbert, medfilt

# =========================
# CONFIG / CONSTANTS
# =========================
NUM_RX          = 4
NUM_TX          = 3
NUM_ADC_SAMPLES = 256
NUM_CHIRPS      = 128
HEADER_BYTES    = 0

FC = 60e9        # carrier
B  = 3.9e9       # bandwidth for range resolution
FS = 40.0        # frame rate (Hz) for vital signals

ANGLE_GRID = np.linspace(-90, 90, 361)  # dense azimuth grid
RNG_GATE   = 3
EPS        = 1e-6

# AOP phase centers you provided (mm)
RX_MM = np.array([
    [ 0.000, +1.800, 0.0],   # Rx0
    [+1.800,  0.000, 0.0],   # Rx1
    [ 0.000, -1.800, 0.0],   # Rx2
    [-1.800,  0.000, 0.0],   # Rx3
])
TX_MM = np.array([
    [+0.900, +0.900, 0.0],   # Tx0
    [-0.900, +0.900, 0.0],   # Tx1
    [ 0.000, -0.900, 0.0],   # Tx2
])

PLOT_DIR = "plots_out"
os.makedirs(PLOT_DIR, exist_ok=True)

# =========================
# UTILS (plots/anim)
# =========================
def savefig(name):
    path = os.path.join(PLOT_DIR, name)
    plt.tight_layout()
    plt.savefig(path, dpi=150)
    plt.close()
    print(f"[plot] {path}")

def plot_adc_example(adc, frame=0, rx=0, chirp=0, name="01_adc_example.png"):
    x = adc[frame, rx, chirp]
    plt.figure()
    plt.title(f"ADC magnitude | frame={frame}, rx={rx}, chirp={chirp}")
    plt.plot(np.abs(x))
    plt.xlabel("sample"); plt.ylabel("|I+jQ|")
    savefig(name)

def plot_range_profiles_per_rx(rp, frame=0, name="02_range_profiles_per_rx.png"):
    mag = np.abs(rp[frame])  # (RX, C, R)
    prof = mag.mean(axis=1)  # (RX, R)
    plt.figure()
    plt.title("Range profiles (avg over chirps) per Rx")
    for rx in range(mag.shape[0]): plt.plot(prof[rx], label=f"Rx{rx}")
    plt.xlabel("range bin"); plt.ylabel("magnitude"); plt.legend()
    savefig(name)

def plot_ra_heatmap(RA, name="03_ra_heatmap.png"):
    plt.figure()
    plt.title("Range × Azimuth (normalized)")
    plt.imshow(RA, aspect='auto', origin='lower',
               extent=[ANGLE_GRID[0], ANGLE_GRID[-1], 0, RA.shape[0]])
    plt.xlabel("azimuth (deg)"); plt.ylabel("range bin")
    savefig(name)

def plot_angle_cut(RA, name="04_angle_cut.png"):
    rbin = int(np.argmax(RA.max(axis=1)))
    plt.figure()
    plt.title(f"Angle cut at strongest range bin r={rbin}")
    plt.plot(ANGLE_GRID, RA[rbin])
    plt.xlabel("azimuth (deg)"); plt.ylabel("normalized power")
    savefig(name)

def plot_ra_with_targets(RA, targets, name="05_ra_with_targets.png"):
    plt.figure()
    plt.title("RA with targets")
    plt.imshow(RA, aspect='auto', origin='lower',
               extent=[ANGLE_GRID[0], ANGLE_GRID[-1], 0, RA.shape[0]])
    if targets:
        azs = [ANGLE_GRID[a] if a is not None else 0 for (r,a) in targets]
        rng = [r for (r,_) in targets]
        plt.scatter(azs, rng, marker='x', color='k')
    plt.xlabel("azimuth (deg)"); plt.ylabel("range bin")
    savefig(name)

def plot_beam_time_series(beams, name="06_beam_time_series.png"):
    plt.figure()
    plt.title("Beam time series (magnitude)")
    for i in range(beams.shape[1]):
        plt.plot(np.abs(beams[:, i]), label=f"beam {i+1}")
    plt.xlabel("frame"); plt.ylabel("|z|"); plt.legend()
    savefig(name)

def plot_inst_freq(dphi, name="07_inst_freq.png"):
    plt.figure(); plt.title("Instantaneous frequency (Hz)")
    plt.plot(dphi); plt.xlabel("frame"); plt.ylabel("Hz")
    savefig(name)

def plot_welch_hr(f, P, name="08_welch_hr.png"):
    plt.figure(); plt.title("Welch PSD (HR band)")
    plt.semilogy(f, P); plt.xlabel("Hz"); plt.ylabel("PSD")
    savefig(name)

def plot_rr_envelope(env, env_s, peaks, name="09_rr_envelope.png"):
    plt.figure(); plt.title("Respiratory envelope + peaks")
    plt.plot(env, alpha=0.6, label="env")
    plt.plot(env_s, label="env (medfilt)")
    if len(peaks) > 0: plt.plot(peaks, env_s[peaks], "o", label="peaks")
    plt.xlabel("frame"); plt.ylabel("amplitude"); plt.legend()
    savefig(name)

# Optional raw-ADC animations (GIF, Pillow)
def animate_adc_magnitude(adc, out_gif="adc_mag.gif", rx=0, chirp=0, fps=20, decim=1):
    try:
        from matplotlib.animation import FuncAnimation, PillowWriter
    except Exception as e:
        print("[anim] Pillow animation not available:", e); return
    F, RX, C, S = adc.shape
    frames_idx = np.arange(F)
    x = np.arange(0, S, decim)
    y0 = np.abs(adc[0, rx, chirp, ::decim])

    fig, ax = plt.subplots()
    ln, = ax.plot(x, y0)
    ax.set_title(f"ADC | frame 0 | Rx{rx} Chirp{chirp}")
    ax.set_xlabel("sample"); ax.set_ylabel("|I+jQ|")
    ax.set_ylim(0, max(1e-9, y0.max()*1.1)); ax.set_xlim(0, x[-1] if len(x)>0 else S-1)

    def update(fi):
        y = np.abs(adc[fi, rx, chirp, ::decim]); ln.set_ydata(y)
        ax.set_title(f"ADC | frame {fi} | Rx{rx} Chirp{chirp}")
        return (ln,)

    anim = plt.matplotlib.animation.FuncAnimation(fig, update, frames=frames_idx, blit=True, interval=1000.0/fps)
    anim.save(os.path.join(PLOT_DIR, out_gif), writer=plt.matplotlib.animation.PillowWriter(fps=fps))
    plt.close(fig)
    print(f"[anim] {os.path.join(PLOT_DIR, out_gif)}")

# =========================
# 1) DECODE
# =========================
def decode_iwr6843_data(filename,
                        num_rx=NUM_RX,
                        num_adc_samples=NUM_ADC_SAMPLES,
                        num_chirps=NUM_CHIRPS,
                        header_bytes=HEADER_BYTES):
    with open(filename,'rb') as f:
        f.seek(header_bytes)
        raw = np.fromfile(f, dtype=np.int16)
    if raw.size % 2 != 0: raw = raw[:-1]
    raw = raw.reshape(-1,2)
    complex_data = raw[:,0].astype(np.float32) + 1j*raw[:,1].astype(np.float32)

    samp_pf    = num_rx*num_chirps*num_adc_samples
    num_frames = complex_data.size // samp_pf
    complex_data = complex_data[:num_frames*samp_pf]

    adc = np.empty((num_frames, num_rx, num_chirps, num_adc_samples), dtype=np.complex64)
    for fr in range(num_frames):
        base = fr*samp_pf
        for rx in range(num_rx):
            st = base + rx*(num_chirps*num_adc_samples)
            en = st +   (num_chirps*num_adc_samples)
            adc[fr,rx] = complex_data[st:en].reshape(num_chirps, num_adc_samples)
    return adc  # (F,RX,C,S)

# =========================
# 2) RANGE FFT (Hann)
# =========================
def compute_range_profiles(adc):
    window = np.hanning(adc.shape[-1]).astype(np.float32)
    return np.fft.fft(adc * window, axis=-1)  # (F,RX,C,R)

# =========================
# 3) VIRTUAL CHANNELS (robust)
# =========================
def form_virtual_channels_from_rp(
    rp,
    num_tx=NUM_TX,
    mode="interleaved",      # or "grouped"
    drop_remainder=True,
    tx_order=(0,1,2),
    group_lengths=None
):
    frames, rx, chirps, R = rp.shape
    if mode == "interleaved":
        if drop_remainder:
            usable = (chirps // num_tx) * num_tx
            rp = rp[:, :, :usable, :]
            chirps = usable
        else:
            assert chirps % num_tx == 0, "num_chirps must be divisible by num_tx."
        cpt = chirps // num_tx
        virt_blocks = []
        for t in tx_order:
            rp_t = rp[:, :, t:chirps:num_tx, :]   # (F,RX,cpt,R)
            virt_blocks.append(rp_t)
        rp_virt = np.concatenate(virt_blocks, axis=1)   # (F, TX*RX=12, cpt, R)
        return rp_virt
    elif mode == "grouped":
        if group_lengths is None or len(group_lengths) != num_tx:
            raise ValueError("group_lengths required for 'grouped'")
        total = sum(group_lengths)
        if chirps < total:
            raise ValueError(f"chirps={chirps} < sum(group_lengths)={total}")
        rp1 = rp[:, :, :total, :]
        start = 0; blocks=[]
        for t in tx_order:
            L = group_lengths[t]
            blocks.append(rp1[:, :, start:start+L, :]); start += L
        minL = min(b.shape[2] for b in blocks)
        blocks = [b[:, :, :minL, :] for b in blocks]
        rp_virt = np.concatenate(blocks, axis=1)
        return rp_virt
    else:
        raise ValueError("mode must be 'interleaved' or 'grouped'")

# =========================
# 4) GEOMETRY (virtual coords + steering)
# =========================
def virtual_array_positions_wavelengths(fc=FC, rx_mm=RX_MM, tx_mm=TX_MM):
    lam = 3e8 / fc
    rx_m = rx_mm * 1e-3; tx_m = tx_mm * 1e-3
    virt=[]
    for t in tx_m:
        for r in rx_m: virt.append((t+r)/lam)
    return np.array(virt, dtype=np.float64)  # (12,3)

def steering_vector_azimuth(theta_deg, virt_xyz_lam):
    th = np.deg2rad(theta_deg)
    kx, ky = np.cos(th), np.sin(th)
    phase = -1j * 2*np.pi * (virt_xyz_lam[:,0]*kx + virt_xyz_lam[:,1]*ky)
    return np.exp(phase)[:,None]  # (V,1)

# =========================
# 5) RA MAP (geometry-aware)
# =========================
def range_azimuth_map(rp_virt, virt_xyz_lam, angle_grid=ANGLE_GRID, method="capon"):
    """
    rp_virt: (frames, V, cpt, R)
    returns RA: (R, A) normalized per range row
    """
    frames, V, cpt, R = rp_virt.shape
    X = rp_virt.transpose(3,1,0,2).reshape(R, V, -1)  # (R,V,snaps)
    RA = np.zeros((R, len(angle_grid)), dtype=np.float64)

    for rb in range(R):
        S = X[rb]
        Rcov = (S @ S.conj().T) / max(1, S.shape[1]) + 1e-6*np.eye(V)
        Rinv = np.linalg.pinv(Rcov)
        for ai, ang in enumerate(angle_grid):
            a = steering_vector_azimuth(ang, virt_xyz_lam)
            if method.lower()=="bartlett":
                num = np.real((a.conj().T @ Rcov @ a)[0,0])
                den = np.real((a.conj().T @ a)[0,0])
                RA[rb, ai] = num/(den+1e-12)
            else:
                RA[rb, ai] = 1.0/np.real((a.conj().T @ Rinv @ a)[0,0] + 1e-12)
    RA /= (RA.max(axis=1, keepdims=True) + 1e-12)
    return RA

# --- Short-time RA with clutter removal ---
def range_azimuth_map_sliding(rp_virt, virt_xyz_lam,
                              angle_grid=ANGLE_GRID,
                              method="bartlett",
                              win_frames=None,
                              step_frames=None,
                              clutter_remove=True,
                              aggregate="median"):
    """
    Slide over frames, build RA per window, then aggregate.
    """
    F = rp_virt.shape[0]
    if win_frames is None: win_frames = max(1, int(0.5*FS))  # ~0.5 s
    if step_frames is None: step_frames = max(1, win_frames//2)

    ras = []
    for s in range(0, F - win_frames + 1, step_frames):
        sl = rp_virt[s:s+win_frames]  # (win,V,cpt,R)
        if clutter_remove:
            sl = sl - sl.mean(axis=0, keepdims=True)
        RAw = range_azimuth_map(sl, virt_xyz_lam, angle_grid, method)
        ras.append(RAw)
    if not ras:
        sl = rp_virt
        if clutter_remove: sl = sl - sl.mean(axis=0, keepdims=True)
        ras = [range_azimuth_map(sl, virt_xyz_lam, angle_grid, method)]

    stack = np.stack(ras, axis=0)  # (W,R,A)
    if aggregate == "mean": RA = stack.mean(axis=0)
    else: RA = np.median(stack, axis=0)
    RA /= (RA.max(axis=1, keepdims=True) + 1e-12)
    return RA

# =========================
# 6) TARGET PICKING
# =========================
def detect_targets_physical(RA, num_targets=2, min_dist_m=0.7, B=B, angle_grid=ANGLE_GRID, R_mean=4.0):
    c = 3e8
    rng_res = c/(2*B)
    min_rb = int(np.ceil(min_dist_m/rng_res))
    ang_res = float(abs(angle_grid[1]-angle_grid[0])) if len(angle_grid)>1 else 1.0
    ang_sep = np.degrees(np.arcsin(np.clip(min_dist_m/max(R_mean,1e-6), -1, 1)))
    min_ab = int(np.ceil(ang_sep/max(ang_res,1e-6)))

    R, A = RA.shape
    idxs = np.argsort(RA.ravel())[::-1]
    targets=[]
    for idx in idxs:
        r,a = divmod(idx,A)
        if r<min_rb or r>R-min_rb: continue
        if any(abs(r-r0)<min_rb and abs(a-a0)<min_ab for r0,a0 in targets): continue
        targets.append((r,a))
        if len(targets)==num_targets: break
    return targets

# --- Fallback: range-only picking when azimuth can’t resolve ---
def detect_targets_by_range(RA, num_targets=2, min_dist_m=0.7, B=B):
    c = 3e8
    rng_res = c/(2*B)
    min_rb = int(np.ceil(min_dist_m/rng_res))
    range_energy = RA.sum(axis=1)
    # simple non-max suppression
    peaks = []
    used = np.zeros_like(range_energy, dtype=bool)
    for r in np.argsort(range_energy)[::-1]:
        if used[r]: continue
        peaks.append(r)
        lo, hi = max(0, r-min_rb), min(len(range_energy), r+min_rb+1)
        used[lo:hi] = True
        if len(peaks) == num_targets: break
    # choose angle as argmax across angles at each range (may be broad)
    targets = []
    for r in peaks:
        a = int(np.argmax(RA[r]))
        targets.append((int(r), int(a)))
    return targets

# =========================
# 7) MVDR BEAMFORMING (with eigen fallback)
# =========================
def mvdr_beamform_tracks(rp_virt, targets, virt_xyz_lam, angle_grid=ANGLE_GRID, rng_gate=RNG_GATE, eps=EPS):
    frames, V, cpt, R = rp_virt.shape
    beams = np.zeros((frames, len(targets)), dtype=np.complex64)

    for ti, (rbin, abin) in enumerate(targets):
        lo = max(0, rbin-rng_gate); hi = min(R, rbin+rng_gate+1)
        gated = rp_virt[:, :, :, lo:hi]  # (F,V,cpt,gate)
        X = gated.transpose(1,0,2,3).reshape(V,-1)
        Rcov = (X @ X.conj().T)/max(1,X.shape[1]) + eps*np.eye(V)

        # try MVDR with steering; if angle not reliable, use principal eigen-beam
        if abin is not None:
            theta = angle_grid[abin]
            a = steering_vector_azimuth(theta, virt_xyz_lam)  # (V,1)
            Rinv = np.linalg.pinv(Rcov)
            w = (Rinv @ a)/(a.conj().T @ Rinv @ a)
        else:
            # dominant eigenvector of Rcov (max-power beam)
            vals, vecs = np.linalg.eigh(Rcov)
            w = vecs[:, np.argmax(vals)][:, None]

        for fr in range(frames):
            Y = gated[fr].reshape(V,-1)       # (V, snapshots in gate)
            beams[fr, ti] = (w.conj().T @ Y).sum()
    return beams

# =========================
# 8) HR / RR
# =========================
def bandpass_filter(x, low, high, fs, order=4):
    b, a = butter(order, [low/(fs/2), high/(fs/2)], btype='band')
    return filtfilt(b, a, x)

def estimate_hr_from_beam_phase(z, fs=FS):
    phase = np.unwrap(np.angle(z))
    dphi  = np.gradient(phase)*fs/(2*np.pi)
    heart = bandpass_filter(dphi, 1.0, 2.5, fs, order=4)     # 60–150 bpm
    f, P  = welch(heart, fs=fs, nperseg=min(len(heart), int(8*fs)))
    mask  = (f>=1.0) & (f<=3.0)
    hr = (f[mask][np.argmax(P[mask])] * 60.0) if np.any(mask) else np.nan
    return hr, dphi, heart, (f, P)

def estimate_rr_from_beam_amp(z, fs=FS, rr_band=(0.1,0.6), min_rr=5, max_rr=40, prom_factor=0.3):
    amp   = np.abs(z)
    env   = bandpass_filter(amp, rr_band[0], rr_band[1], fs, order=4)
    env_s = medfilt(env, kernel_size=5)
    min_dist = int(fs * 60.0 / max_rr)
    prom     = prom_factor * np.std(env_s)
    peaks, _ = find_peaks(env_s, distance=min_dist, prominence=prom)
    if len(peaks) < 2:
        return np.nan, amp, env, env_s, peaks
    ibis = np.diff(peaks) / fs
    bpm  = 60.0 / ibis
    valid = bpm[(bpm >= min_rr) & (bpm <= max_rr)]
    rr = float(np.mean(valid)) if valid.size else np.nan
    return rr, amp, env, env_s, peaks

# =========================
# 9) DRIVER
# =========================
def run_pipeline(
    FILE,
    num_targets=2,
    method_ra="bartlett",     # 'bartlett' is robust; 'capon' is sharper
    PLOT=True,
    ANIMATE_RAW=False
):
    # 1) decode
    adc = decode_iwr6843_data(FILE, NUM_RX, NUM_ADC_SAMPLES, NUM_CHIRPS, HEADER_BYTES)
    print("adc:", adc.shape)  # (F,RX,C,S)
    if PLOT: plot_adc_example(adc)
    if ANIMATE_RAW: animate_adc_magnitude(adc)

    # 2) range fft
    rp  = compute_range_profiles(adc)
    print("rp:", rp.shape)    # (F,RX,C,R)
    if PLOT: plot_range_profiles_per_rx(rp)

    # 3) virtual channels (robust)
    rp_virt = form_virtual_channels_from_rp(
        rp, num_tx=NUM_TX, mode="interleaved", drop_remainder=True, tx_order=(0,1,2)
    )
    print("rp_virt:", rp_virt.shape)  # (F,12,C/3,R)

    # 4) geometry-aware short-time RA with clutter removal
    virt_xyz_lam = virtual_array_positions_wavelengths(FC, RX_MM, TX_MM)
    RA = range_azimuth_map_sliding(
        rp_virt, virt_xyz_lam,
        angle_grid=ANGLE_GRID,
        method=method_ra,
        win_frames=max(1, int(0.5*FS)),
        step_frames=max(1, int(0.25*FS)),
        clutter_remove=True,
        aggregate="median"
    )
    print("RA:", RA.shape)    # (R,A)
    if PLOT:
        plot_ra_heatmap(RA)
        plot_angle_cut(RA)

    # 5) pick targets — try physical (RA), then fallback to range-only
    targets = detect_targets_physical(RA, num_targets=num_targets, min_dist_m=0.7, B=B, angle_grid=ANGLE_GRID, R_mean=4.0)
    if len(targets) < num_targets:
        print("[info] RA could not separate angles well; falling back to range-only picking.")
        targets = detect_targets_by_range(RA, num_targets=num_targets, min_dist_m=0.7, B=B)
    print("targets:", targets)
    if PLOT: plot_ra_with_targets(RA, targets)

    # 6) MVDR beams (eigen fallback if angle unreliable/None)
    beams = mvdr_beamform_tracks(rp_virt, targets, virt_xyz_lam, ANGLE_GRID, RNG_GATE, EPS)
    print("beams:", beams.shape)
    if PLOT: plot_beam_time_series(beams)

    # 7) vital signs on each beam
    est = []
    for i in range(beams.shape[1]):
        z = beams[:, i]
        hr, dphi, heart_bp, (f, P) = estimate_hr_from_beam_phase(z, FS)
        rr, amp, env, env_s, peaks = estimate_rr_from_beam_amp(z, FS)
        est.append((i, hr, rr, len(peaks)))

        if PLOT and i == 0:
            plot_inst_freq(dphi)
            plot_welch_hr(f, P)
            plot_rr_envelope(env, env_s, peaks)

    return RA, targets, beams, est

# =========================
# RUN
# =========================
if __name__ == "__main__":
    FILE = "plots/SubsetD/083_TopFront_2m_45deg_Man2_CondA_CondF_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin"
    RA, targets, beams, est = run_pipeline(FILE, num_targets=2, method_ra="bartlett", PLOT=True, ANIMATE_RAW=False)
    for i, hr, rr, npeak in est:
        print(f"Patient #{i+1}: HR={hr:.1f} bpm, RR={rr:.1f} bpm (peaks={npeak})")


adc: (1408, 4, 128, 256)
[plot] plots_out/01_adc_example.png
rp: (1408, 4, 128, 256)
[plot] plots_out/02_range_profiles_per_rx.png
rp_virt: (1408, 12, 42, 256)
RA: (256, 361)
[plot] plots_out/03_ra_heatmap.png
[plot] plots_out/04_angle_cut.png
targets: [(146, 0), (231, 0)]
[plot] plots_out/05_ra_with_targets.png
beams: (1408, 2)
[plot] plots_out/06_beam_time_series.png
[plot] plots_out/07_inst_freq.png
[plot] plots_out/08_welch_hr.png
[plot] plots_out/09_rr_envelope.png
Patient #1: HR=97.5 bpm, RR=17.0 bpm (peaks=7)
Patient #2: HR=97.5 bpm, RR=22.9 bpm (peaks=12)


In [13]:
import os, re, csv
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import eigh, pinv
from scipy.signal import butter, filtfilt, welch, find_peaks, medfilt

# -------------------------
# Ground truth dictionary
# -------------------------
CONDITION_GT = {
    'A': {'name': "Healthy rest",              'RR': 14, 'HR': 80},
    'B': {'name': "Healthy sleeping",          'RR': 12, 'HR': 60},
    'C': {'name': "Healthy agitated",          'RR': 20, 'HR': 110},
    'D': {'name': "Apnea",                     'RR': 0,  'HR': 80},
    'E': {'name': "Asthma",                    'RR': 22, 'HR': 90},
    'F': {'name': "Asthma attack",             'RR': 33, 'HR': 130},
    'G': {'name': "Pneumothorax",              'RR': 30, 'HR': 120},
    'H': {'name': "Bradypnea",                 'RR': 6,  'HR': 50},
    'I': {'name': "Cardiac arrest",            'RR': 16, 'HR': 160},
    'J': {'name': "ACS Chest Pain",            'RR': 18, 'HR': 104},
    'K': {'name': "Tachypnea/Hyperventilation",'RR': 30, 'HR': 125},
}

# -------------------------
# Radar / array constants
# -------------------------
NUM_RX          = 4
NUM_TX          = 3
NUM_ADC_SAMPLES = 256
NUM_CHIRPS      = 128
HEADER_BYTES    = 0

FC = 60e9
B  = 3.9e9
FS = 40.0

ANGLE_GRID = np.linspace(-90, 90, 361)
RNG_GATE   = 3
EPS        = 1e-6

# IWR6843AOP phase centers you provided (mm)
RX_MM = np.array([
    [ 0.000, +1.800, 0.0],   # Rx0
    [+1.800,  0.000, 0.0],   # Rx1
    [ 0.000, -1.800, 0.0],   # Rx2
    [-1.800,  0.000, 0.0],   # Rx3
])
TX_MM = np.array([
    [+0.900, +0.900, 0.0],   # Tx0
    [-0.900, +0.900, 0.0],   # Tx1
    [ 0.000, -0.900, 0.0],   # Tx2
])

# =============================================
# 1) Filename parser (robust, captures ManN)
# =============================================
def parse_iwr6843_filename(filename):
    """
    Expected example:
    083_TopFront_2m_45deg_Man2_CondA_CondF_..._ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin
    Returns dict with keys: id, num_mannequins, cond_a/b/c (if present),
    adc, chirp, sr, re, fr, gain, fs
    """
    base = os.path.basename(filename)
    out = {}

    # id
    m = re.match(r'(?P<id>\d+)_', base)
    out['id'] = int(m.group('id')) if m else None

    # mannequins
    m = re.search(r'Man(?P<man>\d+)', base, re.IGNORECASE)
    out['num_mannequins'] = int(m.group('man')) if m else 2

    # conditions (up to 3)
    conds = re.findall(r'Cond([A-Z])', base)
    out['cond_a'] = conds[0] if len(conds) > 0 else None
    out['cond_b'] = conds[1] if len(conds) > 1 else None
    out['cond_c'] = conds[2] if len(conds) > 2 else None

    # numeric params
    def grab(tag, cast=int, default=None):
        m = re.search(tag + r'(\d+)', base)
        return cast(m.group(1)) if m else default

    out['adc']   = grab(r'ADC',   int, NUM_ADC_SAMPLES)
    out['chirp'] = grab(r'Chirp', int, NUM_CHIRPS)
    out['sr']    = grab(r'SR',    int, None)
    out['re']    = grab(r'RE',    int, None)
    out['fr']    = grab(r'FR',    int, None)
    out['gain']  = grab(r'Gain',  int, None)
    out['fs']    = grab(r'FS',    int, int(FS))

    return out

# =============================================
# 2) Decode BIN -> ADC cube (F,RX,C,S)
# =============================================
def decode_iwr6843_data(filepath, params, num_rx=NUM_RX, header_bytes=HEADER_BYTES):
    num_adc_samples = int(params['adc'])
    num_chirps      = int(params['chirp'])

    with open(filepath, 'rb') as f:
        f.seek(header_bytes)
        raw = np.fromfile(f, dtype=np.int16)

    if raw.size % 2 != 0:
        raw = raw[:-1]
    raw = raw.reshape(-1, 2)
    complex_data = raw[:, 0].astype(np.float32) + 1j * raw[:, 1].astype(np.float32)

    samp_pf    = num_rx * num_chirps * num_adc_samples
    num_frames = complex_data.size // samp_pf
    if num_frames == 0:
        return None
    complex_data = complex_data[: num_frames * samp_pf]

    adc = np.empty((num_frames, num_rx, num_chirps, num_adc_samples), dtype=np.complex64)
    for fr in range(num_frames):
        base = fr * samp_pf
        for rx in range(num_rx):
            st = base + rx * (num_chirps * num_adc_samples)
            en = st + (num_chirps * num_adc_samples)
            adc[fr, rx] = complex_data[st:en].reshape(num_chirps, num_adc_samples)
    return adc

# =============================================
# 3) Range FFT, virtual channels, geometry
# =============================================
def compute_range_profiles(adc):
    window = np.hanning(adc.shape[-1]).astype(np.float32)
    return np.fft.fft(adc * window, axis=-1)  # (F,RX,C,R)

def form_virtual_channels_from_rp(rp, num_tx=NUM_TX, drop_remainder=True, tx_order=(0,1,2)):
    F, RX, C, R = rp.shape
    if drop_remainder:
        usable = (C // num_tx) * num_tx
        rp = rp[:, :, :usable, :]
        C = usable
    else:
        assert C % num_tx == 0, "num_chirps must be divisible by num_tx."
    cpt = C // num_tx
    virt = []
    for t in tx_order:
        virt.append(rp[:, :, t:C:num_tx, :])  # (F,RX,cpt,R)
    rp_virt = np.concatenate(virt, axis=1)   # (F, TX*RX=12, cpt, R)
    return rp_virt

def virtual_array_positions_wavelengths(fc=FC, rx_mm=RX_MM, tx_mm=TX_MM):
    lam = 3e8 / fc
    rx_m = rx_mm * 1e-3; tx_m = tx_mm * 1e-3
    virt=[]
    for t in tx_m:
        for r in rx_m:
            virt.append((t + r) / lam)
    return np.array(virt, dtype=np.float64)  # (12,3)

def steering_vector_azimuth(theta_deg, virt_xyz_lam):
    th = np.deg2rad(theta_deg)
    kx, ky = np.cos(th), np.sin(th)
    phase = -1j * 2*np.pi * (virt_xyz_lam[:,0]*kx + virt_xyz_lam[:,1]*ky)
    return np.exp(phase)[:,None]

# =============================================
# 4) RA map (short-time + clutter removal)
# =============================================
def range_azimuth_map(rp_virt, virt_xyz_lam, angle_grid=ANGLE_GRID, method="bartlett"):
    F, V, C, R = rp_virt.shape
    X = rp_virt.transpose(3,1,0,2).reshape(R, V, -1)
    RA = np.zeros((R, len(angle_grid)), dtype=np.float64)
    for rb in range(R):
        S = X[rb]
        Rcov = (S @ S.conj().T) / max(1, S.shape[1]) + 1e-6*np.eye(V)
        Rinv = pinv(Rcov)
        for ai, ang in enumerate(angle_grid):
            a = steering_vector_azimuth(ang, virt_xyz_lam)
            if method.lower() == "bartlett":
                num = np.real((a.conj().T @ Rcov @ a)[0,0]); den = np.real((a.conj().T @ a)[0,0])
                RA[rb, ai] = num/(den+1e-12)
            else:
                RA[rb, ai] = 1.0/np.real((a.conj().T @ Rinv @ a)[0,0] + 1e-12)
    RA /= (RA.max(axis=1, keepdims=True) + 1e-12)
    return RA

def range_azimuth_map_sliding(rp_virt, virt_xyz_lam, angle_grid=ANGLE_GRID,
                              method="bartlett", win_frames=None, step_frames=None,
                              clutter_remove=True, aggregate="median"):
    F = rp_virt.shape[0]
    if win_frames is None:  win_frames  = max(1, int(0.5*FS))
    if step_frames is None: step_frames = max(1, win_frames//2)
    ras = []
    for s in range(0, F - win_frames + 1, step_frames):
        sl = rp_virt[s:s+win_frames]
        if clutter_remove:
            sl = sl - sl.mean(axis=0, keepdims=True)
        ras.append(range_azimuth_map(sl, virt_xyz_lam, angle_grid, method))
    if not ras:
        sl = rp_virt
        if clutter_remove: sl = sl - sl.mean(axis=0, keepdims=True)
        ras = [range_azimuth_map(sl, virt_xyz_lam, angle_grid, method)]
    stack = np.stack(ras, axis=0)
    RA = np.median(stack, axis=0) if aggregate=="median" else stack.mean(axis=0)
    RA /= (RA.max(axis=1, keepdims=True) + 1e-12)
    return RA

# =============================================
# 5) Range-first picking (AOP azimuth is weak)
# =============================================
def detect_targets_by_range(RA, num_targets=2, min_dist_m=0.7, B=B):
    c = 3e8; rng_res = c/(2*B); min_rb = int(np.ceil(min_dist_m/rng_res))
    energy = RA.sum(axis=1)  # (R,)
    idxs = np.argsort(energy)[::-1]
    used = np.zeros_like(energy, dtype=bool)
    picks = []
    for r in idxs:
        if used[r]: continue
        picks.append(r)
        lo, hi = max(0, r-min_rb), min(len(energy), r+min_rb+1)
        used[lo:hi] = True
        if len(picks) == num_targets: break
    targets = [(int(r), int(np.argmax(RA[r]))) for r in picks]
    return targets

# =============================================
# 6) MUSIC on respiration band (slow-time AoA)
# =============================================
def bp_slowtime(x_F, fs, lo, hi, order=3):
    b,a = butter(order, [lo/(fs/2), hi/(fs/2)], btype='band')
    return filtfilt(b, a, x_F)

def music_spectrum(Rcov, virt_xyz_lam, angle_grid, k):
    vals, vecs = eigh(Rcov)                 # ascending
    V = Rcov.shape[0]
    m = max(1, V - k)                       # noise-subspace dim
    En = vecs[:, :m]
    EnH = En.conj().T
    P = np.zeros(len(angle_grid), dtype=np.float64)
    for i, th in enumerate(angle_grid):
        a = steering_vector_azimuth(th, virt_xyz_lam)
        P[i] = 1.0 / (np.linalg.norm(EnH @ a, 2)**2 + 1e-12)
    P /= (P.max() + 1e-12)
    return P

def estimate_angles_music_resp(rp_virt, rbin, virt_xyz_lam, angle_grid,
                               fs=FS, rng_gate=RNG_GATE, resp_band=(0.1,0.6), k=2):
    F, V, C, R = rp_virt.shape
    lo = max(0, rbin - rng_gate); hi = min(R, rbin + rng_gate + 1)
    # average chirps and gate -> slow-time per channel
    X = rp_virt[:, :, :, lo:hi].mean(axis=(2,3))  # (F, V)
    for v in range(V):
        re = bp_slowtime(X[:, v].real, fs, *resp_band)
        im = bp_slowtime(X[:, v].imag, fs, *resp_band)
        X[:, v] = re + 1j*im
    S = X.T  # (V,F)
    Rcov = (S @ S.conj().T) / max(1, S.shape[1]) + 1e-6*np.eye(V)
    P = music_spectrum(Rcov, virt_xyz_lam, angle_grid, k=k)
    order = np.argsort(P)[::-1]
    taken = []
    for idx in order:
        if all(abs(idx - t) > 3 for t in taken):
            taken.append(idx)
        if len(taken) == k: break
    return taken, P

# =============================================
# 7) Beamforming at chosen (rbin, angle)
# =============================================
def mvdr_beam_at_angle(rp_virt, rbin, angle_idx, virt_xyz_lam, angle_grid=ANGLE_GRID, rng_gate=RNG_GATE, eps=EPS):
    F, V, C, R = rp_virt.shape
    lo = max(0, rbin - rng_gate); hi = min(R, rbin + rng_gate + 1)
    gated = rp_virt[:, :, :, lo:hi]                # (F,V,C,gate)
    Xall  = gated.transpose(1,0,2,3).reshape(V,-1) # (V,snaps)
    Rcov  = (Xall @ Xall.conj().T)/max(1,Xall.shape[1]) + eps*np.eye(V)
    Rinv  = pinv(Rcov)
    theta = angle_grid[angle_idx]
    a     = steering_vector_azimuth(theta, virt_xyz_lam)
    w     = (Rinv @ a) / (a.conj().T @ Rinv @ a)
    beam  = np.zeros(F, dtype=np.complex64)
    for fr in range(F):
        Y = gated[fr].reshape(V,-1)
        beam[fr] = (w.conj().T @ Y).sum()
    return beam

# =============================================
# 8) HR / RR estimators
# =============================================
def bandpass_filter(x, low, high, fs, order=4):
    b, a = butter(order, [low/(fs/2), high/(fs/2)], btype='band')
    return filtfilt(b, a, x)

def estimate_hr_from_beam_phase(z, fs=FS):
    phase = np.unwrap(np.angle(z))
    dphi  = np.gradient(phase)*fs/(2*np.pi)          # Hz
    heart = bandpass_filter(dphi, 1.0, 2.5, fs, 4)   # 60–150 bpm
    f, P  = welch(heart, fs=fs, nperseg=min(len(heart), int(8*fs)))
    mask  = (f>=1.0) & (f<=3.0)
    return (f[mask][np.argmax(P[mask])] * 60.0) if np.any(mask) else np.nan

def estimate_rr_from_beam_amp(z, fs=FS, rr_band=(0.1,0.6), min_rr=5, max_rr=40, prom_factor=0.3):
    amp   = np.abs(z)
    env   = bandpass_filter(amp, rr_band[0], rr_band[1], fs, 4)
    env_s = medfilt(env, kernel_size=5)
    min_dist = int(fs * 60.0 / max_rr)
    prom     = prom_factor * np.std(env_s)
    peaks, _ = find_peaks(env_s, distance=min_dist, prominence=prom)
    if len(peaks) < 2: return np.nan
    ibis = np.diff(peaks) / fs
    bpm  = 60.0 / ibis
    valid = bpm[(bpm >= min_rr) & (bpm <= max_rr)]
    return float(np.mean(valid)) if valid.size else np.nan

# =============================================
# 9) Per-file pipeline → list of subject estimates
# =============================================
def process_file(filepath):
    params = parse_iwr6843_filename(filepath)
    adc = decode_iwr6843_data(filepath, params)
    if adc is None:
        print(f"[skip] {os.path.basename(filepath)}: decode failed.")
        return []

    rp = compute_range_profiles(adc)                         # (F,RX,C,R)
    rp_virt = form_virtual_channels_from_rp(rp)              # (F,12,C/3,R)
    virt_xyz_lam = virtual_array_positions_wavelengths(FC, RX_MM, TX_MM)

    # Short-time RA (robust)
    RA = range_azimuth_map_sliding(rp_virt, virt_xyz_lam, ANGLE_GRID,
                                   method="bartlett", win_frames=max(1,int(0.5*FS)),
                                   step_frames=max(1,int(0.25*FS)),
                                   clutter_remove=True, aggregate="median")

    # Pick ranges first
    n_targets = max(1, int(params.get('num_mannequins', 2)))
    rng_targets = detect_targets_by_range(RA, num_targets=n_targets, min_dist_m=0.7, B=B)

    # MUSIC-on-respiration per range
    chosen = []
    if len(rng_targets) >= 2 and rng_targets[0][0] != rng_targets[1][0]:
        k_per = 1
    else:
        k_per = n_targets
    for (rbin, _) in rng_targets:
        ang_idx, _ = estimate_angles_music_resp(rp_virt, rbin, virt_xyz_lam, ANGLE_GRID,
                                                fs=FS, rng_gate=RNG_GATE, resp_band=(0.1,0.6), k=k_per)
        for ai in ang_idx:
            chosen.append((rbin, ai))
        if len(chosen) >= n_targets: break
    while len(chosen) < n_targets:
        for (rbin, _) in rng_targets:
            ai = int(np.argmax(RA[rbin]))
            if (rbin, ai) not in chosen:
                chosen.append((rbin, ai))
            if len(chosen) == n_targets: break

    # Beamform & estimate vitals
    beams = []
    for (rbin, ai) in chosen:
        beams.append(mvdr_beam_at_angle(rp_virt, rbin, ai, virt_xyz_lam, ANGLE_GRID, RNG_GATE, EPS))
    beams = [b for b in beams if b is not None]
    estimates = []
    for i, z in enumerate(beams):
        hr = estimate_hr_from_beam_phase(z, FS)
        rr = estimate_rr_from_beam_amp(z, FS)
        estimates.append({'beam_idx': i+1, 'HR': hr, 'RR': rr})

    # Map to mannequin condition labels (a,b,c) by sorting chosen by range
    chosen_sorted = sorted(chosen, key=lambda t: t[0])  # ascending range
    cond_codes = [params.get('cond_a'), params.get('cond_b'), params.get('cond_c')]
    cond_codes = [c for c in cond_codes if c is not None][:len(estimates)]

    per_subject = []
    for i, est in enumerate(estimates):
        cond_code = cond_codes[i] if i < len(cond_codes) else None
        gt = CONDITION_GT.get(cond_code, None)
        per_subject.append({
            'file': os.path.basename(filepath),
            'mannequin': i,  # 0-based index
            'condition': cond_code,
            'condition_name': gt['name'] if gt else None,
            'estimated_RR': float(est['RR']) if est['RR'] is not None and not np.isnan(est['RR']) else None,
            'estimated_HR': float(est['HR']) if est['HR'] is not None and not np.isnan(est['HR']) else None,
            'gt_RR': gt['RR'] if gt else None,
            'gt_HR': gt['HR'] if gt else None
        })
    return per_subject

# =============================================
# 10) Metrics (no sklearn)
# =============================================
def compute_metrics(results):
    rr_pairs = [(r['gt_RR'], r['estimated_RR']) for r in results if r['gt_RR'] is not None and r['estimated_RR'] is not None]
    hr_pairs = [(r['gt_HR'], r['estimated_HR']) for r in results if r['gt_HR'] is not None and r['estimated_HR'] is not None]

    def mae_rmse(pairs):
        if not pairs: return None, None
        gt  = np.array([p[0] for p in pairs], dtype=float)
        est = np.array([p[1] for p in pairs], dtype=float)
        mae = float(np.mean(np.abs(gt - est)))
        rmse= float(np.sqrt(np.mean((gt - est)**2)))
        return mae, rmse

    rr_mae, rr_rmse = mae_rmse(rr_pairs)
    hr_mae, hr_rmse = mae_rmse(hr_pairs)

    return {
        "RR_MAE": rr_mae, "RR_RMSE": rr_rmse,
        "HR_MAE": hr_mae, "HR_RMSE": hr_rmse
    }

# =============================================
# 11) Save results CSV
# =============================================
def save_results_csv(results, out_csv="folder_results.csv"):
    if not results:
        print("[warn] no results to save.")
        return
    keys = ['file','mannequin','condition','condition_name','estimated_RR','estimated_HR','gt_RR','gt_HR']
    with open(out_csv, 'w', newline='') as f:
        w = csv.DictWriter(f, fieldnames=keys)
        w.writeheader()
        for r in results:
            w.writerow({k: r.get(k, None) for k in keys})
    print(f"[save] {out_csv}")

# =============================================
# 12) Folder runner
# =============================================
def run_pipeline(folder):
    files = [os.path.join(folder, f) for f in os.listdir(folder)
             if f.endswith('.bin') and '_IWR' in f]
    files.sort()
    print(f"Found {len(files)} files in {folder}")
    results = []
    for fpath in files:
        try:
            rows = process_file(fpath)
            results.extend(rows)
            print(f"[ok] {os.path.basename(fpath)} -> {len(rows)} subjects")
        except Exception as e:
            print(f"[err] {os.path.basename(fpath)}: {e}")
    metrics = compute_metrics(results)
    print("\n--- Metrics ---")
    for k, v in metrics.items():
        print(f"{k}: {v:.2f}" if v is not None else f"{k}: N/A")
    return results, metrics

# -------------------------
# Example usage
# -------------------------
if __name__ == "__main__":
    folder = "plots/SubsetC"   # <-- set your folder
    res, metrics = run_pipeline(folder)
    save_results_csv(res, out_csv=os.path.join(folder, "results_with_metrics.csv"))
    # quick peek
    for r in res[:5]:
        print(r)


Found 30 files in plots/SubsetC
[ok] 146_TopFront_4m_0deg_Man3_CondE_CondG_CondB_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin -> 3 subjects
[ok] 147_TopFront_4m_0deg_Man3_CondH_CondA_CondL_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin -> 3 subjects
[ok] 148_TopFront_4m_0deg_Man3_CondK_CondJ_CondA_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin -> 3 subjects
[ok] 149_TopFront_4m_0deg_Man3_CondF_CondC_CondH_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin -> 3 subjects
[ok] 150_TopFront_4m_0deg_Man3_CondD_CondI_CondI_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin -> 3 subjects
[ok] 151_TopFront_4m_0deg_Man3_CondH_CondB_CondJ_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin -> 3 subjects
[ok] 152_TopFront_4m_0deg_Man3_CondI_CondE_CondK_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin -> 3 subjects
[ok] 153_TopFront_4m_0deg_Man3_CondJ_CondF_CondD_ADC256_Chirp128_SR1000_RE60_FR40_Gain30_FS30_IWR_0.bin -> 3 subjects
[ok] 154_TopFront_4m_0de