In [None]:
import tempfile
from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml
from spf.dataset.spf_dataset import v5spfdataset

n = 1025
noise = 0.3
nthetas = 65
orbits = 4

tmpdir = tempfile.TemporaryDirectory()
tmpdirname = "/tmp/"  # tmpdir.name
temp_ds_fn = f"{tmpdirname}/sample_dataset_for_ekf_n{n}_noise{noise}"

In [None]:
from spf.dataset.fake_dataset import create_empirical_dist_for_datasets


create_fake_dataset(
    filename=temp_ds_fn, yaml_config_str=fake_yaml, n=n, noise=noise, orbits=orbits
)

In [None]:
empirical_pkl_fn = create_empirical_dist_for_datasets(
    datasets=[temp_ds_fn], precompute_cache=tmpdirname, nthetas=65
)

In [None]:
from spf.model_training_and_inference.models.single_point_networks_inference import (
    get_md5_of_file,
)


# get_md5_of_file(
#     "/home/mouse9911/gits/spf/nov2_checkpoints/paired_checkpoints_inputdo0p3/best.pth"
# )

In [None]:
ds = v5spfdataset(
    temp_ds_fn,
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache=tmpdirname,
    empirical_data_fn=empirical_pkl_fn,
    paired=True,
    skip_fields=set(["signal_matrix"]),
)

In [None]:
from functools import cache
import numpy as np
import torch
from spf.filters.filters import (
    ParticleFilter,
    add_noise,
    fix_particles_single,
    single_radio_mse_theta_metrics,
    theta_phi_to_bins,
    theta_phi_to_p_vec,
)
from spf.rf import reduce_theta_to_positive_y, torch_pi_norm_pi
from spf.dataset.spf_dataset import v5_collate_keys_fast
from spf.model_training_and_inference.models.single_point_networks_inference import (
    convert_datasets_config_to_inference,
    get_inference_on_ds,
    load_model_and_config_from_config_fn_and_checkpoint,
)
import torch
from spf.scripts.train_single_point import (
    global_config_to_keys_used,
    load_config_from_fn,
    load_dataloaders,
)

from tqdm import tqdm


class PFSingleThetaSingleRadioNN(ParticleFilter):
    """
    particle state is [ theta, dtheta/dt]
    """

    def __init__(
        self, ds, rx_idx, checkpoint_fn, config_fn, inference_cache=None, device="cpu"
    ):
        self.ds = ds
        self.rx_idx = rx_idx
        self.generator = torch.Generator()
        self.generator.manual_seed(0)

        checkpoint_config = load_config_from_fn(config_fn)
        assert (
            self.ds.empirical_data_fn
            == checkpoint_config["datasets"]["empirical_data_fn"]
        )

        if not self.ds.temp_file:
            # cache model results
            self.cached_model_inference = torch.as_tensor(
                get_inference_on_ds(
                    ds_fn=ds.zarr_fn,
                    config_fn=config_fn,
                    checkpoint_fn=checkpoint_fn,
                    device=device,
                    inference_cache=inference_cache,
                    batch_size=64,
                    workers=0,
                    precompute_cache=ds.precompute_cache,
                    crash_if_not_cached=False,
                )["single"]
            )
        else:
            # load the model and such
            self.model, self.model_config = (
                load_model_and_config_from_config_fn_and_checkpoint(
                    config_fn=config_fn, checkpoint_fn=checkpoint_fn, device=device
                )
            )
            self.model.eval()

            self.model_datasets_config = convert_datasets_config_to_inference(
                self.model_config["datasets"],
                ds_fn=ds.zarr_fn,
                precompute_cache=self.ds.precompute_cache,
            )

            self.model_optim_config = {"device": device, "dtype": torch.float32}

            self.model_keys_to_get = global_config_to_keys_used(
                global_config=self.model_config["global"]
            )

    def model_inference_at_observation_idx(self, idx):
        if not self.ds.temp_file:
            return self.cached_model_inference[idx]

        z = v5_collate_keys_fast(self.model_keys_to_get, [self.ds[idx]]).to(
            self.model_optim_config["device"]
        )
        with torch.no_grad():
            return self.model(z)["single"].cpu()

    def observation(self, idx):
        return self.model_inference_at_observation_idx(idx)[self.rx_idx, 0]

    def fix_particles(self):
        self.particles = fix_particles_single(self.particles)
        # self.particles[:, 0] = reduce_theta_to_positive_y(self.particles[:, 0])

    def predict(self, our_state, dt, noise_std):
        if noise_std is None:
            noise_std = torch.tensor([[0.1, 0.001]])
        self.particles[:, 0] += dt * self.particles[:, 1]
        add_noise(self.particles, noise_std=noise_std, generator=self.generator)

    def update(self, z):
        # z is not the raw observation, but the processed model output
        theta_bin = theta_phi_to_bins(self.particles[:, 0], nbins=z.shape[0])
        prob_theta_given_observation = torch.take(z, theta_bin)

        self.weights *= prob_theta_given_observation
        self.weights += 1.0e-30  # avoid round-off to zero
        self.weights /= torch.sum(self.weights)  # normalize

    def metrics(self, trajectory):
        return single_radio_mse_theta_metrics(
            trajectory, self.ds.ground_truth_thetas[self.rx_idx]
        )

    def trajectory(self, **kwargs):
        trajectory = super().trajectory(**kwargs)
        for x in trajectory:
            x["theta"] = x["mu"][0]
            x["P_theta"] = x["var"][0]
        return trajectory

In [None]:
import numpy as np

ds_with_model_empirical = v5spfdataset(
    temp_ds_fn,
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache=tmpdirname,
    empirical_data_fn="/home/mouse9911/gits/spf/empirical_dists/full.pkl",
    # empirical_data_fn=empirical_pkl_fn,
    paired=True,
    skip_fields=set(["signal_matrix"]),
)
# def __init__(self, ds, rx_idx, checkpoint_fn, config_fn, device="cpu"):#
pfs = [
    PFSingleThetaSingleRadioNN(
        ds_with_model_empirical,
        rx_idx=rx_idx,
        # config_fn="/home/mouse9911/gits/spf/nov2_checkpoints/paired_checkpoints_inputdo0p3/config.yml",
        # checkpoint_fn="/home/mouse9911/gits/spf/nov2_checkpoints/paired_checkpoints_inputdo0p3/best.pth",
        config_fn="/home/mouse9911/gits/spf/checkpoints/march16/paired_wd0p02_gains_vehicle_0p2dropout_noroverbounceREAL_lowdrop_x2/config.yml",
        checkpoint_fn="/home/mouse9911/gits/spf/checkpoints/march16/paired_wd0p02_gains_vehicle_0p2dropout_noroverbounceREAL_lowdrop_x2/best.pth",
        # inference_cache="/mnt/4tb_ssd/inference_cache/",
        inference_cache="/mnt/md2/cache/inference",
        device="cpu",
    )
    for rx_idx in range(2)
]
trajectories = [
    pf.trajectory(
        mean=torch.tensor([[0, 0]]),
        std=torch.tensor([[1, 0.1]]),
        return_particles=False,
        debug=True,
        N=512 * 16 * 8,
    )
    for pf in pfs
]
metrics = [
    pf.metrics(trajectory=trajectory) for pf, trajectory in zip(pfs, trajectories)
]
metrics

In [None]:
from spf.filters.filters import dual_radio_mse_theta_metrics


class PFSingleThetaDualRadioNN(ParticleFilter):
    def __init__(
        self, ds, checkpoint_fn, config_fn, inference_cache=None, device="cpu"
    ):
        self.ds = ds

        self.generator = torch.Generator()
        self.generator.manual_seed(0)

        checkpoint_config = load_config_from_fn(config_fn)
        assert (
            self.ds.empirical_data_fn
            == checkpoint_config["datasets"]["empirical_data_fn"]
        )

        if not self.ds.temp_file:
            # cache model results
            self.cached_model_inference = torch.as_tensor(
                get_inference_on_ds(
                    ds_fn=ds.zarr_fn,
                    config_fn=config_fn,
                    checkpoint_fn=checkpoint_fn,
                    device=device,
                    inference_cache=inference_cache,
                    batch_size=64,
                    workers=0,
                    precompute_cache=ds.precompute_cache,
                    crash_if_not_cached=False,
                )["paired"]
            )
        else:
            # load the model and such
            self.model, self.model_config = (
                load_model_and_config_from_config_fn_and_checkpoint(
                    config_fn=config_fn, checkpoint_fn=checkpoint_fn, device=device
                )
            )
            self.model.eval()

            self.model_datasets_config = convert_datasets_config_to_inference(
                self.model_config["datasets"],
                ds_fn=ds.zarr_fn,
                precompute_cache=self.ds.precompute_cache,
            )

            self.model_optim_config = {"device": device, "dtype": torch.float32}

            self.model_keys_to_get = global_config_to_keys_used(
                global_config=self.model_config["global"]
            )

    def model_inference_at_observation_idx(self, idx):
        if not self.ds.temp_file:
            return self.cached_model_inference[idx]

        z = v5_collate_keys_fast(self.model_keys_to_get, [self.ds[idx]]).to(
            self.model_optim_config["device"]
        )
        with torch.no_grad():
            return self.model(z)["paired"].cpu()

    def observation(self, idx):
        # even though the model outputs one paired dist for each reciever
        # they should be identical
        return self.model_inference_at_observation_idx(idx)[0, 0]

    def fix_particles(self):
        self.particles[:, 0] = torch_pi_norm_pi(self.particles[:, 0])

    def predict(self, our_state, dt, noise_std):
        if noise_std is None:
            noise_std = torch.tensor([[0.1, 0.001]])
        self.particles[:, 0] += dt * self.particles[:, 1]
        add_noise(self.particles, noise_std=noise_std, generator=self.generator)

    def update(self, z):
        #
        # z is not the raw observation, but the processed model output
        theta_bin = theta_phi_to_bins(self.particles[:, 0], nbins=z.shape[0])
        prob_theta_given_observation = torch.take(z, theta_bin)

        self.weights *= prob_theta_given_observation
        self.weights += 1.0e-30  # avoid round-off to zero
        self.weights /= torch.sum(self.weights)  # normalize

    def metrics(self, trajectory):
        return dual_radio_mse_theta_metrics(
            trajectory, self.ds.craft_ground_truth_thetas
        )

    def trajectory(self, **kwargs):
        trajectory = super().trajectory(**kwargs)
        for x in trajectory:
            x["craft_theta"] = x["mu"][0]
            x["P_theta"] = x["var"][0]
        return trajectory

In [None]:
ds_with_model_empirical = v5spfdataset(
    # "/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_08_21_03_09_04_nRX2_rx_circle_spacing0p05075.zarr",
    "/mnt/md2/cache/nosig_data/wallarrayv3_2024_08_21_03_09_04_nRX2_rx_circle_spacing0p05075.zarr",
    nthetas=nthetas,
    ignore_qc=True,
    # precompute_cache="/home/mouse9911/precompute_cache_chunk16_sept/",
    precompute_cache="/mnt/md2/cache/precompute_cache_3p5_chunk1/",
    empirical_data_fn="/home/mouse9911/gits/spf/empirical_dists/full.pkl",
    paired=True,
    skip_fields=set(["signal_matrix"]),
)

In [None]:
import numpy as np

ds_with_model_empirical = v5spfdataset(
    # "/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_08_21_03_09_04_nRX2_rx_circle_spacing0p05075.zarr",
    "/mnt/md2/cache/nosig_data/wallarrayv3_2024_08_21_03_09_04_nRX2_rx_circle_spacing0p05075.zarr",
    nthetas=nthetas,
    ignore_qc=True,
    # precompute_cache="/home/mouse9911/precompute_cache_chunk16_sept/",
    precompute_cache="/mnt/md2/cache/precompute_cache_3p5_chunk1/",
    empirical_data_fn="/home/mouse9911/gits/spf/empirical_dists/full.pkl",
    paired=True,
    skip_fields=set(["signal_matrix"]),
)
# ds_with_model_empirical = v5spfdataset(
#     temp_ds_fn,
#     nthetas=nthetas,
#     ignore_qc=True,
#     precompute_cache=tmpdirname,
#     empirical_data_fn="/home/mouse9911/gits/spf/empirical_dists/full.pkl",
#     paired=True,
#     skip_fields=set(["signal_matrix"]),
# )
# def __init__(self, ds, rx_idx, checkpoint_fn, config_fn, device="cpu"):#
pf = PFSingleThetaDualRadioNN(
    ds_with_model_empirical,
    # config_fn="/home/mouse9911/gits/spf/nov2_checkpoints/paired_checkpoints_inputdo0p3/config.yml",
    # checkpoint_fn="/home/mouse9911/gits/spf/nov2_checkpoints/paired_checkpoints_inputdo0p3/best.pth",
    # inference_cache="/mnt/4tb_ssd/inference_cache/",
    device="cpu",
    config_fn="/home/mouse9911/gits/spf/checkpoints/march16/paired_wd0p02_gains_vehicle_0p2dropout_noroverbounceREAL_lowdrop_x2/config.yml",
    checkpoint_fn="/home/mouse9911/gits/spf/checkpoints/march16/paired_wd0p02_gains_vehicle_0p2dropout_noroverbounceREAL_lowdrop_x2/best.pth",
    inference_cache="/mnt/md2/cache/inference",
)

trajectory = pf.trajectory(
    mean=torch.tensor([[0, 0]]),
    std=torch.tensor([[1, 0.1]]),
    return_particles=False,
    debug=True,
    N=512 * 16 * 8,
)
metrics = pf.metrics(trajectory=trajectory)
metrics

In [None]:
import matplotlib.pyplot as plt


def plot_traj(ds, traj_paired):

    fig, ax = plt.subplots(2, 1, figsize=(10, 10))

    ax[1].axhline(y=torch.pi / 2, ls=":", c=(0.7, 0.7, 0.7))
    ax[1].axhline(y=-torch.pi / 2, ls=":", c=(0.7, 0.7, 0.7))
    n = len(traj_paired)
    colors = ["blue", "orange"]
    for rx_idx in (0, 1):
        ax[0].scatter(
            range(min(n, ds.mean_phase[f"r{rx_idx}"].shape[0])),
            ds.mean_phase[f"r{rx_idx}"][:n],
            label=f"r{rx_idx} estimated phi",
            s=1.0,
            alpha=0.1,
            color=colors[rx_idx],
        )
        ax[0].plot(
            ds.ground_truth_phis[rx_idx][:n],
            color=colors[rx_idx],
            label=f"r{rx_idx} perfect phi",
            linestyle="dashed",
        )

    ax[1].plot(
        # torch_pi_norm_pi(ds[0][0]["craft_y_rad"][0]),
        torch_pi_norm_pi(ds.craft_ground_truth_thetas),
        label="craft gt theta",
        linestyle="dashed",
    )

    xs = torch.hstack([x["mu"][0] for x in traj_paired])
    stds = torch.sqrt(torch.hstack([x["var"][0] for x in traj_paired]))

    ax[1].fill_between(
        torch.arange(xs.shape[0]),
        xs - stds,
        xs + stds,
        label="PF-std",
        color="red",
        alpha=0.2,
    )
    ax[1].scatter(range(xs.shape[0]), xs, label="PF-x", color="orange", s=0.5)

    ax[0].set_ylabel("radio phi")

    ax[0].legend()
    ax[0].set_title(f"Radio 0 & 1")
    ax[1].legend()
    ax[1].set_xlabel("time step")
    ax[1].set_ylabel("Theta between target and receiver craft")
    return fig


plot_traj(ds_with_model_empirical, traj_paired=trajectory)

In [None]:
ds.craft_ground_truth_thetas.shape, ds.ground_truth_phis.shape

In [None]:
len(trajectory)

In [None]:
import numpy as np

from spf.filters.particle_dualradio_filter import PFSingleThetaDualRadio

ds_with_model_empirical = v5spfdataset(
    "/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_08_21_03_09_04_nRX2_rx_circle_spacing0p05075.zarr",
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache="/home/mouse9911/precompute_cache_chunk16_sept/",
    empirical_data_fn="/home/mouse9911/gits/spf/empirical_dists/full.pkl",
    paired=True,
    skip_fields=set(["signal_matrix"]),
)
# ds_with_model_empirical = v5spfdataset(
#     temp_ds_fn,
#     nthetas=nthetas,
#     ignore_qc=True,
#     precompute_cache=tmpdirname,
#     empirical_data_fn="/home/mouse9911/gits/spf/empirical_dists/full.pkl",
#     paired=True,
#     skip_fields=set(["signal_matrix"]),
# )
# def __init__(self, ds, rx_idx, checkpoint_fn, config_fn, device="cpu"):#
pf = PFSingleThetaDualRadio(
    ds_with_model_empirical,
)

trajectory = pf.trajectory(
    mean=torch.tensor([[0, 0]]),
    std=torch.tensor([[1, 0.1]]),
    return_particles=False,
    debug=True,
    N=512 * 16 * 8,
)
metrics = pf.metrics(trajectory=trajectory)
metrics

In [None]:
pf.observation(16)

In [None]:
pfs[0].particles
pf = pfs[0]
theta_phi_to_p_vec(
    pf.particles[:, 0],
    pf.observation(16),
    pf.cached_empirical_dist,
).shape, pf.weights.shape, pf.particles.shape

In [None]:
model(single_example)["paired"].shape

In [None]:
from spf.dataset.spf_dataset import v5_collate_keys_fast
from spf.filters.filters import theta_phi_to_bins


single_example = v5_collate_keys_fast(keys_to_get, [ds[16]]).to(optim_config["device"])
dist = model(single_example)["paired"][0, 0].cpu()
print(pf.cached_empirical_dist.shape, dist.shape[0])
particles = pf.particles
theta_bin = theta_phi_to_bins(particles[:, 0], nbins=dist.shape[0])
print(theta_bin)
x = torch.take(dist, theta_bin)
x

In [None]:
# dev particle filter from nn
from spf.dataset.spf_dataset import v5_collate_keys_fast
from spf.model_training_and_inference.models.single_point_networks_inference import (
    convert_datasets_config_to_inference,
    load_model_and_config_from_config_fn_and_checkpoint,
)
import torch
from spf.scripts.train_single_point import global_config_to_keys_used, load_dataloaders

from tqdm import tqdm

# load model
model, config = load_model_and_config_from_config_fn_and_checkpoint(
    config_fn="/home/mouse9911/gits/spf/nov2_checkpoints/paired_checkpoints_inputdo0p3/config.yml",
    checkpoint_fn="/home/mouse9911/gits/spf/nov2_checkpoints/paired_checkpoints_inputdo0p3/best.pth",
)

datasets_config = convert_datasets_config_to_inference(
    config["datasets"],
    ds_fn="/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_08_21_10_30_58_nRX2_bounce_spacing0p05075.zarr",
)

optim_config = {"device": "cuda", "dtype": torch.float32}

ds = v5spfdataset(
    datasets_config["train_paths"][0],
    nthetas=config["global"]["nthetas"],
    ignore_qc=True,
    precompute_cache=datasets_config["precompute_cache"],
    empirical_data_fn=datasets_config["empirical_data_fn"],
    paired=True,
    skip_fields=set(["signal_matrix"]),
)

keys_to_get = global_config_to_keys_used(global_config=config["global"])
outputs = []
with torch.no_grad():
    for idx in tqdm(range(min(20, len(ds)))):
        single_example = v5_collate_keys_fast(keys_to_get, [ds[idx]]).to(
            optim_config["device"]
        )
        outputs.append(model(single_example))
results = {
    "single": torch.vstack([output["single"].unsqueeze(0) for output in outputs])
}
if "paired" in outputs[0]:
    results["paired"] = torch.vstack(
        [output["paired"].unsqueeze(0) for output in outputs]
    )

In [None]:
config["global"]

In [None]:
optim_config

In [None]:
from matplotlib import pyplot as plt


def run_and_plot_single_radio_PF(ds, trajectories):

    fig, ax = plt.subplots(3, 2, figsize=(10, 15))

    for rx_idx in [0, 1]:  # [0, 1]:
        ax[1, rx_idx].axhline(y=np.pi / 2, ls=":", c=(0.7, 0.7, 0.7))
        ax[1, rx_idx].axhline(y=-np.pi / 2, ls=":", c=(0.7, 0.7, 0.7))

        trajectory = trajectories[rx_idx]

        zs = [x["observation"] for x in trajectory]
        # trajectory, jacobian, zs = trajectory_for_phi(rx_idx, ds)
        zs = np.array(zs)
        n = len(trajectory)
        ax[0, rx_idx].scatter(
            range(min(n, ds.mean_phase[f"r{rx_idx}"].shape[0])),
            ds.mean_phase[f"r{rx_idx}"][:n],
            label=f"r{rx_idx} estimated phi",
            s=1.0,
            alpha=1.0,
            color="red",
        )
        ax[0, rx_idx].plot(ds.ground_truth_phis[rx_idx][:n], label="perfect phi")

        ax[0, rx_idx].plot(zs, label="zs")
        ax[1, rx_idx].plot(
            [ds[idx][rx_idx]["ground_truth_theta"] for idx in range(min(n, len(ds)))],
            label=f"r{rx_idx} gt theta",
        )
        reduced_gt_theta = np.array(
            [
                reduce_theta_to_positive_y(ds[idx][rx_idx]["ground_truth_theta"])
                for idx in range(min(n, len(ds)))
            ]
        )
        ax[1, rx_idx].plot(
            reduced_gt_theta,
            label=f"r{rx_idx} reduced gt theta",
        )

        xs = np.array([x["theta"] for x in trajectory])
        stds = np.sqrt(np.array([x["P_theta"] for x in trajectory]))
        zscores = (xs - reduced_gt_theta) / (stds + 0.0001)
        print(zscores)

        ax[1, rx_idx].plot(xs, label="EKF-x", color="orange")
        ax[1, rx_idx].fill_between(
            np.arange(xs.shape[0]),
            xs - stds,
            xs + stds,
            label="EKF-std",
            color="orange",
            alpha=0.2,
        )

        ax[0, rx_idx].set_ylabel("radio phi")

        ax[0, rx_idx].legend()
        ax[0, rx_idx].set_title(f"Radio {rx_idx}")
        ax[1, rx_idx].legend()
        ax[1, rx_idx].set_xlabel("time step")
        ax[1, rx_idx].set_ylabel("radio theta")

        ax[2, rx_idx].hist(zscores.reshape(-1), bins=25)
    return fig

In [None]:
run_and_plot_single_radio_PF(ds, trajectories=trajectories)