In [None]:
import tempfile
from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml
from spf.dataset.spf_dataset import v5spfdataset

n = 1025
noise = 1.0
nthetas = 65
orbits = 4

tmpdir = tempfile.TemporaryDirectory()
tmpdirname = tmpdir.name
ds_fn = f"{tmpdirname}/sample_dataset_for_ekf_n{n}_noise{noise}"

create_fake_dataset(
    filename=ds_fn, yaml_config_str=fake_yaml, n=n, noise=noise, orbits=orbits
)

In [None]:
simulated_circle_dss = [
    v5spfdataset(
        ds_fn,
        nthetas=nthetas,
        ignore_qc=True,
        precompute_cache=tmpdirname,
        paired=True,
        skip_fields=set(["signal_matrix"]),
    )
]

In [None]:
from spf.dataset.spf_dataset import v5spfdataset
import glob
import torch


def load_dss(fns):
    return [
        v5spfdataset(
            fn,
            nthetas=65,
            ignore_qc=True,
            precompute_cache="/home/mouse9911/precompute_cache_chunk16_fresh/",
            gpu=True,
            n_parallel=20,
        )
        for fn in fns
    ]


# real_circle_dss = load_dss(
#     glob.glob(
#         "/mnt/md0/spf/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_0*circle*.zarr"
#     )
# )
real_bounce_dss = load_dss(
    glob.glob(
        # "/mnt/md0/spf/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_0*bounce*.zarr"
        # "/mnt/md0/spf/2d_wallarray_v2_data/june/wallarrayv3_2024_07_29_22_59_1*"
        "/mnt/md0/spf/2d_wallarray_v2_data/june/wallarrayv3_2024_07_30_*0p06*.zarr"
        #
        #
    )
)
real_circle_dss = real_bounce_dss

In [None]:
x = torch.hstack([torch.tensor([]).mean(), torch.tensor(0.1)])
x[~x.isfinite()] = 5
x

In [None]:
# hist(h(x),z) # where x is state and z is angle
# hist(h(x),z) # wehre x is state and z is beamformer output (peaks, width)


# radio 0, then radio 1 separate
# simulated, real circle, real bounce

In [None]:
real_circle_dss[0].segmentation["segmentation_by_receiver"]["r0"][159][
    "simple_segmentation"
]

In [None]:
real_circle_dss[0].mean_phase["r0"][159]

In [None]:
np.where(np.isfinite(real_circle_dss[0].mean_phase["r0"]) == False)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot_hx_z(dss, tag, bins=50):
    fig, axs = plt.subplots(1, 6, figsize=(30, 5), sharex=True, sharey=True)
    heatmaps = []
    for ridx in [0, 1]:
        ground_truth_thetas = np.hstack([ds.ground_truth_thetas[ridx] for ds in dss])
        mean_phase = np.hstack([ds.mean_phase[f"r{ridx}"] for ds in dss])
        ground_truth_phis = np.hstack([ds.ground_truth_phis[ridx] for ds in dss])
        axs[0 + 3 * ridx].scatter(
            ground_truth_thetas, mean_phase, s=1, label="z", alpha=0.3
        )
        axs[0 + 3 * ridx].scatter(
            ground_truth_thetas, ground_truth_phis, s=1, label="h(x)"
        )
        axs[0 + 3 * ridx].scatter(
            ground_truth_thetas,
            ground_truth_phis * np.sqrt(np.abs(ground_truth_phis)),
            s=1,
            label="h(x)",
        )
        axs[0 + 3 * ridx].set_xlabel("x/theta (-pi,+pi)")
        axs[0 + 3 * ridx].set_ylabel("z/phi")
        axs[0 + 3 * ridx].legend()

        print(np.isfinite(mean_phase).all())
        heatmap, xedges, yedges = np.histogram2d(
            ground_truth_phis, mean_phase, bins=bins
        )
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

        # ig.clf()
        axs[2 + 3 * ridx].imshow(heatmap.T, extent=extent, origin="lower")
        axs[2 + 3 * ridx].set_title("h(x)/gt_phi vs z/phi")

        heatmap, xedges, yedges = np.histogram2d(
            ground_truth_thetas, mean_phase, bins=bins
        )
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
        heatmaps.append(heatmap)

        # ig.clf()
        axs[1 + 3 * ridx].imshow(heatmap.T, extent=extent, origin="lower")
        axs[1 + 3 * ridx].set_title("x/theta vs z/phi")

        # axs[1 + 2 * ridx].scatter(
        #     ground_truth_thetas,
        #     ground_truth_phis,
        #     s=0.1,
        #     label="h(x)",
        #     color="red",
        #     alpha=0.5,
        # )
        axs[0 + 3 * ridx].set_title(f"radio {ridx}")
        # axs[1 + 3 * ridx].set_title(f"radio {ridx}")
        # plt.show()
    fig.suptitle(f"{tag}: z vs h(x)")
    return heatmaps


# plot_hx_z(simulated_circle_dss, "simulated circle data radio:{ridx}")
plot_hx_z(real_circle_dss, "real circle data radio:{ridx}")
# heatmaps = plot_hx_z(real_bounce_dss, "real bounce data radio:{ridx}")

In [None]:
heatmaps = plot_hx_z(real_bounce_dss, "real bounce data radio:{ridx}")
h = heatmaps[0].copy() + heatmaps[1].copy()
# h[:,0]=11000 # across thetas
# have symmetry
half = h[:25] + np.flip(h[25:])
full = np.vstack([half, np.flip(half)])
full = full / full.sum(axis=1, keepdims=True)

In [None]:
import pickle

pickle.dump({"full_p": full}, open("full_p.pkl", "wb"))

In [None]:
plt.imshow(full)

In [None]:
def ds_beamformer_to_stats(ds, ridx=0):
    peaks = []
    segmented_peaks = []
    thetas = ds.ground_truth_thetas[ridx]
    for idx in range(len(ds)):
        beamformer = ds[idx][ridx]["windowed_beamformer"]
        seg_mask = ds[idx][ridx]["downsampled_segmentation_mask"].reshape(-1)
        segmented_beamformer = ds[idx][ridx]["windowed_beamformer"][
            seg_mask.numpy().astype(bool), :
        ]
        peaks.append(np.argmax(beamformer.mean(axis=0, keepdims=True)))
        segmented_peaks.append(
            np.argmax(segmented_beamformer.mean(axis=0, keepdims=True))
        )
    return thetas, peaks, segmented_peaks


# ds_beamformer_to_stats(simulated_circle_dss[0])

In [None]:
np.argmax(beamformer.mean(axis=0, keepdims=True))

In [None]:
plt.imshow(beamformer.mean(axis=0, keepdims=True))

In [None]:
plt.imshow(segmented_beamformer.mean(axis=0, keepdims=True))