In [None]:
import os
import pickle
from spf.dataset.fake_dataset import (
    create_empirical_dist_for_datasets,
    create_fake_dataset,
    fake_yaml,
)
from spf.dataset.spf_dataset import v5spfdataset
from spf.scripts.create_empirical_p_dist import (
    apply_symmetry_rules_to_heatmap,
    get_heatmap,
)


dir = "./temp_delme"
# os.makedirs(dir)


def noise1_n128_obits2():
    n = 128
    fn = dir + f"/perfect_circle_n{n}_noise0"
    create_fake_dataset(
        filename=fn, yaml_config_str=fake_yaml, n=n, noise=0.3, orbits=2
    )

    v5spfdataset(  # make sure everything gets segmented here
        fn,
        nthetas=65,
        ignore_qc=True,
        precompute_cache=dir,
        paired=True,
        skip_fields=set(["signal_matrix"]),
    )

    empirical_pkl_fn = create_empirical_dist_for_datasets(
        datasets=[f"{fn}.zarr"], precompute_cache=dir, nthetas=50
    )
    return dir, empirical_pkl_fn, fn


def heatmap(ds_fn):
    ds = v5spfdataset(
        ds_fn,
        precompute_cache=dir,
        nthetas=65,
        skip_fields=set(["signal_matrix"]),
        paired=True,
        ignore_qc=True,
        gpu=False,
    )
    heatmap = get_heatmap([ds], bins=50)
    heatmap = apply_symmetry_rules_to_heatmap(heatmap)
    full_p_fn = f"{dir}/full_p.pkl"
    pickle.dump({"full_p": heatmap}, open(full_p_fn, "wb"))
    return full_p_fn


_, empirical_pkl_fn, ds_fn = noise1_n128_obits2()

In [None]:
full_p_fn = heatmap(ds_fn)

In [None]:
empirical_pkl_fn, full_p_fn

In [None]:
a = pickle.load(open(empirical_pkl_fn, "rb"))
b = pickle.load(open(full_p_fn, "rb"))

In [None]:
import matplotlib.pyplot as plt

plt.imshow(a["0.050750"]["r"]["sym"]), a["0.050750"]["r"]["sym"].max(), a["0.050750"][
    "r"
]["sym"].shape

In [None]:
plt.imshow(b["full_p"].T), b["full_p"].max(), b["full_p"].shape

In [None]:
import torch
from spf.filters.filters import (
    ParticleFilter,
    add_noise,
    fix_particles_single,
    theta_phi_to_p_vec,
)
from spf.rf import reduce_theta_to_positive_y, torch_pi_norm_pi


class PFSingleThetaSingleRadio(ParticleFilter):
    """
    particle state is [ theta, dtheta/dt]
    """

    def __init__(self, ds, rx_idx, full_p_fn):
        self.ds = ds
        self.rx_idx = rx_idx
        self.full_p = torch.as_tensor(pickle.load(open(full_p_fn, "rb"))["full_p"])
        a = 1

    def observation(self, idx):
        return self.ds.mean_phase[f"r{self.rx_idx}"][idx]

    def fix_particles(self):
        self.particles = fix_particles_single(self.particles)
        # self.particles[:, 0] = reduce_theta_to_positive_y(self.particles[:, 0])

    def predict(self, our_state, dt, noise_std):
        if noise_std is None:
            noise_std = torch.tensor([[0.1, 0.001]])
        self.particles[:, 0] += dt * self.particles[:, 1]
        add_noise(self.particles, noise_std=noise_std)

    def update(self, z):
        print("DIFF")
        self.weights *= theta_phi_to_p_vec(
            self.particles[:, 0],
            z,
            self.ds.get_empirical_dist(self.rx_idx).T,
            # self.particles[:, 0],
            # z,
            # self.full_p,
        )
        self.weights += 1.0e-30  # avoid round-off to zero
        self.weights /= torch.sum(self.weights)  # normalize

    def metrics(self, trajectory):
        pred_theta = torch.hstack([x["mu"][0] for x in trajectory])
        ground_truth_reduced_theta = torch.as_tensor(
            reduce_theta_to_positive_y(self.ds.ground_truth_thetas[self.rx_idx])
        )
        return {
            "mse_theta": (
                torch_pi_norm_pi(ground_truth_reduced_theta - pred_theta) ** 2
            )
            .mean()
            .item()
        }

    def trajectory(self, **kwargs):
        trajectory = super().trajectory(**kwargs)
        for x in trajectory:
            x["theta"] = x["mu"][0]
            x["P_theta"] = x["var"][0]
        return trajectory

In [None]:
from spf.model_training_and_inference.models.run_filters_on_data import (
    run_PF_single_theta_single_radio,
)


ds = v5spfdataset(
    ds_fn,
    precompute_cache=dir,
    nthetas=65,
    skip_fields=set(["signal_matrix"]),
    empirical_data_fn=empirical_pkl_fn,
    paired=True,
    ignore_qc=True,
    gpu=False,
)
args = {
    "ds_fn": ds_fn,
    "precompute_fn": dir,
    "empirical_pkl_fn": empirical_pkl_fn,
    "N": 1024 * 4,
    "theta_err": 0.01,
    "theta_dot_err": 0.01,
    "full_p_fn": full_p_fn,
}
results = run_PF_single_theta_single_radio(**args)
for result in results:
    assert result["metrics"]["mse_theta"] < 0.05
# plot_single_theta_single_radio(ds, full_pkl_fn=full_p_fn

In [None]:
import torch


ds = v5spfdataset(
    ds_fn,
    nthetas=65,
    ignore_qc=True,
    precompute_cache=dir,
    paired=True,
    snapshots_per_session=1,
    readahead=True,
    skip_fields=set(
        [
            "windowed_beamformer",
            "all_windows_stats",
            "downsampled_segmentation_mask",
            "signal_matrix",
            "simple_segmentations",
        ]
    ),
    empirical_data_fn=empirical_pkl_fn,
)
theta_err = 0.01
theta_dot_err = 0.01
N = 1024 * 4
metrics = []
for rx_idx in [0, 1]:
    pf = PFSingleThetaSingleRadio(ds=ds, rx_idx=rx_idx, full_p_fn=full_p_fn)
    trajectory = pf.trajectory(
        mean=torch.tensor([[0, 0]]),
        std=torch.tensor([[2, 0.1]]),
        noise_std=torch.tensor([[theta_err, theta_dot_err]]),
        return_particles=False,
        N=N,
    )
    metrics.append(
        {
            "type": "single_theta_single_radio",
            "ds_fn": ds_fn,
            "rx_idx": rx_idx,
            "theta_err": theta_err,
            "theta_dot_err": theta_dot_err,
            "N": N,
            "metrics": pf.metrics(trajectory=trajectory),
        }
    )