In [None]:
import tempfile
from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml
from spf.dataset.spf_dataset import v5spfdataset

n = 1025
noise = 0.3
nthetas = 65
orbits = 4

tmpdir = tempfile.TemporaryDirectory()
tmpdirname = "/tmp/"  # tmpdir.name
temp_ds_fn = f"{tmpdirname}/sample_dataset_for_ekf_n{n}_noise{noise}"

In [None]:
create_fake_dataset(
    filename=temp_ds_fn, yaml_config_str=fake_yaml, n=n, noise=noise, orbits=orbits
)

In [None]:
ds = v5spfdataset(
    temp_ds_fn,
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache=tmpdirname,
    paired=True,
    skip_fields=set(["signal_matrix"]),
)

In [None]:
use_real_data = False
if use_real_data:
    # ds_fn = "/mnt/md1/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_10_03_38_21_nRX2_rx_circle.zarr"
    ds_fn = "/mnt/md1/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_15_11_44_13_nRX2_bounce.zarr"
    precompute_cache_dir = "/home/mouse9911/precompute_cache_chunk16_sept"
else:
    ds_fn = temp_ds_fn
    precompute_cache_dir = tmpdirname
ds = v5spfdataset(
    ds_fn,
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache=precompute_cache_dir,
    paired=True,
    skip_fields=set(["signal_matrix", "windowed_beamformer"]),
)

In [None]:
import os


output_prefix = "./" + os.path.basename(ds_fn) + "_"

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

for rx_idx in [0, 1]:
    ax[rx_idx].scatter(
        range(len(ds)),
        ds.mean_phase[f"r{rx_idx}"],
        label=f"radio{rx_idx} est phi",
        s=1.0,
        color="red",
    )
    ax[rx_idx].plot(ds.ground_truth_phis[rx_idx], label="perfect phi", color="blue")
    ax[rx_idx].plot(
        [ds[idx][rx_idx]["ground_truth_theta"] for idx in range(len(ds))],
        label=f"radio{rx_idx} gt theta",
        color="green",
    )
    ax[rx_idx].set_title(f"Radio {rx_idx}")
    ax[rx_idx].set_xlabel("Time step")
    ax[rx_idx].set_ylabel("tehta/phi")
    ax[rx_idx].legend()
    ax[rx_idx].axhline(y=0, color="r", linestyle="-")
fig.suptitle("Phase(phi) recovered from radios after segmentation")
fig.savefig(f"{output_prefix}_raw_signal.png")

In [None]:
import torch
from spf.filters.ekf_single_radio_filter import SPFKalmanFilter

kfs = [SPFKalmanFilter(ds=ds, rx_idx=rx_idx, phi_std=5.0, p=5) for rx_idx in range(2)]
single_radio_trajectories = [kf.trajectory(debug=True) for kf in kfs]
single_radio_metrics = [
    kf.metrics(trajectory) for kf, trajectory in zip(kfs, single_radio_trajectories)
]
print(single_radio_metrics)

In [None]:
from spf.filters.ekf_single_radio_filter import run_and_plot_single_radio_EKF

fig = run_and_plot_single_radio_EKF(ds, trajectories=single_radio_trajectories)

In [None]:
from spf.filters.ekf_dualradio_filter import SPFPairedKalmanFilter


kf = SPFPairedKalmanFilter(
    ds=ds, phi_std=5.0, p=5, dynamic_R=False
)  # , phi_std=0.5, p=5, **kwargs):
paired_trajectory = kf.trajectory(debug=True)
paired_metrics = kf.metrics(paired_trajectory)
print(paired_metrics)

In [None]:
from spf.filters.ekf_dualradio_filter import run_and_plot_dualradio_EKF

fig = run_and_plot_dualradio_EKF(ds, trajectory=paired_trajectory)

In [None]:
from spf.filters.ekf_dualradioXY_filter import SPFPairedXYKalmanFilter

kf = SPFPairedXYKalmanFilter(ds=ds, phi_std=5.0, p=0.1, dynamic_R=True)
pairedXY_trajectory = kf.trajectory(debug=True, dt=1.0, noise_std=10)
pairedXY_metrics = kf.metrics(pairedXY_trajectory)
print(pairedXY_metrics)

In [None]:
from spf.filters.ekf_dualradioXY_filter import (
    run_and_plot_dualradioXY_EKF,
)

fig = run_and_plot_dualradioXY_EKF(ds, trajectory=pairedXY_trajectory)