In [None]:
from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml

ds_fn = "test_circle"
create_fake_dataset(fake_yaml, ds_fn)

In [None]:
from spf.dataset.spf_dataset import v5spfdataset

import torch

ds = v5spfdataset(
    ds_fn,
    nthetas=65,
    ignore_qc=True,
    precompute_cache="/tmp/",
    segment_if_not_exist=True,
    paired=True,
)

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2)
axs[0].imshow(ds.precomputed_zarr["r0"].windowed_beamformer[:].mean(axis=1).T)
axs[1].imshow(ds.precomputed_zarr["r1"].windowed_beamformer[:].mean(axis=1).T)

In [None]:
ds.z.receivers["r0"].tree()

In [None]:
ds[0][0]["ground_truth_theta"] / torch.pi, ds[0][1]["ground_truth_theta"] / torch.pi

In [None]:
ds[0][0]["rx_theta_in_pis"], ds[0][1]["rx_theta_in_pis"]

In [None]:
craft_ground_truth_thetas = torch.vstack(
    [ds[idx][0]["craft_ground_truth_theta"] for idx in range(len(ds))]
)
craft_ground_truth_thetas.shape

In [None]:
plt.imshow(x)

In [None]:
gt = torch.zeros(craft_ground_truth_thetas.shape[0], 65)
gt[int(craft_ground_truth_thetas // (2 * torch.pi / 65)) % 65] = 1

In [None]:
from spf.rf import rotate_dist

rs = []
for rx_idx in range(2):
    r = torch.as_tensor(
        ds.precomputed_zarr[f"r{rx_idx}"].windowed_beamformer[:].mean(axis=1)
    )
    r = rotate_dist(
        r,
        torch.as_tensor(ds.z.receivers[f"r{rx_idx}"].rx_theta_in_pis[:]) * torch.pi,
    ).T
    r /= r.sum(axis=0)
    rs.append(r)

gt = torch.scatter(
    torch.zeros(50, 65),
    1,
    (((craft_ground_truth_thetas - torch.pi) // (2 * torch.pi / 65)) % 65).to(
        torch.long
    ),
    1,
)
(gt * rs[0].T * rs[1].T).sum(axis=1).mean() - (rs[0] * rs[1]).max()

In [None]:
plt.imshow(gt * rs[0].T, origin="lower")

In [None]:
import matplotlib.pyplot as plt

from spf.rf import rotate_dist

fig, axs = plt.subplots(2, 3, figsize=(7, 6))
for row_idx in range(2):
    rs = []
    for rx_idx in range(2):
        r = torch.as_tensor(
            ds.precomputed_zarr[f"r{rx_idx}"].windowed_beamformer[:].mean(axis=1)
        )
        r = rotate_dist(
            r,
            torch.as_tensor(ds.z.receivers[f"r{rx_idx}"].rx_theta_in_pis[:])
            * torch.pi
            * row_idx,
        ).T
        rs.append(r)
    rs.append(rs[0] * rs[1])
    for idx in range(3):
        axs[row_idx, idx].imshow(rs[idx], origin="lower")
        axs[row_idx, idx].set_yticks([0, 32, 64], ["-pi", "0", "+pi"])
        title = ""
        if idx < 2:
            title = f"rx_idx:{idx}"
            if row_idx == 1:
                title = "craft " + title
        else:
            if row_idx == 0:
                title = "joint rx_idx 1*2"
            else:
                title = "joint craft rx_idx 1*2"
        axs[row_idx, idx].set_title(title)

    fig.tight_layout()

In [None]:
 ds[0][0]["weighted_beamformer"][0].shape

In [None]:
import matplotlib.pyplot as plt

from spf.rf import rotate_dist

plt.imshow(
    rotate_dist(
        ds[0][0]["weighted_beamformer"][0],
        0 * ds[0][0]["rx_theta_in_pis"].reshape(1, 1) * torch.pi + 1,
    )
)

In [None]:
ds[0][0]["rx_theta_in_pis"]

In [None]:
import matplotlib.pyplot as plt
import torch
from spf.rf import get_peaks_for_2rx, torch_pi_norm

rx_idx = 0

fig, axs = plt.subplots(1, 2, figsize=(14, 4))
axs[0].plot(
    [ds[idx][rx_idx]["ground_truth_theta"] for idx in range(len(ds))],
    label=f"gt-theta-rx{rx_idx}",
)
# axs[0].plot(
#    [ds[idx][1]["ground_truth_theta"] for idx in range(len(ds))], label="gt-theta-rx1"
# )
axs[0].plot(
    [ds[idx][rx_idx]["craft_ground_truth_theta"] for idx in range(len(ds))],
    label=f"craft-theta-rx{rx_idx}",
)
# axs[0].plot(
#     [ds[idx][0]["craft_ground_truth_theta"] for idx in range(len(ds))],
#     label="craft-theta-rx0",
# )
axs[0].legend()

# axs[1].plot(
#    [ds[idx][0]["mean_phase_segmentation"].item() for idx in range(20)],
#    label="phase-rx0",
# )
axs[1].plot(
    [
        torch_pi_norm(
            (
                torch.tensor(
                    get_peaks_for_2rx(ds[idx][rx_idx]["weighted_beamformer"][0, 0])
                )
                / 65
                - 0.5
            )
            * 2
            * torch.pi
        )
        for idx in range(len(ds))
    ],
    label="uncorrected",
)
axs[1].plot(
    [
        torch_pi_norm(
            (
                torch.tensor(
                    get_peaks_for_2rx(ds[idx][rx_idx]["weighted_beamformer"][0, 0])
                )
                / 65
                - 0.5
            )
            * 2
            * torch.pi
            + ds[idx][rx_idx]["rx_theta_in_pis"] * torch.pi
        )
        for idx in range(len(ds))
    ],
    label="corrected",
)
axs[1].legend()

# ds[0][1]

In [None]:
import matplotlib.pyplot as plt
import torch
from spf.rf import get_peaks_for_2rx, torch_pi_norm

rx_idx = 0

fig, axs = plt.subplots(1, 2, figsize=(14, 4))
axs[0].plot(
    [ds[idx][0]["ground_truth_theta"] for idx in range(len(ds))], label="gt-theta-rx0"
)
axs[0].plot(
    [ds[idx][1]["ground_truth_theta"] for idx in range(len(ds))], label="gt-theta-rx1"
)
axs[0].plot(
    [ds[idx][1]["craft_ground_truth_theta"] for idx in range(len(ds))],
    label="craft-theta-rx1",
)
axs[0].plot(
    [ds[idx][0]["craft_ground_truth_theta"] for idx in range(len(ds))],
    label="craft-theta-rx0",
)
axs[0].legend()

# axs[1].plot(
#    [ds[idx][0]["mean_phase_segmentation"].item() for idx in range(20)],
#    label="phase-rx0",
# )
axs[1].plot(
    [
        torch_pi_norm(
            (
                torch.tensor(
                    get_peaks_for_2rx(
                        rotate_dist(
                            ds[idx][rx_idx]["weighted_beamformer"][0],
                            ds[idx][rx_idx]["rx_theta_in_pis"].reshape(1, 1) * torch.pi,
                        )
                    )
                )
                / 65
                - 0.5
            )
            * 2
            * torch.pi
        )
        for idx in range(len(ds))
    ]
)

# ds[0][1]

In [None]:
fig, axs = plt.subplots(1, 3)
axs[0].imshow(
    torch.vstack(
        [ds[idx][rx_idx]["weighted_beamformer"][0] for idx in range(len(ds))]
    ).T,
    origin="lower",
)
axs[0].set_yticks([0, 32, 64], ["-pi", "0", "+pi"])

axs[1].imshow(
    torch.vstack(
        [
            rotate_dist(
                ds[idx][rx_idx]["weighted_beamformer"][0],
                (
                    ds[idx][rx_idx]["rx_theta_in_pis"]
                    + ds[idx][rx_idx]["rx_heading_in_pis"]
                ).reshape(1, 1)
                * torch.pi,
            )
            for idx in range(len(ds))
        ]
    ).T,
    origin="lower",
)
axs[1].set_yticks([0, 32, 64], ["-pi", "0", "+pi"])
axs[2].imshow(
    rotate_dist(
        torch.vstack(
            [ds[idx][rx_idx]["weighted_beamformer"][0] for idx in range(len(ds))]
        ),
        torch.vstack(
            [
                ds[idx][rx_idx]["rx_theta_in_pis"][0]
                + ds[idx][rx_idx]["rx_heading_in_pis"][0]
                for idx in range(len(ds))
            ]
        )
        * torch.pi,
    ).T,
    origin="lower",
)
axs[2].set_yticks([0, 32, 64], ["-pi", "0", "+pi"])

In [None]:
ds[0][0]["rx_heading_in_pis"], ds[0][0]["rx_theta_in_pis"]

In [None]:
a = torch.vstack([ds[idx][rx_idx]["rx_theta_in_pis"][0] for idx in range(len(ds))])

In [None]:
b = torch.vstack([ds[idx][rx_idx]["weighted_beamformer"][0] for idx in range(len(ds))])

In [None]:
a.shape, b.shape

In [None]:
rotate_dist(b[[0]], a[0]) == rotate_dist(b, a)[0]

In [None]:
from spf.rf import get_peaks_for_2rx, torch_pi_norm
import torch

torch_pi_norm(
    (torch.tensor(get_peaks_for_2rx(ds[0][0]["weighted_beamformer"][0, 0])) / 65 - 0.5)
    * 2
    * torch.pi
), ds[0][0]["ground_truth_theta"]

In [None]:
plt.imshow(ds[0][0]["weighted_beamformer"][0])

In [None]:
ds[0][0]["weighted_beamformer"].shape

In [None]:
lines = torch.zeros(100, 65)
lines_2x = torch.zeros(100, 65)
angle_diffs = torch.zeros(100, 1)
for i in range(100):
    lines[i][i % 65] = 1.0
    lines_2x[i][(2 * i) % 65] = 1.0
    angle_diffs[i] = i * 2 * torch.pi / 65

In [None]:
angle_diffs.shape

In [None]:
plt.imshow(lines)

In [None]:
plt.imshow(rotate_dist(lines, angle_diffs))

In [None]:
plt.imshow(rotate_dist(rotate_dist(a, b), -b))

In [None]:
(lines - rotate_dist(rotate_dist(lines, angle_diffs), -angle_diffs)).isclose(
    torch.tensor([0.0])
).all()

In [None]:
(lines - rotate_dist(lines_2x, -angle_diffs)).isclose(torch.tensor([0.0])).all()

In [None]:
(lines_2x - rotate_dist(lines, angle_diffs)).isclose(torch.tensor([0.0])).all()

In [None]:
torch.nn.functional.normalize(rotate_dist(lines, angle_diffs), p=1.0, dim=1).sum(axis=1)

In [None]:
rotate_dist(lines, angle_diffs + 0.1).sum(axis=1)