# Interpolate to reference grid + filtering (IMU filtered at native rate)

This version does what you want:

- **IMU**: collapse duplicate timestamps → estimate IMU sampling rate from median `dt` → apply **zero-phase 4th-order Butterworth LPF at 10 Hz** on the IMU stream **before interpolation** → then interpolate onto `t_ref`.
- **Joints + Wheelchair**: interpolate onto `t_ref` → apply **zero-phase 4th-order Butterworth LPF at 5 Hz** on the reference grid (with Nyquist-safety clamping) → save.

Outputs:
- `refgrid_interpolated_and_filtered.csv`
- `refgrid_interpolated_and_filtered.parquet`

Notes:
- `filtfilt` assumes approximately uniform sampling. We use the **median** IMU `dt` to define `fs_imu` after de-duplication.
- Filtering is NaN-aware on the refgrid (handles out-of-range NaNs from interpolation).


In [None]:
import os
import numpy as np
import pandas as pd
import h5py

from scipy.signal import butter, filtfilt

# -----------------------------
# Paths (edit if needed)
# -----------------------------
REF_CSV   = "timestamps_synced_refgrid.csv"
IMU_H5    = "kinova_gen3_imu.h5"
JOINT_H5  = "kinova_gen3_joint_states.h5"
WHILL_H5  = "kinova_gen3_wheelchair_states.h5"

OUT_CSV     = "refgrid_interpolated_and_filtered.csv"
OUT_PARQUET = "refgrid_interpolated_and_filtered.parquet"

assert os.path.exists(REF_CSV),  f"Missing {REF_CSV}"
assert os.path.exists(IMU_H5),   f"Missing {IMU_H5}"
assert os.path.exists(JOINT_H5), f"Missing {JOINT_H5}"
assert os.path.exists(WHILL_H5), f"Missing {WHILL_H5}"
print("All input files found.")

All input files found.


In [None]:
# -----------------------------
# Load reference grid (target timestamps)
# -----------------------------
ref = pd.read_csv(REF_CSV)
if "ref_time_s" not in ref.columns:
    raise ValueError(f"'ref_time_s' column not found. Columns are: {list(ref.columns)}")

t_ref = ref["ref_time_s"].to_numpy(dtype=np.float64)

# Ensure strictly increasing
if np.any(np.diff(t_ref) <= 0):
    t_ref = np.unique(t_ref)

dt_ref = np.median(np.diff(t_ref))
fs_ref = 1.0 / dt_ref
nyq_ref = 0.5 * fs_ref
print(f"Reference grid: N={len(t_ref)}  median dt={dt_ref:.6f}s  fs≈{fs_ref:.3f} Hz  Nyquist≈{nyq_ref:.3f} Hz")

Reference grid: N=1860  median dt=0.083333s  fs≈12.000 Hz  Nyquist≈6.000 Hz


In [None]:
# -----------------------------
# Helpers
# -----------------------------
def _as_2d(y: np.ndarray) -> np.ndarray:
    y = np.asarray(y)
    if y.ndim == 1:
        return y.reshape(-1, 1)
    return y

def collapse_duplicate_timestamps(t: np.ndarray, y: np.ndarray, agg: str = "mean"):
    '''
    Collapse duplicate timestamps by aggregating y over identical t.
    - t: (N,)
    - y: (N, D) or (N,)
    Returns: (t_unique, y_agg)
    '''
    t = np.asarray(t, dtype=np.float64)
    y2 = _as_2d(np.asarray(y))

    order = np.argsort(t)
    t = t[order]
    y2 = y2[order]

    t_u, idx_start, counts = np.unique(t, return_index=True, return_counts=True)

    if agg == "mean":
        y_sum = np.add.reduceat(y2, idx_start, axis=0)
        y_agg = y_sum / counts[:, None]
    elif agg == "last":
        idx_last = idx_start + counts - 1
        y_agg = y2[idx_last]
    else:
        raise ValueError("agg must be 'mean' or 'last'")

    return t_u, y_agg

def interp_linear(t_src: np.ndarray, y_src: np.ndarray, t_tgt: np.ndarray):
    '''
    Linear interpolation column-wise.
    - Requires t_src strictly increasing.
    - Returns NaN outside the source range.
    '''
    t_src = np.asarray(t_src, dtype=np.float64)
    y2 = _as_2d(np.asarray(y_src, dtype=np.float64))
    t_tgt = np.asarray(t_tgt, dtype=np.float64)

    if t_src.size < 2:
        raise ValueError("Need at least 2 source samples for interpolation.")
    if np.any(np.diff(t_src) <= 0):
        raise ValueError("t_src must be strictly increasing. (Call collapse_duplicate_timestamps first.)")

    out = np.empty((t_tgt.size, y2.shape[1]), dtype=np.float64)
    for d in range(y2.shape[1]):
        out[:, d] = np.interp(t_tgt, t_src, y2[:, d], left=np.nan, right=np.nan)
    return out.squeeze()

def _contiguous_segments(mask: np.ndarray):
    idx = np.flatnonzero(mask)
    if idx.size == 0:
        return
    breaks = np.where(np.diff(idx) > 1)[0]
    starts = np.r_[idx[0], idx[breaks + 1]]
    ends   = np.r_[idx[breaks] + 1, idx[-1] + 1]
    for s, e in zip(starts, ends):
        yield int(s), int(e)

def clamp_cutoff(cutoff_hz: float, fs_hz: float, label: str):
    nyq = 0.5 * fs_hz
    max_ok = nyq * 0.999
    if cutoff_hz >= nyq:
        print(f"[WARN] {label} cutoff {cutoff_hz:.3f} Hz >= Nyquist {nyq:.3f} Hz (fs={fs_hz:.3f} Hz). "
              f"Clamping to {max_ok:.3f} Hz.")
        return max_ok
    return cutoff_hz

def butter_lowpass(order: int, cutoff_hz: float, fs_hz: float):
    nyq = 0.5 * fs_hz
    wn = cutoff_hz / nyq
    b, a = butter(order, wn, btype="low", analog=False)
    return b, a

def filtfilt_nanaware(x: np.ndarray, b: np.ndarray, a: np.ndarray):
    '''
    Zero-phase filtering for 1D signal x, handling NaNs by filtering each valid segment.
    Short segments are left unchanged.
    '''
    x = np.asarray(x, dtype=np.float64)
    y = x.copy()

    valid = np.isfinite(x)
    if valid.sum() == 0:
        return y

    padlen = 3 * (max(len(a), len(b)) - 1)

    for s, e in _contiguous_segments(valid):
        seg = x[s:e]
        if seg.size <= padlen:
            continue
        y[s:e] = filtfilt(b, a, seg, method="pad")
    return y

def filtfilt_columns_nanaware(Y: np.ndarray, b: np.ndarray, a: np.ndarray):
    Y2 = _as_2d(Y)
    out = np.empty_like(Y2)
    for d in range(Y2.shape[1]):
        out[:, d] = filtfilt_nanaware(Y2[:, d], b, a)
    return out.squeeze()

def filtfilt_columns_dense(Y: np.ndarray, b: np.ndarray, a: np.ndarray):
    '''
    For dense (no-NaN) arrays: apply filtfilt per column.
    '''
    Y2 = _as_2d(Y)
    out = np.empty_like(Y2)
    for d in range(Y2.shape[1]):
        out[:, d] = filtfilt(b, a, Y2[:, d], method="pad")
    return out.squeeze()

In [None]:
# -----------------------------
# IMU: filter at native rate, then interpolate
# -----------------------------
IMU_CUTOFF_HZ = 10.0
ORDER = 4

with h5py.File(IMU_H5, "r") as f:
    t_imu = f["timestamps"][:].astype(np.float64)
    imu_lin = f["linear_accelerations"][:].astype(np.float64)      # (N,3)
    imu_ang = f["angular_velocitys"][:].astype(np.float64)         # (N,3)
    imu_q   = f["orientations"][:].astype(np.float64)              # (N,4)

# De-dup timestamps (and sort)
t_imu_u, imu_lin_u = collapse_duplicate_timestamps(t_imu, imu_lin, agg="mean")
_,      imu_ang_u  = collapse_duplicate_timestamps(t_imu, imu_ang, agg="mean")
_,      imu_q_u    = collapse_duplicate_timestamps(t_imu, imu_q,   agg="mean")

# Estimate IMU sampling rate from median dt after de-duplication
dt_imu = np.diff(t_imu_u)
dt_imu = dt_imu[dt_imu > 0]
fs_imu = 1.0 / np.median(dt_imu)
nyq_imu = 0.5 * fs_imu
print(f"IMU unique samples: {len(t_imu_u)}  fs≈{fs_imu:.3f} Hz  Nyquist≈{nyq_imu:.3f} Hz")

# Design IMU filter on IMU fs (10 Hz is now valid on native IMU stream)
imu_cutoff_used = clamp_cutoff(IMU_CUTOFF_HZ, fs_imu, "IMU(native)")
b_imu, a_imu = butter_lowpass(order=ORDER, cutoff_hz=imu_cutoff_used, fs_hz=fs_imu)

# Filter IMU in native time order (assumes approximately uniform sampling)
imu_lin_u_f = filtfilt_columns_dense(imu_lin_u, b_imu, a_imu)
imu_ang_u_f = filtfilt_columns_dense(imu_ang_u, b_imu, a_imu)
imu_q_u_f   = filtfilt_columns_dense(imu_q_u,   b_imu, a_imu)

print(f"IMU filter cutoff used: {imu_cutoff_used:.3f} Hz")

# Interpolate BOTH raw and filtered IMU onto reference grid
imu_lin_ref    = interp_linear(t_imu_u, imu_lin_u,   t_ref)
imu_ang_ref    = interp_linear(t_imu_u, imu_ang_u,   t_ref)
imu_q_ref      = interp_linear(t_imu_u, imu_q_u,     t_ref)

imu_lin_ref_f  = interp_linear(t_imu_u, imu_lin_u_f, t_ref)
imu_ang_ref_f  = interp_linear(t_imu_u, imu_ang_u_f, t_ref)
imu_q_ref_f    = interp_linear(t_imu_u, imu_q_u_f,   t_ref)

print("IMU done (filtered native → interpolated).")
print(f"timu {t_imu_u}")
print(f"t_ref {t_ref}")


IMU unique samples: 7002  fs≈65.950 Hz  Nyquist≈32.975 Hz
IMU filter cutoff used: 10.000 Hz
IMU done (filtered native → interpolated).
timu [1.76541179e+09 1.76541179e+09 1.76541179e+09 ... 1.76541195e+09
 1.76541195e+09 1.76541195e+09]
t_ref [1.76541179e+09 1.76541179e+09 1.76541179e+09 ... 1.76541195e+09
 1.76541195e+09 1.76541195e+09]


In [None]:
# -----------------------------
# Joints: interpolate, then filter on refgrid (5 Hz)
# -----------------------------
OTHER_CUTOFF_HZ = 5.0

with h5py.File(JOINT_H5, "r") as f:
    t_j = f["timestamps"][:].astype(np.float64)
    j_pos = f["positions"][:].astype(np.float64)      # (N,7)
    j_vel = f["velocitys"][:].astype(np.float64)      # (N,7)

t_j_u, j_pos_u = collapse_duplicate_timestamps(t_j, j_pos, agg="mean")
_,      j_vel_u = collapse_duplicate_timestamps(t_j, j_vel, agg="mean")

j_pos_ref = interp_linear(t_j_u, j_pos_u, t_ref)
j_vel_ref = interp_linear(t_j_u, j_vel_u, t_ref)

# Filter on reference grid
other_cutoff_used = clamp_cutoff(OTHER_CUTOFF_HZ, fs_ref, "Joints(refgrid)")
b_other, a_other = butter_lowpass(order=ORDER, cutoff_hz=other_cutoff_used, fs_hz=fs_ref)

j_pos_ref_f = filtfilt_columns_nanaware(j_pos_ref, b_other, a_other)
j_vel_ref_f = filtfilt_columns_nanaware(j_vel_ref, b_other, a_other)

print(f"Joints done (interp → filtfilt @ {other_cutoff_used:.3f} Hz).")

Joints done (interp → filtfilt @ 5.000 Hz).


In [None]:
# -----------------------------
# Wheelchair: interpolate, then filter on refgrid (5 Hz)
# -----------------------------
with h5py.File(WHILL_H5, "r") as f:
    n = int(f["num_datapoints"][()])
    start = float(f["record_start_time"][()])
    freq  = float(f["record_frequency"][()])  # Hz

    w_l_ang = f["left_motor_angles"][:].astype(np.float64)
    w_l_spd = f["left_motor_speeds"][:].astype(np.float64)
    w_r_ang = f["right_motor_angles"][:].astype(np.float64)
    w_r_spd = f["right_motor_speeds"][:].astype(np.float64)

t_w = start + np.arange(n, dtype=np.float64) / freq

w_l_ang_ref = interp_linear(t_w, w_l_ang, t_ref)
w_l_spd_ref = interp_linear(t_w, w_l_spd, t_ref)
w_r_ang_ref = interp_linear(t_w, w_r_ang, t_ref)
w_r_spd_ref = interp_linear(t_w, w_r_spd, t_ref)

w_l_ang_ref_f = filtfilt_nanaware(w_l_ang_ref, b_other, a_other)
w_l_spd_ref_f = filtfilt_nanaware(w_l_spd_ref, b_other, a_other)
w_r_ang_ref_f = filtfilt_nanaware(w_r_ang_ref, b_other, a_other)
w_r_spd_ref_f = filtfilt_nanaware(w_r_spd_ref, b_other, a_other)

print("Wheelchair done (interp → filtfilt on refgrid).")

Wheelchair done (interp → filtfilt on refgrid).


In [None]:
# -----------------------------
# Merge + Save (raw + filtered columns)
# -----------------------------
out = pd.DataFrame({"t_ref_s": t_ref})

# IMU (raw interpolated)
out[["imu_lin_x","imu_lin_y","imu_lin_z"]] = imu_lin_ref
out[["imu_ang_x","imu_ang_y","imu_ang_z"]] = imu_ang_ref
out[["imu_q0","imu_q1","imu_q2","imu_q3"]] = imu_q_ref

# IMU (filtered-native → interpolated)
out[["imu_lin_x_f10","imu_lin_y_f10","imu_lin_z_f10"]] = imu_lin_ref_f
out[["imu_ang_x_f10","imu_ang_y_f10","imu_ang_z_f10"]] = imu_ang_ref_f
out[["imu_q0_f10","imu_q1_f10","imu_q2_f10","imu_q3_f10"]] = imu_q_ref_f

# Joints (raw + filtered)
for i in range(j_pos_ref.shape[1]):
    out[f"joint_pos_{i}"] = j_pos_ref[:, i]
    out[f"joint_pos_{i}_f5"] = j_pos_ref_f[:, i]

for i in range(j_vel_ref.shape[1]):
    out[f"joint_vel_{i}"] = j_vel_ref[:, i]
    out[f"joint_vel_{i}_f5"] = j_vel_ref_f[:, i]

# Wheelchair (raw + filtered)
out["wheel_left_angle"]      = w_l_ang_ref
out["wheel_left_angle_f5"]   = w_l_ang_ref_f
out["wheel_left_speed"]      = w_l_spd_ref
out["wheel_left_speed_f5"]   = w_l_spd_ref_f
out["wheel_right_angle"]     = w_r_ang_ref
out["wheel_right_angle_f5"]  = w_r_ang_ref_f
out["wheel_right_speed"]     = w_r_spd_ref
out["wheel_right_speed_f5"]  = w_r_spd_ref_f

print("Output table:", out.shape)
display(out.head(3))

out.to_csv(OUT_CSV, index=False)
out.to_parquet(OUT_PARQUET, index=False)
print("Saved:", OUT_CSV, "and", OUT_PARQUET)

Output table: (1860, 57)


Unnamed: 0,t_ref_s,imu_lin_x,imu_lin_y,imu_lin_z,imu_ang_x,imu_ang_y,imu_ang_z,imu_q0,imu_q1,imu_q2,...,joint_vel_6,joint_vel_6_f5,wheel_left_angle,wheel_left_angle_f5,wheel_left_speed,wheel_left_speed_f5,wheel_right_angle,wheel_right_angle_f5,wheel_right_speed,wheel_right_speed_f5
0,1765412000.0,,,,,,,,,,...,,,,,,,,,,
1,1765412000.0,-6.468126,0.975801,7.257833,0.003196,0.002815,-0.002893,0.0,0.0,0.0,...,0.0,-2.009665e-06,,,,,,,,
2,1765412000.0,-6.487327,0.912944,7.210071,0.00172,0.0014,-0.002952,0.0,0.0,0.0,...,0.0,8.987092e-07,,,,,,,,


Saved: refgrid_interpolated_and_filtered.csv and refgrid_interpolated_and_filtered.parquet


In [None]:
# -----------------------------
# Optional: NaN report
# -----------------------------
nan_frac = out.isna().mean().sort_values(ascending=False).head(20)
nan_frac

Unnamed: 0,0
wheel_left_angle,0.002151
wheel_right_speed_f5,0.002151
wheel_right_speed,0.002151
wheel_right_angle_f5,0.002151
wheel_right_angle,0.002151
wheel_left_speed_f5,0.002151
wheel_left_speed,0.002151
wheel_left_angle_f5,0.002151
imu_lin_x,0.000538
imu_q0_f10,0.000538
