In [None]:
from spf.dataset.spf_dataset import v5spfdataset
import glob


def load_dss(fns):
    return [
        v5spfdataset(
            fn,
            nthetas=65,
            ignore_qc=True,
            precompute_cache="/home/mouse9911/precompute_cache_chunk16_fresh/",
            gpu=True,
            skip_fields=["signal_matrix"],
        )
        for fn in fns
    ]


real_circle_dss = load_dss(
    glob.glob(
        "/mnt/md0/spf/2d_wallarray_v2_data/june_fix_nosig/wallarrayv3_2024_06_0*circle*.zarr"
    )
)
real_bounce_dss = load_dss(
    glob.glob(
        "/mnt/md0/spf/2d_wallarray_v2_data/june_fix_nosig/wallarrayv3_2024_06_0*bounce*.zarr"
    )
)

In [None]:
from matplotlib import pyplot as plt
import torch

from spf.scripts.create_empirical_p_dist import (
    apply_symmetry_rules_to_heatmap,
    get_heatmap_for_radio,
)


def plot_empirical(dss, bins):
    fig, axs = plt.subplots(2, 3, figsize=(15, 10))
    row_idx = 0
    heatmaps = {"r0": {}, "r1": {}, "r": {}}
    for symmetry in [False, True]:
        r0, _, _ = get_heatmap_for_radio(dss, 0, bins)
        r1, _, _ = get_heatmap_for_radio(dss, 1, bins)
        r = (r0 + r1) / 2
        if symmetry:
            r0 = apply_symmetry_rules_to_heatmap(r0)
            r1 = apply_symmetry_rules_to_heatmap(r1)
            r = apply_symmetry_rules_to_heatmap(r)
        extent = [-torch.pi, torch.pi, -torch.pi, torch.pi]
        r0 = r0 / r0.sum(axis=1, keepdims=True)
        r1 = r1 / r1.sum(axis=1, keepdims=True)
        r = r / r.sum(axis=1, keepdims=True)
        heatmaps["r0"]["sym" if symmetry else "nosym"] = r0
        heatmaps["r1"]["sym" if symmetry else "nosym"] = r1
        heatmaps["r"]["sym" if symmetry else "nosym"] = r
        # axs[2 + 3 * ridx].imshow(heatmap.T, extent=extent, origin="lower")
        axs[row_idx, 0].imshow(r0.T, extent=extent)
        axs[row_idx, 0].set_title(f"Radio0,sym={symmetry}")
        axs[row_idx, 1].imshow(r1.T, extent=extent)
        axs[row_idx, 1].set_title(f"Radio1,sym={symmetry}")
        axs[row_idx, 2].imshow(r.T, extent=extent)
        axs[row_idx, 2].set_title(f"Radio0+1,sym={symmetry}")
        row_idx += 1
    return heatmaps


heatmaps = plot_empirical(real_circle_dss, bins=64)

In [None]:
hm = get_heatmap_for_radio(real_circle_dss, 0, 65)

In [None]:
import math


x = hm[0].copy()
# x[:10]=1000
plt.imshow(x.T)

import numpy as np

bins = x.shape[0]
h = hm[0].copy()
half = x[: math.ceil(bins / 2)] + np.flip(h[math.floor(bins // 2) :])
# plt.imshow(h[math.floor(bins // 2) :].T)
# plt.imshow(np.flip(h[math.floor(bins // 2) :]).T)
# breakpoint()
# half = half + np.flip(half, axis=0)
plt.imshow(half.T)
# plt.imshow((half + np.flip(half, axis=0)).T)
# full = np.vstack([half[:-1], np.flip(half)])
# plt.imshow(full.T)
# return full  # / full.sum(axis=1, keepdims=True)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot_hx_z(dss, tag, bins=50):
    fig, axs = plt.subplots(1, 6, figsize=(30, 5))  # , sharex=True, sharey=True)
    heatmaps = []
    for ridx in [0, 1]:
        ground_truth_thetas = np.hstack([ds.ground_truth_thetas[ridx] for ds in dss])
        mean_phase = np.hstack([ds.mean_phase[f"r{ridx}"] for ds in dss])
        ground_truth_phis = np.hstack([ds.ground_truth_phis[ridx] for ds in dss])

        # plot theta vs phi
        axs[0 + 3 * ridx].scatter(
            ground_truth_thetas, mean_phase, s=1, label="z", alpha=0.3
        )
        # plot theta vs g.t. phi
        axs[0 + 3 * ridx].scatter(
            ground_truth_thetas, ground_truth_phis, s=1, label="h(x)"
        )
        axs[0 + 3 * ridx].scatter(
            ground_truth_thetas,
            ground_truth_phis * np.sqrt(np.abs(ground_truth_phis)),
            s=1,
            label="h(x)",
        )
        axs[0 + 3 * ridx].set_xlabel("x/theta (-pi,+pi)")
        axs[0 + 3 * ridx].set_ylabel("z/phi")
        axs[0 + 3 * ridx].legend()
        axs[0 + 3 * ridx].set_title(f"radio {ridx}")

        # gt theta vs phi
        heatmap, xedges, yedges = np.histogram2d(
            ground_truth_thetas, mean_phase, bins=bins
        )
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
        # heatmaps.append(heatmap)

        axs[1 + 3 * ridx].imshow(heatmap.T, extent=extent, origin="lower")
        axs[1 + 3 * ridx].set_title("x/theta vs z/phi")
        axs[1 + 3 * ridx].set_xlabel("G.T. theta")
        axs[1 + 3 * ridx].set_ylabel("phi")

        # plot gt phi vs phi
        heatmap, xedges, yedges = np.histogram2d(
            ground_truth_phis, mean_phase, bins=bins
        )
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

        axs[2 + 3 * ridx].imshow(heatmap.T, extent=extent, origin="lower")
        axs[2 + 3 * ridx].set_title("h(x)/gt_phi vs z/phi")
        axs[2 + 3 * ridx].set_xlabel("G.T. phi")
        axs[2 + 3 * ridx].set_ylabel("phi")

    fig.suptitle(f"{tag}: z vs h(x)")
    return heatmaps

In [None]:
# plot_hx_z(simulated_circle_dss, "simulated circle data radio:{ridx}")
plot_hx_z(real_circle_dss, "real circle data radio:{ridx}")

In [None]:
heatmaps = plot_hx_z(real_bounce_dss, "real bounce data radio:{ridx}")

In [None]:
import numpy as np


def get_heatmap(dss, bins=50):
    heatmaps = []
    for ridx in [0, 1]:
        ground_truth_thetas = np.hstack([ds.ground_truth_thetas[ridx] for ds in dss])
        mean_phase = np.hstack([ds.mean_phase[f"r{ridx}"] for ds in dss])
        heatmap, _, _ = np.histogram2d(ground_truth_thetas, mean_phase, bins=bins)
        heatmaps.append(heatmap)
    return heatmaps[0].copy() + heatmaps[1].copy()


def apply_symmetry_rules_to_heatmap(h):
    half = h[:25] + np.flip(h[25:])
    half = half + np.flip(half, axis=0)
    full = np.vstack([half, np.flip(half)])
    return full / full.sum(axis=1, keepdims=True)


heatmap = apply_symmetry_rules_to_heatmap(get_heatmap(real_bounce_dss))

In [None]:
import matplotlib.pyplot as plt

plt.imshow(heatmap)

In [None]:
import pickle

pickle.dump({"full_p": full}, open("full_p.pkl", "wb"))