In [None]:
from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml

n = 1025
noise = 0.05
nthetas = 65
ds_fn = f"sample_dataset_for_ekf_n{n}_noise{noise}"

create_fake_dataset(
    filename=ds_fn, yaml_config_str=fake_yaml, n=n, noise=noise, orbits=5
)

In [None]:
from spf.dataset.spf_dataset import v5spfdataset

ds = v5spfdataset(
    ds_fn,
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache="/tmp/",
    paired=True,
    skip_signal_matrix=True,
)

In [None]:
# import torch


# def estimate_phi(rx_idx, ds):
#     estimates = []
#     for idx in range(len(ds)):
#         estimates.append(ds.estimate_phi(ds[idx][rx_idx]))
#     return torch.vstack(estimates)

In [None]:
# ds.get_segmentation_mean_phase().keys()

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

for rx_idx in [0, 1]:
    ax[rx_idx].plot(
        [ds[idx][rx_idx]["ground_truth_theta"] for idx in range(len(ds))],
        label=f"radio{rx_idx} gt theta",
    )
    ax[rx_idx].plot(
        ds.mean_phase[f"r{rx_idx}"],
        label=f"radio{rx_idx} est phi",
    )

    ax[rx_idx].legend()

In [None]:
import numpy as np

"""
x = [ theta dtheta/dt ]
z = [ phi ]

F = [ [ 1 dt ],
      [ 0  1 ]]

h(x) = sin(x[0]) * (d * 2 * pi / wavelength )

H(x) = [ dh/dx_1 , dh/dx_2 ] = cos(x[0]) * (d * 2 * pi / wavelength )

"""

antenna_spacing = ds.yaml_config["receivers"][0]["antenna-spacing-m"]
assert antenna_spacing == ds.yaml_config["receivers"][1]["antenna-spacing-m"]

wavelength = ds.wavelengths[0]
assert wavelength == ds.wavelengths[1]


def hx(x):
    return np.array([np.sin(x[0, 0]) * antenna_spacing * 2 * np.pi / wavelength])


def HJacobian_at(x):
    """compute Jacobian of H matrix at x"""
    return np.array([[np.cos(x[0, 0]) * antenna_spacing * 2 * np.pi / wavelength, 0]])

In [None]:
from filterpy.kalman import ExtendedKalmanFilter

from filterpy.common import Q_discrete_white_noise

from spf.rf import pi_norm, reduce_theta_to_positive_y


def residual(a, b):
    # return pi_norm(a - b)
    y = a - b
    y = y % (2 * np.pi)  # force in range [0, 2 pi)
    if y > np.pi:  # move to [-pi, pi)
        y -= 2 * np.pi
    return y


def trajectory_for_phi(rx_idx, dx):
    rk = ExtendedKalmanFilter(dim_x=2, dim_z=1)
    # initialize with first ground truth state
    y_rad = ds[rx_idx][0]["ground_truth_theta"].item()
    # y_rad_reduced=reduce_theta_to_positive_y(y_rad)
    rk.x = np.array([[y_rad], [0]])

    dt = 0.1
    rk.F = np.eye(2) + np.array([[0, 1], [0, 0]]) * dt

    phi_std = 0.5
    rk.R = phi_std**2  # np.diag([phi_std**2])
    rk.Q = Q_discrete_white_noise(2, dt=dt, var=1.0)
    rk.P *= 1  # initialized as identity?

    traj = []
    for idx in range(len(ds)):
        rk.update(
            np.array([ds[idx][rx_idx]["mean_phase_segmentation"]]),
            HJacobian_at,
            hx,
            residual=residual,
        )
        traj.append(rk.x[0, 0])
        rk.predict()
        rk.x = pi_norm(rk.x)
    return np.array(traj)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

for rx_idx in [0, 1]:
    ax[rx_idx].plot(
        [ds[idx][rx_idx]["ground_truth_theta"] for idx in range(len(ds))],
        label="radio gt theta",
    )
    # phi_estimates = torch.vstack(
    #     [ds.estimate_phi(ds[idx][rx_idx]) for idx in range(len(ds))]
    # )
    # ax[rx_idx].plot(
    #     phi_estimates[:, 0],
    #     label=f"radio{rx_idx} est phi",
    # )
    ax[rx_idx].plot(
        ds.mean_phase[f"r{rx_idx}"],
        label=f"radio{rx_idx} est phi",
    )
    ax[rx_idx].plot(trajectory_for_phi(rx_idx, ds))

    ax[rx_idx].legend()

In [None]:
z = trajectory_for_phi(rx_idx, ds)

In [None]:
np.array(z)

In [None]:
np.dot(np.array([[2], [1]]), np.array([[3]])).shape

In [None]:
np.array(ds[idx][rx_idx]["mean_phase_segmentation"]).shape

In [None]:
max_idx = 10
for idx in range(len(ds)):
    rk.update(np.array(ds[0][0]["mean_phase_segmentation"]), HJacobian_at, hx)

In [None]:
xs, track = [], []
for i in range(int(20 / dt)):
    z = radar.get_range()
    track.append((radar.pos, radar.vel, radar.alt))

    rk.update(array([z]), HJacobian_at, hx)
    xs.append(rk.x)
    rk.predict()

xs = asarray(xs)
track = asarray(track)
time = np.arange(0, len(xs) * dt, dt)

In [None]:
# x = batch_data["all_windows_stats"].to(torch_device).to(torch.float32)
# seg_mask = batch_data["downsampled_segmentation_mask"].to(torch_device)