In [None]:
from spf.dataset.spf_dataset import v5spfdataset

# ds_fn = "/mnt/md1/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_10_03_38_21_nRX2_rx_circle.zarr"

ds_fn = "/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_06_10_03_38_21_nRX2_rx_circle.zarr"
# ds_fn = "/mnt/md0/spf/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_15_11_44_13_nRX2_bounce.zarr"


nthetas = 65
ds = v5spfdataset(
    ds_fn,
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache="/mnt/4tb_ssd/precompute_cache_new",
    paired=True,
    skip_fields=["signal_matrix"],
)

In [None]:
output_prefix = "./test_"

In [None]:
from filterpy.kalman import ExtendedKalmanFilter
import numpy as np
from filterpy.common import Q_discrete_white_noise

from spf.rf import pi_norm, reduce_theta_to_positive_y


def residual(a, b):
    # return pi_norm(a - b)
    y = a - b
    y = y % (2 * np.pi)  # force in range [0, 2 pi)
    if y > np.pi:  # move to [-pi, pi)
        y -= 2 * np.pi
    return y


def residual_paired(a, b):
    # return pi_norm(a - b)
    return np.array([residual(a[0], b[0]), residual(a[1], b[1])])


def trajectory_for_phi_paired(ds):
    rk = ExtendedKalmanFilter(dim_x=2, dim_z=2)
    # initialize with first ground truth state
    y_rad = ds[rx_idx][0]["ground_truth_theta"].item()
    # y_rad_reduced=reduce_theta_to_positive_y(y_rad)
    rk.x = np.array([[y_rad], [0]])

    dt = 0.05
    rk.F = np.eye(2) + np.array([[0, 1], [0, 0]]) * dt

    phi_std = 0.5
    rk.R *= phi_std**2  # can this change dependent on state?
    rk.Q = Q_discrete_white_noise(2, dt=dt, var=1.0)
    rk.P *= 0.1  # initialized as identity?

    traj = []
    for idx in range(len(ds)):
        rk.update(
            np.array(
                [
                    [ds[idx][0]["mean_phase_segmentation"]],
                    [ds[idx][1]["mean_phase_segmentation"]],
                ]
            ),
            HJacobian_at_paired,
            hx_paired,
            residual=residual_paired,
        )
        traj.append(rk.x[0, 0])
        rk.predict()
        rk.x = pi_norm(rk.x)
    return np.array(traj)

In [None]:
# paired EKF

offsets = [
    ds.yaml_config["receivers"][0]["theta-in-pis"] * np.pi,
    ds.yaml_config["receivers"][1]["theta-in-pis"] * np.pi,
]


# flip the order of the antennas
antenna_spacing = -ds.yaml_config["receivers"][0]["antenna-spacing-m"]
assert antenna_spacing == -ds.yaml_config["receivers"][1]["antenna-spacing-m"]

wavelength = ds.wavelengths[0]
assert wavelength == ds.wavelengths[1]


def hx_paired(x):
    return np.array(
        [
            np.sin(x[0, 0] - offsets[0]) * antenna_spacing * 2 * np.pi / wavelength,
            np.sin(x[0, 0] - offsets[1]) * antenna_spacing * 2 * np.pi / wavelength,
        ]
    )


def HJacobian_at_paired(x):
    """compute Jacobian of H matrix at x"""
    return np.array(
        [
            [
                np.cos(x[0, 0] - offsets[0]) * antenna_spacing * 2 * np.pi / wavelength,
                0,
            ],
            [
                np.cos(x[0, 0] - offsets[1]) * antenna_spacing * 2 * np.pi / wavelength,
                0,
            ],
        ]
    )

In [None]:
import matplotlib.pyplot as plt

from spf.rf import torch_pi_norm

fig, ax = plt.subplots(1, 3, figsize=(15, 5))

for rx_idx in [0, 1]:
    ax[rx_idx].scatter(
        range(ds.mean_phase[f"r{rx_idx}"].shape[0]),
        ds.mean_phase[f"r{rx_idx}"],
        label=f"r{rx_idx} estimated phi",
        s=1.0,
        alpha=1.0,
        color="red",
    )
    ax[rx_idx].plot(ds.ground_truth_phis[rx_idx], label="perfect phi")

    ax[rx_idx].legend()
    ax[rx_idx].set_xlabel("time step")
    ax[rx_idx].set_ylabel("phi")
    ax[rx_idx].set_title(f"Radio {rx_idx}")

ax[2].plot(
    [pi_norm(ds[idx][0]["craft_y_rad"].item()) for idx in range(len(ds))],
    label="craft gt theta",
)
ax[2].set_title("Paired radio EKF")
ax[2].plot(trajectory_for_phi_paired(ds), label="EKF")

ax[2].legend()
fig.suptitle("EKF: When two (radios) become one")
fig.savefig(f"{output_prefix}_when_two_become_one_ekf.png")