In [None]:
from spf.filters.filters import dual_radio_mse_theta_metrics
import torch
from spf.filters.filters import (
    ParticleFilter,
    add_noise,
    theta_phi_to_bins,
)
from spf.dataset.spf_dataset import v5spfdataset
from spf.rf import rotate_dist, torch_pi_norm_pi
from spf.dataset.spf_dataset import v5_collate_keys_fast
from spf.model_training_and_inference.models.single_point_networks_inference import (
    convert_datasets_config_to_inference,
    get_inference_on_ds,
    load_model_and_config_from_config_fn_and_checkpoint,
)
import torch
from spf.scripts.train_single_point import (
    global_config_to_keys_used,
    load_config_from_fn,
)

nthetas = 65


def cached_model_inference_to_absolute_north(ds, cached_model_inference):
    _rx_heading = torch.concatenate(
        [
            torch_pi_norm_pi(
                ds.cached_keys[ridx]["rx_heading_in_pis"][:, None] * torch.pi
            )
            for ridx in range(2)
        ],
        dim=1,
    ).reshape(-1, 1)
    _cached_model_inference = cached_model_inference.reshape(-1, 65)
    cached_model_inference_rotated = rotate_dist(
        _cached_model_inference,
        rotations=_rx_heading,
    ).reshape(cached_model_inference.shape)
    return cached_model_inference_rotated


class PFSingleThetaDualRadioNN(ParticleFilter):
    def __init__(
        self,
        ds,
        checkpoint_fn,
        config_fn,
        inference_cache=None,
        device="cpu",
        absolute=False,
    ):
        self.ds = ds
        self.absolute = absolute
        self.generator = torch.Generator()
        self.generator.manual_seed(0)

        checkpoint_config = load_config_from_fn(config_fn)
        assert (
            self.ds.empirical_data_fn
            == checkpoint_config["datasets"]["empirical_data_fn"]
        )

        if not self.ds.temp_file:
            # cache model results
            self.cached_model_inference = torch.as_tensor(
                get_inference_on_ds(
                    ds_fn=ds.zarr_fn,
                    config_fn=config_fn,
                    checkpoint_fn=checkpoint_fn,
                    device=device,
                    inference_cache=inference_cache,
                    batch_size=64,
                    workers=0,
                    precompute_cache=ds.precompute_cache,
                    crash_if_not_cached=False,
                )["paired"]
            )
            if self.absolute:
                self.cached_model_inference = cached_model_inference_to_absolute_north(
                    ds, self.cached_model_inference
                )
        else:
            # load the model and such
            self.model, self.model_config = (
                load_model_and_config_from_config_fn_and_checkpoint(
                    config_fn=config_fn, checkpoint_fn=checkpoint_fn, device=device
                )
            )
            self.model.eval()

            self.model_datasets_config = convert_datasets_config_to_inference(
                self.model_config["datasets"],
                ds_fn=ds.zarr_fn,
                precompute_cache=self.ds.precompute_cache,
            )

            self.model_optim_config = {"device": device, "dtype": torch.float32}

            self.model_keys_to_get = global_config_to_keys_used(
                global_config=self.model_config["global"]
            )
            assert not self.absolute  # this needs to be implemented

    def model_inference_at_observation_idx(self, idx):
        if not self.ds.temp_file:
            return self.cached_model_inference[idx]

        z = v5_collate_keys_fast(self.model_keys_to_get, [self.ds[idx]]).to(
            self.model_optim_config["device"]
        )
        with torch.no_grad():
            return self.model(z)["paired"].cpu()

    def observation(self, idx):
        # even though the model outputs one paired dist for each reciever
        # they should be identical
        return self.model_inference_at_observation_idx(idx)[0, 0]

    def fix_particles(self):
        self.particles[:, 0] = torch_pi_norm_pi(self.particles[:, 0])

    def predict(self, our_state, dt, noise_std):
        if noise_std is None:
            noise_std = torch.tensor([[0.1, 0.001]])
        self.particles[:, 0] += dt * self.particles[:, 1]
        add_noise(self.particles, noise_std=noise_std, generator=self.generator)

    def update(self, z):
        #
        # z is not the raw observation, but the processed model output
        theta_bin = theta_phi_to_bins(self.particles[:, 0], nbins=z.shape[0])
        prob_theta_given_observation = torch.take(z, theta_bin)

        self.weights *= prob_theta_given_observation
        self.weights += 1.0e-30  # avoid round-off to zero
        self.weights /= torch.sum(self.weights)  # normalize

    def metrics(self, trajectory):
        return dual_radio_mse_theta_metrics(
            trajectory,
            (
                self.ds.craft_ground_truth_thetas
                if not self.absolute
                else self.ds.absolute_thetas.mean(axis=0)
            ),
        )

    def trajectory(self, **kwargs):
        trajectory = super().trajectory(**kwargs)
        for x in trajectory:
            x["craft_theta"] = x["mu"][0]
            x["P_theta"] = x["var"][0]
        return trajectory

In [None]:
import matplotlib.pyplot as plt


def plot_traj(ds, traj_paired):

    fig, ax = plt.subplots(2, 1, figsize=(10, 10))

    ax[1].axhline(y=torch.pi / 2, ls=":", c=(0.7, 0.7, 0.7))
    ax[1].axhline(y=-torch.pi / 2, ls=":", c=(0.7, 0.7, 0.7))
    n = len(traj_paired)
    colors = ["blue", "orange"]
    for rx_idx in (0, 1):
        ax[0].scatter(
            range(min(n, ds.mean_phase[f"r{rx_idx}"].shape[0])),
            ds.mean_phase[f"r{rx_idx}"][:n],
            label=f"r{rx_idx} estimated phi",
            s=1.0,
            alpha=0.1,
            color=colors[rx_idx],
        )
        ax[0].plot(
            ds.ground_truth_phis[rx_idx][:n],
            color=colors[rx_idx],
            label=f"r{rx_idx} perfect phi",
            linestyle="dashed",
        )

    ax[1].plot(
        # torch_pi_norm_pi(ds[0][0]["craft_y_rad"][0]),
        torch_pi_norm_pi(ds.craft_ground_truth_thetas),
        label="craft gt theta",
        linestyle="dashed",
    )
    ax[1].plot(
        torch_pi_norm_pi(ds.absolute_thetas.mean(axis=0)),
        label="absolute theta",
        linestyle="dashed",
    )

    xs = torch.hstack([x["mu"][0] for x in traj_paired])
    stds = torch.sqrt(torch.hstack([x["var"][0] for x in traj_paired]))

    ax[1].fill_between(
        torch.arange(xs.shape[0]),
        xs - stds,
        xs + stds,
        label="PF-std",
        color="red",
        alpha=0.2,
    )
    ax[1].scatter(range(xs.shape[0]), xs, label="PF-x", color="orange", s=0.5)

    ax[0].set_ylabel("radio phi")

    ax[0].legend()
    ax[0].set_title(f"Radio 0 & 1")
    ax[1].legend()
    ax[1].set_xlabel("time step")
    ax[1].set_ylabel("Theta between target and receiver craft")
    return fig

In [None]:
config_fn = "/home/mouse9911/gits/spf/checkpoints/march16/paired_wd0p02_gains_vehicle_0p2dropout_noroverbounceREAL_lowdrop_x2/config.yml"
checkpoint_fn = "/home/mouse9911/gits/spf/checkpoints/march16/paired_wd0p02_gains_vehicle_0p2dropout_noroverbounceREAL_lowdrop_x2/best.pth"

empirical_data_fn = "/home/mouse9911/gits/spf/empirical_dists/full.pkl"
precompute_cache = "/mnt/md2/cache/precompute_cache_3p5_chunk1/"
inference_cache = "/mnt/md2/cache/inference"

In [None]:
ds = v5spfdataset(
    "/mnt/md2/cache/nosig_data/wallarrayv3_2024_08_21_03_09_04_nRX2_rx_circle_spacing0p05075.zarr",
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache=precompute_cache,
    empirical_data_fn=empirical_data_fn,
    paired=True,
    skip_fields=set(["signal_matrix"]),
)
pf = PFSingleThetaDualRadioNN(
    ds,
    device="cpu",
    config_fn=config_fn,
    checkpoint_fn=checkpoint_fn,
    inference_cache=inference_cache,
    absolute=True,
)

trajectory = pf.trajectory(
    mean=torch.tensor([[0, 0]]),
    std=torch.tensor([[1, 0.1]]),
    return_particles=False,
    debug=True,
    N=512 * 16,  # * 8,
)

In [None]:
plot_traj(ds, traj_paired=trajectory)
metrics = pf.metrics(trajectory=trajectory)

In [None]:
ds = v5spfdataset(
    # "/mnt/md2/cache/nosig_data/wallarrayv3_2024_08_21_03_09_04_nRX2_rx_circle_spacing0p05075.zarr",
    "/mnt/md2/cache/nosig_data/rover_2025_03_02_22_20_28_nRX2_center_spacing0p043_tag_RO3.rover_2025_03_02_22_19_58_nRX1_circle_spacing0p05075_tag_RO2.zarr",
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache=precompute_cache,
    empirical_data_fn=empirical_data_fn,
    paired=True,
    skip_fields=set(["signal_matrix"]),
)
pf = PFSingleThetaDualRadioNN(
    ds,
    device="cpu",
    config_fn=config_fn,
    checkpoint_fn=checkpoint_fn,
    inference_cache=inference_cache,
    absolute=True,
)

trajectory = pf.trajectory(
    mean=torch.tensor([[0, 0]]),
    std=torch.tensor([[1, 0.1]]),
    return_particles=False,
    debug=True,
    N=512 * 16,
)

In [None]:
plot_traj(ds, traj_paired=trajectory)
metrics = pf.metrics(trajectory=trajectory)

In [None]:
ds = v5spfdataset(
    # "/mnt/md2/cache/nosig_data/wallarrayv3_2024_08_21_03_09_04_nRX2_rx_circle_spacing0p05075.zarr",
    "/mnt/md2/cache/nosig_data/rover_2025_03_22_19_21_42_nRX2_diamond_spacing0p035_tag_RO1.rover_2025_03_22_19_21_01_nRX1_circle_spacing0p05075_tag_RO2.zarr",
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache=precompute_cache,
    empirical_data_fn=empirical_data_fn,
    paired=True,
    skip_fields=set(["signal_matrix"]),
)
pf = PFSingleThetaDualRadioNN(
    ds,
    device="cpu",
    config_fn=config_fn,
    checkpoint_fn=checkpoint_fn,
    inference_cache=inference_cache,
    absolute=True,
)

trajectory = pf.trajectory(
    mean=torch.tensor([[0, 0]]),
    std=torch.tensor([[1, 0.1]]),
    # noise_std=torch.tensor([[0.01, 0.001]]),
    return_particles=False,
    debug=True,
    N=512 * 16 * 8,
)

In [None]:
metrics = pf.metrics(trajectory=trajectory)
print(metrics)
plot_traj(ds, traj_paired=trajectory)

In [None]:
cached_model_inference = torch.as_tensor(
    get_inference_on_ds(
        ds_fn=ds.zarr_fn,
        config_fn=config_fn,
        checkpoint_fn=checkpoint_fn,
        device="cpu",
        inference_cache=inference_cache,
        batch_size=64,
        workers=0,
        precompute_cache=ds.precompute_cache,
        crash_if_not_cached=False,
    )["paired"]
)

In [None]:
plt.imshow(cached_model_inference[:100, 1, 0].T)

In [None]:
from spf.rf import rotate_dist

# correct predictions to be relative to true north
_rx_heading = torch.concatenate(
    [
        torch_pi_norm_pi(ds.cached_keys[ridx]["rx_heading_in_pis"][:, None] * torch.pi)
        for ridx in range(2)
    ],
    dim=1,
).reshape(-1, 1)
_cached_model_inference = cached_model_inference.reshape(-1, 65)
cached_model_inference_rotated = rotate_dist(
    _cached_model_inference,
    rotations=_rx_heading,
).reshape(cached_model_inference.shape)

In [None]:
plt.imshow(cached_model_inference_rotated[:100, 1, 0].T)

In [None]:
ds.cached_keys[0]["rx_heading_in_pis"][:, None].shape

In [None]:
torch.concatenate(
    [
        ds.cached_keys[ridx]["rx_heading_in_pis"][:, None] * 0 + ridx
        for ridx in range(2)
    ],
    dim=1,
).shape

In [None]:
cached_model_inference.shape

In [None]:
a = cached_model_inference.clone()
a[0, 0] = 0
a[0, 1] = 1
a.reshape(-1, 65)[1]