In [None]:
from spf.dataset.spf_dataset import v5spfdataset

ds = v5spfdataset(
    # "/mnt/md2/rovers/merged_test/rover_2025_03_29_16_16_59_nRX2_center_spacing0p043_tag_RO3.rover_2025_03_29_16_11_30_nRX1_circle_spacing0p05075_tag_RO2.zarr",
    # "/mnt/md2/rovers/merged/rover_2025_03_29_16_14_54_nRX2_diamond_spacing0p035_tag_RO1.rover_2025_03_29_16_11_30_nRX1_circle_spacing0p05075_tag_RO2.zarr",
    "/mnt/md2/rovers/merged/rover_2025_03_22_19_55_11_nRX2_diamond_spacing0p035_tag_RO1.rover_2025_03_22_19_55_02_nRX1_circle_spacing0p05075_tag_RO2.zarr",
    # "/mnt/md2/rovers/merged/rover_2025_03_29_16_14_54_nRX2_diamond_spacing0p035_tag_RO1.rover_2025_03_29_16_11_30_nRX1_circle_spacing0p05075_tag_RO2.zarr",
    # "/mnt/md2/rovers/merged/rover_2025_03_29_17_34_29_nRX2_diamond_spacing0p035_tag_RO1.rover_2025_03_29_17_34_21_nRX1_circle_spacing0p05075_tag_RO2.zarr",
    # "/mnt/md2/2d_wallarray_v2_data/march_nuand/wallarrayv3_2025_03_29_19_03_59_nRX2_rx_random_circle_spacing0p025.zarr",
    # "/mnt/md2/2d_wallarray_v2_data/march/wallarrayv3_2025_03_15_11_18_30_nRX2_rx_random_circle_spacing0p043.zarr",
    # "/mnt/md2/2d_wallarray_v2_data/march_nuand/wallarrayv3_2025_03_29_14_52_33_nRX2_rx_random_circle_spacing0p025.zarr",
    nthetas=65,
    ignore_qc=True,
    precompute_cache="/mnt/md2/cache/precompute_cache_3p5_chunk1/",
    gpu=True,
    snapshots_per_session=1,
    n_parallel=8,
    paired=True,
    segmentation_version=3.5,
)

In [None]:
import numpy as np

tx_pos = np.vstack([ds.z.receivers.r0.tx_pos_x_mm, ds.z.receivers.r0.tx_pos_y_mm])
rx_pos = np.vstack([ds.z.receivers.r0.rx_pos_x_mm, ds.z.receivers.r0.rx_pos_y_mm])
dists = np.sqrt((rx_pos - tx_pos) ** 2).sum(axis=0)


start_idx = 10
end_idx = min(6000, len(ds))

In [None]:
import matplotlib.pyplot as plt


fig, axs = plt.subplots(1, 2, figsize=(12, 6))

step = int((end_idx - start_idx) / 40)
sessions = [ds[idx][0] for idx in range(start_idx, end_idx, step)]

axs[0].scatter(
    [session["tx_pos_x_mm"] for session in sessions],
    [session["tx_pos_y_mm"] for session in sessions],
    label="tx",
)
axs[0].scatter(
    [session["rx_pos_x_mm"] for session in sessions],
    [session["rx_pos_y_mm"] for session in sessions],
    label="rx",
)
axs[0].set_ylabel("x_position(mm)")
axs[0].set_xlabel("y_position(mm)")
axs[0].set_title("Position plot")
axs[0].legend()

axs[1].plot([session["tx_pos_x_mm"] for session in sessions], label="tx_x")
axs[1].plot([session["tx_pos_y_mm"] for session in sessions], label="tx_y")
axs[1].plot([session["rx_pos_x_mm"] for session in sessions], label="rx_x")
axs[1].plot([session["rx_pos_y_mm"] for session in sessions], label="rx_y")
axs[1].plot(rx_pos[0][start_idx:end_idx:step], label="rx_xx")
axs[1].set_title("Position vs time")
axs[1].set_ylabel("value")
axs[1].set_xlabel("sample_idx (time)")
axs[1].legend()

In [None]:
import matplotlib.pyplot as plt
import torch
from spf.rf import torch_pi_norm
import numpy as np
import os


# show the beamformer vs expected


def normalize(x, dim):
    # return x
    # return x / x.sum(axis=dim, keepdims=True)
    return x / x.max(axis=dim, keepdims=True)


for rx_idx in range(2):
    fig, axs = plt.subplots(5, 1, figsize=(14, 10))

    fig.suptitle(f"{os.path.basename(ds.zarr_fn)}:{start_idx}-{end_idx}")
    ax_idx = 0
    axs[ax_idx].set_title(f"rx_idx{rx_idx} : theta")
    axs[ax_idx].plot(
        ds.ground_truth_thetas[rx_idx][start_idx:end_idx], label="ground truth"
    )
    axs[ax_idx].plot(ds.absolute_thetas[rx_idx][start_idx:end_idx], label="absolute")
    axs[ax_idx].plot(
        torch_pi_norm(
            ds.cached_keys[rx_idx]["rx_heading_in_pis"][start_idx:end_idx] * torch.pi
        ),
        label="rx_heading",
    )
    axs[ax_idx].plot(
        ds.cached_keys[rx_idx]["rx_theta_in_pis"][start_idx:end_idx] * torch.pi,
        label="rx_theta",
    )
    axs[ax_idx].set_yticks([-torch.pi, 0, torch.pi], ["-pi", "0", "pi"])
    axs[ax_idx].legend()
    axs[ax_idx].set_xlim([0, end_idx - start_idx])
    ax_idx += 1

    axs[ax_idx].set_title(f"rx_idx{rx_idx} : windowed beamformer")
    x = normalize(
        ds.precomputed_zarr[f"r{rx_idx}"].windowed_beamformer[:].astype(np.float32),
        2,
    ).mean(axis=1)[start_idx:end_idx]
    x /= x.sum(axis=1, keepdims=True)
    axs[ax_idx].imshow(x.T, origin="lower", aspect="auto")
    axs[ax_idx].set_yticks([0, 32, 64], ["-pi", "0", "pi"])
    ax_idx += 1

    axs[ax_idx].set_title(f"rx_idx{rx_idx} : dist rx to tx")
    axs[ax_idx].plot(dists[start_idx:end_idx])
    axs[ax_idx].set_xlim([0, end_idx - start_idx])
    ax_idx += 1

    axs[ax_idx].set_title(f"rx_idx{rx_idx} : gain")
    axs[ax_idx].plot(
        ds.z.receivers[f"r{rx_idx}"].gains[start_idx:end_idx, 0],
        label=f"R{rx_idx}-Gain0",
    )
    axs[ax_idx].plot(
        ds.z.receivers[f"r{rx_idx}"].gains[start_idx:end_idx, 1],
        label=f"R{rx_idx}-Gain1",
    )
    axs[ax_idx].set_xlim([0, end_idx - start_idx])
    axs[ax_idx].legend()
    ax_idx += 1

    axs[ax_idx].set_title(f"rx_idx{rx_idx} : RSSI")
    axs[ax_idx].plot(
        ds.z.receivers[f"r{rx_idx}"].rssis[start_idx:end_idx, 0],
        label=f"R{rx_idx}-RSSI0",
    )
    axs[ax_idx].plot(
        ds.z.receivers[f"r{rx_idx}"].rssis[start_idx:end_idx, 1],
        label=f"R{rx_idx}-RSSI1",
    )
    axs[ax_idx].set_xlim([0, end_idx - start_idx])
    axs[ax_idx].legend()
    fig.tight_layout()

In [None]:
from spf.sdrpluto.detrend import detrend_np
from spf.rf import beamformer_given_steering_nomean, get_phase_diff
import matplotlib.pyplot as plt
from spf.dataset.segmentation import default_segment_args, simple_segment

# look at radio rx_idx
rx_idx = 0
# load up the session_idx'th buffer in this recording and plot parts of it
session_idx = 4
offset = 0  # offset inside of buffer

data = ds[session_idx][rx_idx]

n = data["signal_matrix"].shape[3]

#
raw_radio_values = detrend_np(
    data["signal_matrix"][0, 0, :, offset : offset + n].numpy()
)
phase_difference = get_phase_diff(raw_radio_values)

#
fig, axs = plt.subplots(3, 1, figsize=(12, 6))
fig.suptitle(
    f"Raw signal + Phase offsets: {os.path.basename(ds.zarr_fn)} rx_idx{rx_idx} sessionidx:{session_idx}"
)
axs[0].scatter(np.arange(n), np.abs(raw_radio_values[0]), alpha=0.1, s=1, label="ant0")
axs[0].scatter(np.arange(n), raw_radio_values[0].real, alpha=0.1, s=1, label="ant0 r")
axs[0].scatter(np.arange(n), raw_radio_values[0].imag, alpha=0.1, s=1, label="ant0 i")
axs[0].set_title("Raw signal ant0")
axs[1].scatter(np.arange(n), np.abs(raw_radio_values[1]), alpha=0.1, s=1, label="ant0")
axs[1].scatter(np.arange(n), raw_radio_values[1].real, alpha=0.1, s=1, label="ant0 r")
axs[1].scatter(np.arange(n), raw_radio_values[1].imag, alpha=0.1, s=1, label="ant0 i")
axs[1].set_title("Raw signal ant1")
axs[0].set_xlabel("Sample# (time)")
axs[1].set_xlabel("Sample# (time)")
axs[2].set_xlabel("Sample# (time)")
axs[2].set_title("Phase estimates")
axs[2].scatter(np.arange(n), phase_difference, s=1, alpha=0.01)

beamformer_output = [
    beamformer_given_steering_nomean(
        steering_vectors=ds.steering_vectors[receiver_idx],
        signal_matrix=raw_radio_values,
    )
    for receiver_idx in range(2)
]

window_sds = []
for window in simple_segment(raw_radio_values, **default_segment_args)[
    "simple_segmentation"
]:
    print(window)
    if window["type"] == "signal":
        axs[1].plot(
            [window["start_idx"], window["end_idx"]],
            [window["mean"], window["mean"]],
            color="red",
        )
    else:
        axs[1].plot(
            [window["start_idx"], window["end_idx"]],
            [window["mean"], window["mean"]],
            color="orange",
        )
    _beam_sds = beamformer_output[0][:, window["start_idx"] : window["end_idx"]].mean(
        axis=1
    )
    window_sds.append(_beam_sds)
window_sds = np.array(window_sds)
fig.set_tight_layout(True)

In [None]:
normed = torch.nn.functional.normalize(
    ds[session_idx][rx_idx]["windowed_beamformer"][0, 0, :], p=1, dim=1
)
unnormed = ds[session_idx][rx_idx]["windowed_beamformer"][0, 0, :]
buffer_size = ds[session_idx][rx_idx]["signal_matrix"].shape[-1]
window_size = buffer_size // normed.shape[0]

fig, axs = plt.subplots(2, 1, figsize=(10, 5))
axs[0].imshow(normed.T, aspect="auto", origin="lower")
axs[1].imshow(unnormed.T, aspect="auto", origin="lower")
y_ticks = [0, 32, 64]
y_labels = [r"$-\pi$", "0", r"$\pi$"]
x_labels = np.arange(0, buffer_size, 100000)
x_indices = x_labels // 2048
for idx in [0, 1]:
    axs[idx].set_yticks(y_ticks)
    axs[idx].set_yticklabels(y_labels)
    axs[idx].set_xticks(x_indices)  # Show every 8th tick
    axs[idx].set_xticklabels(x_labels)
    axs[idx].set_xlabel("Sample # (time)")
    axs[idx].set_ylabel("power at angle (rad)")
    if idx == 0:
        axs[idx].set_title("Normalized Beamformer mean by window (size=2048)")
    else:
        axs[idx].set_title("Beamformer mean by window (size=2048)")
fig.tight_layout()

In [None]:
offset = 70000
n = 1024 * 8
offset = 208896
offset = 470000

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[1].plot(
    data["signal_matrix"][0, 0, 1, offset : (offset + 1024 * 16)].imag,
    label="ant1 imag",
)
axs[0].plot(
    data["signal_matrix"][0, 0, 1, offset : (offset + 1024 * 16)].real,
    label="ant1 real",
)
axs[1].plot(
    data["signal_matrix"][0, 0, 0, offset : (offset + 1024 * 16)].imag,
    label="ant0 imag",
)
axs[0].plot(
    data["signal_matrix"][0, 0, 0, offset : (offset + 1024 * 16)].real,
    label="ant0 real",
)
axs[0].set_ylabel("I value")
axs[1].set_ylabel("Q value")
for idx in [0, 1]:
    axs[idx].set_xlabel("idx in buffer")
    axs[idx].legend()

fig.suptitle(
    f"{os.path.basename(ds.zarr_fn)}:{session_idx}:{offset}+{n} rx_idx{rx_idx} IQ values"
)
fig.tight_layout()

In [None]:
from spf.sdrpluto import detrend
from spf.rf import torch_pi_norm

from scipy.signal import butter, filtfilt


def high_pass_filter(data, cutoff=10, fs=16000000, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype="high", analog=False)
    y = filtfilt(b, a, data)
    return y


offset = 75000
# offset = 50000
n = 5000
# session_idx = 287  # 395#+790


fig, axs = plt.subplots(2, 1, figsize=(10, 5))
for rx_idx in range(2):
    _data = ds[session_idx][rx_idx]
    axs[rx_idx].set_title(
        f"{os.path.basename(ds.zarr_fn)}:{session_idx}:{offset}+{n} PlutoPlus:{rx_idx}, signal phase"
    )
    rx0_mean = _data["signal_matrix"][0, 0, 0, offset : (offset + n)].mean()
    rx1_mean = _data["signal_matrix"][0, 0, 1, offset : (offset + n)].mean()
    axs[rx_idx].plot(
        (_data["signal_matrix"][0, 0] - rx0_mean).angle()[0, offset : (offset + n)],
        label="rx0",
    )
    axs[rx_idx].plot(
        (_data["signal_matrix"][0, 0] - rx1_mean).angle()[1, offset : (offset + n)],
        label="rx1",
    )
    axs[rx_idx].set_xlabel("IDX in captured buffer")
    axs[rx_idx].set_ylabel("Measured phase")
    axs[rx_idx].legend()
fig.tight_layout()

In [None]:
session = ds[session_idx][0]
sessions = [ds[idx][0] for idx in range(session_idx - 600, session_idx + 600, 20)]
plt.scatter(
    [session["tx_pos_x_mm"] for session in sessions],
    [session["tx_pos_y_mm"] for session in sessions],
)
plt.scatter(
    [session["rx_pos_x_mm"] for session in sessions],
    [session["rx_pos_y_mm"] for session in sessions],
    label="rx",
)
plt.legend()
plt.gca().set_aspect("equal")
session["tx_pos_x_mm"], session["tx_pos_y_mm"]

In [None]:
data["signal_matrix"][0, 0][1, offset : (offset + n)].mean()
offset = 75000
# offset = 50000
n = 15000
offset = 208896

fig, axs = plt.subplots(2, 1, figsize=(10, 5))
for rx_idx in range(2):
    data = ds[session_idx][rx_idx]
    axs[rx_idx].set_title(
        f"{os.path.basename(ds.zarr_fn)}:{session_idx}:{offset}+{n} PlutoPlus:{rx_idx}"
    )
    raw_radio_values = detrend_np(data["signal_matrix"][0, 0])
    # v = data["signal_matrix"][0, 0]
    axs[rx_idx].plot(raw_radio_values[0, offset : (offset + n)].real, label="ant0-real")
    axs[rx_idx].plot(raw_radio_values[0, offset : (offset + n)].imag, label="ant0-imag")
    axs[rx_idx].plot(raw_radio_values[1, offset : (offset + n)].real, label="ant1-real")
    axs[rx_idx].plot(raw_radio_values[1, offset : (offset + n)].imag, label="ant1-imag")
    axs[rx_idx].set_xlabel("IDX in captured buffer")
    axs[rx_idx].set_ylabel("Measured real/imag (IQ)")
    axs[rx_idx].legend()
fig.tight_layout()

In [None]:
from spf.rf import torch_pi_norm


dists = (
    (ds.cached_keys[0]["tx_pos_mm"] / 1000 - ds.cached_keys[0]["rx_pos_mm"] / 1000)
    .pow(2)
    .sum(axis=1)
    .sqrt()
)
# plt.plot(dists)
s = 900
e = 1100
s = 0
e = -1
fig, axs = plt.subplots(2, 1, figsize=(10, 5))
for rx_idx in range(2):
    axs[rx_idx].plot(
        ds.z.receivers[f"r{rx_idx}"].gains[s:e, 0], label=f"R{rx_idx}-Gain0"
    )
    axs[rx_idx].plot(
        ds.z.receivers[f"r{rx_idx}"].gains[s:e, 1], label=f"R{rx_idx}-Gain1"
    )
    axs[rx_idx].legend()
    axs[rx_idx].set_xlabel("Buffer capture # / IDX")
    axs[rx_idx].set_ylabel("Gain")
    axs[rx_idx].set_title(
        f"{os.path.basename(ds.zarr_fn)}:{session_idx}:{offset}+{n} PlutoPlus:{rx_idx}, Gain"
    )
fig.tight_layout()

In [None]:
from spf.rf import torch_pi_norm


dists = (
    (ds.cached_keys[0]["tx_pos_mm"] / 1000 - ds.cached_keys[0]["rx_pos_mm"] / 1000)
    .pow(2)
    .sum(axis=1)
    .sqrt()
)
# plt.plot(dists)
s = 900
e = 1100
s = 0
e = -1
fig, axs = plt.subplots(2, 1, figsize=(10, 5))
for rx_idx in range(2):
    axs[rx_idx].plot(
        ds.z.receivers[f"r{rx_idx}"].rssis[s:e, 0], label=f"R{rx_idx}-RSSI0"
    )
    axs[rx_idx].plot(
        ds.z.receivers[f"r{rx_idx}"].rssis[s:e, 1], label=f"R{rx_idx}-RSSI1"
    )
    axs[rx_idx].legend()
    axs[rx_idx].set_xlabel("Buffer capture # / IDX")
    axs[rx_idx].set_ylabel("RSSI")
    axs[rx_idx].set_title(
        f"{os.path.basename(ds.zarr_fn)}:{session_idx}:{offset}+{n} PlutoPlus:{rx_idx}, RSSI"
    )
fig.tight_layout()

In [None]:
segmentation = simple_segment(raw_radio_values, **default_segment_args)[
    "simple_segmentation"
]

In [None]:
from spf.rf import torch_pi_norm


dists = (
    (ds.cached_keys[0]["tx_pos_mm"] / 1000 - ds.cached_keys[0]["rx_pos_mm"] / 1000)
    .pow(2)
    .sum(axis=1)
    .sqrt()
)
plt.plot(dists)  # [670:700])

In [None]:
plt.hist(ds.z.receivers.r0.gains[:, 0])

In [None]:
dists = (
    (ds.cached_keys[0]["tx_pos_mm"] / 1000 - ds.cached_keys[0]["rx_pos_mm"] / 1000)
    .pow(2)
    .sum(axis=1)
    .sqrt()
)
plt.plot(dists)
plt.plot(ds.z.receivers.r0.rssis[:, 0], label="R0-rssi0")
plt.plot(ds.z.receivers.r0.rssis[:, 1], label="R0-rssi1")
plt.plot(ds.z.receivers.r1.rssis[:, 0], label="R1-rssi0")
plt.plot(ds.z.receivers.r1.rssis[:, 1], label="R1-rssi1")
plt.legend()

In [None]:
from spf.rf import torch_pi_norm


dists = (
    (ds.cached_keys[0]["tx_pos_mm"] / 1000 - ds.cached_keys[0]["rx_pos_mm"] / 1000)
    .pow(2)
    .sum(axis=1)
    .sqrt()
)

r1_err = torch_pi_norm(ds.mean_phase["r1"] - ds.ground_truth_phis[1]).abs()
r0_err = torch_pi_norm(ds.mean_phase["r0"] - ds.ground_truth_phis[0]).abs()
plt.title("Distance vs phi error")
plt.scatter(dists, r1_err, s=1, label="R1")
plt.scatter(dists, r0_err, s=1)
plt.legend()
plt.xlabel("Distance (m)")
plt.ylabel("abs phi error")

In [None]:
-torch.sin(ds.ground_truth_thetas[0])

In [None]:
ds.ground_truth_thetas[1]

In [None]:
ds.

In [None]:
from spf.rf import torch_pi_norm


r1_err = torch_pi_norm(ds.mean_phase["r1"] - ds.ground_truth_phis[1])
r0_err = torch_pi_norm(ds.mean_phase["r0"] - ds.ground_truth_phis[0])
plt.hist(r1_err)
plt.hist(r0_err)

In [None]:
from spf.rf import torch_pi_norm


r1_err = torch_pi_norm(ds.mean_phase["r1"] - ds.ground_truth_phis[1]).abs()
r0_err = torch_pi_norm(ds.mean_phase["r0"] - ds.ground_truth_phis[0]).abs()
plt.title("Tx pos x vs phi error")
plt.scatter(ds.cached_keys[0]["tx_pos_mm"][:, 0], r1_err, s=1)
plt.scatter(ds.cached_keys[0]["tx_pos_mm"][:, 0], r0_err, s=1)
plt.xlabel("Distance (m)")
plt.ylabel("abs phi error")

In [None]:
from spf.rf import mean_phase_mean

mean_phases = []
means = []
weights = []
for x in segmentation:
    if x["type"] == "signal":
        means.append(x["mean"])
        weights.append(
            (x["end_idx"] - x["start_idx"])
            * x["abs_signal_median"]
            / (x["stddev"] + 1e-6)  # weight by signal strength and region
        )
if len(means) == 0:
    mean_phases.append(torch.nan)
else:
    means = np.array(means)
    weights = np.array(weights)
    # weights /= weights.sum()
    mean_phases.append(mean_phase_mean(angles=means, weights=weights))
mean_phases

In [None]:
normed = torch.nn.functional.normalize(
    ds[session_idx][0]["windowed_beamformer"][0, 0, :], p=1, dim=1
)
plt.imshow(normed)

In [None]:
plt.imshow(ds[session_idx][0]["windowed_beamformer"][0, 0, :])

In [None]:
ds[session_idx][0]["windowed_beamformer"][0, 0, :].max()

In [None]:
ds.receiver_data[0]["rx_pos_x_mm"][0], ds.receiver_data[0]["rx_pos_y_mm"][0]

In [None]:
from spf.rf import pi_norm


ridx = 0
rx_theta_in_pis = ds.receiver_data[ridx]["rx_theta_in_pis"]
tx_pos = np.array(
    [
        ds.receiver_data[ridx]["tx_pos_x_mm"],
        ds.receiver_data[ridx]["tx_pos_y_mm"],
    ]
)
rx_pos = np.array(
    [
        ds.receiver_data[ridx]["rx_pos_x_mm"],
        ds.receiver_data[ridx]["rx_pos_y_mm"],
    ]
)

# compute the angle of the tx with respect to rx
d = tx_pos - rx_pos

rx_to_tx_theta = np.arctan2(d[0], d[1])
# theta = pi_norm(rx_to_tx_theta - rx_theta_in_pis[:] * np.pi)
# theta, ds.get_ground_truth_thetas()
1

In [None]:
((ds.mean_phase["r0"] == 0.0000) * 1.0).mean(), (
    (ds.mean_phase["r0"].isfinite()) * 1.0
).mean()

In [None]:
ds.cached_keys[0]["rx_heading_in_pis"].shape, ds.ground_truth_phis[0].shape

In [None]:
from spf.rf import torch_pi_norm


def get_ground_truth_phisX(ds):
    ground_truth_phis = []
    for ridx in range(ds.n_receivers):
        ground_truth_phis.append(
            torch_pi_norm(
                -torch.sin(
                    ds.ground_truth_thetas[ridx]  # this is theta relative to our array!
                    # + self.receiver_data[ridx]["rx_theta_in_pis"][:] * np.pi
                )  # up to negative sign, which way do we spin?
                # or maybe this is the order of the receivers 0/1 vs 1/0 on the x-axis
                # pretty sure this (-) is more about which receiver is closer to x+/ish
                # a -1 here is the same as -rx_spacing!
                * ds.rx_wavelength_spacing
                * 2
                * torch.pi
            )
        )
    return torch.vstack(ground_truth_phis)


# z=get_ground_truth_phis(ds)
# z==ds.ground_truth_phis
# ds.cached_keys[0]["rx_heading_in_pis"][:first_n]

In [None]:
ds.cached_keys[0]["rx_heading_in_pis"][20]

In [None]:
ds.ground_truth_thetas[0][20], ds.ground_truth_phis[0][20]

In [None]:
from spf.rf import torch_pi_norm
from matplotlib import pyplot as plt

# segmentation = ds.get_segmentation()


first_n = 500 * 4  # 12 * 8
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].scatter(range(first_n), ds.mean_phase["r0"][:first_n], s=3, label="Rx0")
axs[1].scatter(range(first_n), ds.mean_phase["r1"][:first_n], s=3, label="Rx1")
axs[0].scatter(
    range(first_n),
    ds.ground_truth_phis[0][:first_n],
    s=3,
    label="Rx0 (GT)",
)
axs[1].scatter(range(first_n), ds.ground_truth_phis[1][:first_n], s=3, label="Rx1 (GT)")
for idx in range(2):
    axs[idx].legend()
    # axs.axvline(x=115)
    axs[idx].set_title("Mean segmented phase diff")
    axs[idx].set_xlabel("Chunk (time)")
    axs[idx].set_ylabel("Mean phase diff of seg. chunk")

In [None]:
# segmentation_by_receiver.keys()

In [None]:
ds.get_segmentation_mean_phase()
ds.get_estimated_thetas()

In [None]:
ds.mean_phase["r0"].numpy()

In [None]:
from spf.dataset.spf_dataset import pi_norm
from spf.rf import c as speed_of_light


fig, axs = plt.subplots(1, 2, figsize=(12, 4))

estimated_thetas = ds.get_estimated_thetas()
for rx_idx in [0, 1]:

    axs[rx_idx].scatter(
        range(estimated_thetas[f"r{rx_idx}"][0].shape[0]),
        pi_norm(estimated_thetas[f"r{rx_idx}"][0]),
        s=0.4,
    )
    axs[rx_idx].scatter(
        range(estimated_thetas[f"r{rx_idx}"][1].shape[0]),
        pi_norm(estimated_thetas[f"r{rx_idx}"][1]),
        s=0.4,
    )
    axs[rx_idx].scatter(
        range(estimated_thetas[f"r{rx_idx}"][2].shape[0]),
        pi_norm(estimated_thetas[f"r{rx_idx}"][2]),
        s=0.4,
    )
    axs[rx_idx].set_xlabel("Chunk")
    axs[rx_idx].set_ylabel("estimated theta")

In [None]:
from spf.dataset.spf_dataset import pi_norm
from spf.rf import reduce_theta_to_positive_y


fig, axs = plt.subplots(1, 2, figsize=(12, 4))

first_n = 1500
estimated_thetas = ds.get_estimated_thetas()
for rx_idx in [0, 1]:
    expected_theta = ds.ground_truth_thetas[rx_idx]
    axs[rx_idx].plot(
        expected_theta[:first_n],
        alpha=0.5,
        color="red",
        label="ground truth",
    )
    axs[rx_idx].plot(
        reduce_theta_to_positive_y(expected_theta[:first_n]),
        alpha=0.5,
        color="green",
        label="reduced ground truth",
    )

    n = estimated_thetas[f"r{rx_idx}"][0].shape[0]
    axs[rx_idx].scatter(
        range(first_n),
        pi_norm(estimated_thetas[f"r{rx_idx}"][0])[:first_n],
        s=3,
        label=f"Rx{rx_idx}_peak1",
    )
    axs[rx_idx].scatter(
        range(first_n),
        pi_norm(estimated_thetas[f"r{rx_idx}"][1])[:first_n],
        s=3,
        label=f"Rx{rx_idx}_peak2",
    )
    axs[rx_idx].set_xlabel("Chunk")
    axs[rx_idx].set_ylabel("estimated theta")
    axs[rx_idx].legend()
    axs[rx_idx].set_title(f"Receiver (Rx) {rx_idx}")

In [None]:
repo_root = "/Users/miskodzamba/Dropbox/research/gits/spf/"
import sys

if repo_root not in sys.path:
    sys.path.append(repo_root)  # go to parent dir

from spf.dataset.spf_dataset import v5spfdataset


ds = v5spfdataset(
    "/Volumes/SPFData/missions/april5/wallarrayv3_2024_05_06_19_04_15_nRX2_bounce",
    nthetas=11,
)

from functools import cache
import gc

from spf.dataset.spf_dataset import v5_collate_beamsegnet, v5_thetas_to_targets
from spf.model_training_and_inference.models.beamsegnet import (
    BeamNSegNetDirect,
    BeamNSegNetDiscrete,
    # BeamNetDirect,
    UNet1D,
    ConvNet,
)

torch_device = torch.device("cpu")
nthetas = 11
lr = 0.001


dataloader_params = {
    "batch_size": 4,
    "shuffle": True,
    "num_workers": 0,
    "collate_fn": v5_collate_beamsegnet,
}
torch.manual_seed(1337)
train_dataloader = torch.utils.data.DataLoader(ds, **dataloader_params)

import random

w = False
if w:

    import wandb

    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="projectspf",
        # track hyperparameters and run metadata
        config={
            "learning_rate": lr,
            "architecture": "beamsegnet1",
        },
    )


@cache
def mean_guess(shape):
    return torch.nn.functional.normalize(torch.ones(shape), p=1, dim=1)


X, Y_rad, segmentation = next(iter(train_dataloader))


def batch_to_gt_segmentation(X, Y_rad, segmentation):
    n, _, samples_per_session = X.shape
    window_size = 2048
    stride = 2048
    assert window_size == stride
    assert samples_per_session % window_size == 0
    n_windows = samples_per_session // window_size
    window_status = torch.zeros(n, n_windows)
    for row_idx in range(len(segmentation)):
        for window in segmentation[row_idx]["simple_segmentation"]:
            window_status[
                row_idx,
                window["start_idx"] // window_size : window["end_idx"] // window_size,
            ] = 1
    return window_status[:, None]


def segmentation_mask(X, segmentations):
    seg_mask = torch.zeros(
        X.shape[0], X.shape[2], device=X.device
    )  # X.new(X.shape[0], X.shape[2])
    for row_idx in range(seg_mask.shape[0]):
        for w in segmentations[row_idx]["simple_segmentation"]:
            seg_mask[row_idx, w["start_idx"] : w["end_idx"]] = 1
    return seg_mask[:, None]  # orch.nn.functional.normalize(seg_mask, p=1, dim=1)


# m = BeamNSegNetDiscrete(nthetas=nthetas, symmetry=False).to(torch_device)
# m = BeamNSegNetDirect(nthetas=nthetas, symmetry=False).to(torch_device)
# print("ALL", segmentation[0]["all_windows_stats"].shape)
m = UNet1D().to(torch_device).double()
# m = ConvNet(in_channels=3, out_channels=1, hidden=32)
optimizer = torch.optim.Adam(m.parameters(), lr=0.00001, weight_decay=0)
step = 0
sigmoid = torch.nn.Sigmoid()
X = X.double().to(torch_device)
# X[:, :2] /= 500
for epoch in range(10000):
    # for X, Y_rad, segmentation in train_dataloader:
    if True:
        optimizer.zero_grad()

        # full
        input = X.clone().to(torch_device)
        output = m(input)

        seg_mask = segmentation_mask(X, segmentation)
        print(input.shape, output.shape, seg_mask.shape)

        # downsampled
        # input = torch.Tensor(
        #     np.vstack(
        #         [
        #             segmentation[idx]["all_windows_stats"].transpose()[None]
        #             for idx in range(len(segmentation))
        #         ]
        #     )
        # )
        # input[:, 2] /= 50
        # output = m(input)
        # seg_mask = batch_to_gt_segmentation(X, Y_rad, segmentation)

        loss = ((output - seg_mask) ** 2).mean()
        loss.backward()
        optimizer.step()

        to_log = {"loss": loss.item()}

        _input = input.cpu()
        _output = output.cpu().detach().numpy()
        first_n = 3000

        if step % 1000 == 0:
            print(loss.item())
            fig, axs = plt.subplots(1, 3, figsize=(8, 3))
            s = 0.3
            axs[0].set_title("input (track 0/1)")
            axs[0].scatter(range(first_n), _input[0, 0, :first_n], s=s)
            axs[0].scatter(range(first_n), _input[0, 1, :first_n], s=s)
            axs[1].set_title("input (track 2)")
            axs[1].scatter(range(first_n), _input[0, 2, :first_n], s=s)
            # mw = mask_weights.cpu().detach().numpy()

            axs[2].set_title("output vs gt")
            axs[2].scatter(range(first_n), _output[0, 0, :first_n], s=s)
            axs[2].scatter(
                range(first_n), seg_mask.cpu().detach().numpy()[0, 0, :first_n], s=s
            )
            to_log["fig"] = fig
        if w:
            wandb.log(to_log)
        step += 1


# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

In [None]:
X.shape

In [None]:
output.shape, seg_mask.shape

In [None]:
repo_root = "/Users/miskodzamba/Dropbox/research/gits/spf/"
import sys
import torch

if repo_root not in sys.path:
    sys.path.append(repo_root)  # go to parent dir

from spf.dataset.spf_dataset import v5spfdataset
import matplotlib.pyplot as plt

torch_device = torch.device("cpu")
nthetas = 11
lr = 0.001
batch_size = 8

ds = v5spfdataset(
    "/Volumes/SPFData/missions/april5/wallarrayv3_2024_05_06_19_04_15_nRX2_bounce",
    nthetas=11,
)

from functools import cache
import gc

from spf.dataset.spf_dataset import v5_collate_beamsegnet, v5_thetas_to_targets
from spf.model_training_and_inference.models.beamsegnet import (
    BeamNSegNet,
    BeamNetDirect,
    BeamNetDiscrete,
    ConvNet,
    UNet1D,
)

torch.autograd.set_detect_anomaly(True)


dataloader_params = {
    "batch_size": batch_size,
    "shuffle": True,
    "num_workers": 0,
    "collate_fn": v5_collate_beamsegnet,
}
torch.manual_seed(1337)
train_dataloader = torch.utils.data.DataLoader(ds, **dataloader_params)
w = False
if w:
    import wandb

    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="projectspf",
        # track hyperparameters and run metadata
        config={
            "learning_rate": lr,
            "architecture": "beamsegnet1",
        },
    )


def plot_instance(_x, _output_seg, _seg_mask, idx=0):
    fig, axs = plt.subplots(1, 3, figsize=(8, 3))
    s = 0.3
    axs[0].set_title("input (track 0/1)")
    axs[0].scatter(range(first_n), _x[idx, 0, :first_n], s=s)
    axs[0].scatter(range(first_n), _x[idx, 1, :first_n], s=s)
    axs[1].set_title("input (track 2)")
    axs[1].scatter(range(first_n), _x[idx, 2, :first_n], s=s)
    # mw = mask_weights.cpu().detach().numpy()

    axs[2].set_title("output vs gt")
    axs[2].scatter(range(first_n), _output_seg[idx, 0, :first_n], s=s)
    axs[2].scatter(range(first_n), _seg_mask[idx, 0, :first_n], s=s)
    return fig


batch_data = next(iter(train_dataloader))
import pickle

pickle.dump(batch_data, open("test_batch.pkl", "wb"))
skip_segmentation = False
segmentation_level = "downsampled"
if segmentation_level == "full":
    first_n = 10000
    seg_m = UNet1D().to(torch_device)
elif segmentation_level == "downsampled":
    first_n = 256
    seg_m = ConvNet(3, 1, 32, bn=True).to(torch_device)

import torch.nn as nn

beam_m = BeamNetDirect(
    nthetas=nthetas, hidden=16, symmetry=True, other=True, act=nn.SELU, bn=True
).to(torch_device)
# beam_m = BeamNetDiscrete(nthetas=nthetas, hidden=16, symmetry=False).to(torch_device)
m = BeamNSegNet(segnet=seg_m, beamnet=beam_m, circular_mean=True).to(torch_device)

optimizer = torch.optim.AdamW(seg_m.parameters(), lr=0.01, weight_decay=0)

step = 0
head_start = 200
for epoch in range(10000):
    if step == head_start:
        optimizer = torch.optim.AdamW(beam_m.parameters(), lr=0.001, weight_decay=0)
        optimizer.zero_grad()
    # for X, Y_rad in train_dataloader:
    optimizer.zero_grad()

    # copy to torch device
    if segmentation_level == "full":
        x = batch_data["x"].to(torch_device)
        y_rad = batch_data["y_rad"].to(torch_device)
        seg_mask = batch_data["segmentation_mask"].to(torch_device)
    elif segmentation_level == "downsampled":
        x = batch_data["all_windows_stats"].to(torch_device)
        y_rad = batch_data["y_rad"].to(torch_device)
        seg_mask = batch_data["downsampled_segmentation_mask"].to(torch_device)
    else:
        raise NotImplementedError

    assert seg_mask.ndim == 3 and seg_mask.shape[1] == 1

    # run beamformer and segmentation
    if not skip_segmentation:
        output = m(x)
    else:
        output = m(x, seg_mask)

    # x to beamformer loss (indirectly including segmentation)
    x_to_beamformer_loss = -beam_m.loglikelihood(output["pred_theta"], y_rad)
    assert x_to_beamformer_loss.shape == (batch_size, 1)
    x_to_beamformer_loss = x_to_beamformer_loss.mean()

    # segmentation loss
    x_to_segmentation_loss = (output["segmentation"] - seg_mask) ** 2
    assert x_to_segmentation_loss.ndim == 3 and x_to_segmentation_loss.shape[1] == 1
    x_to_segmentation_loss = x_to_segmentation_loss.mean()

    if skip_segmentation:
        loss = x_to_beamformer_loss
    else:
        if step >= head_start:
            loss = x_to_beamformer_loss
        else:
            loss = x_to_segmentation_loss
    # if step in [799, 780]:
    #     print(step, output)
    loss.backward()
    optimizer.step()

    to_log = {
        "loss": loss.item(),
        "segmentation_loss": x_to_segmentation_loss.item(),
        "beam_former_loss": x_to_beamformer_loss.item(),
    }
    if step % 500 == 0:
        # beam outputs
        img_beam_output = (
            (beam_m.render_discrete_x(output["pred_theta"]) * 255).cpu().byte()
        )
        img_beam_gt = (beam_m.render_discrete_y(y_rad) * 255).cpu().byte()
        train_target_image = torch.zeros(
            (img_beam_output.shape[0] * 2, img_beam_output.shape[1]),
        ).byte()
        for row_idx in range(img_beam_output.shape[0]):
            train_target_image[row_idx * 2] = img_beam_output[row_idx]
            train_target_image[row_idx * 2 + 1] = img_beam_gt[row_idx]
        if w:
            output_image = wandb.Image(
                train_target_image, caption="train vs target (interleaved)"
            )
            to_log["output"] = output_image

        # segmentation output
        _x = x.detach().cpu().numpy()
        _seg_mask = seg_mask.detach().cpu().numpy()
        # _output_seg = output_segmentation_upscaled.detach().cpu().numpy()
        _output_seg = output["segmentation"].detach().cpu().numpy()

        fig = plot_instance(_x, _output_seg, _seg_mask, idx=0)
        if w:
            to_log["fig"] = fig
    if w:
        wandb.log(to_log)
    else:
        # if step > 760 and step < 800:
        if step % 20 == 0:
            print(
                step,
                loss.item(),
                x_to_beamformer_loss.item(),
                x_to_segmentation_loss.item(),
            )
    step += 1

# [optional] finish the wandb run, necessary in notebooks
if w:
    wandb.finish()

In [None]:
import torch.nn as nn

z = nn.Sigmoid()

In [None]:
z(torch.Tensor([-10, 0, 10]))

In [None]:
weighted_input = torch.mul(x, output["segmentation"]).sum(axis=2) / output[
    "segmentation"
].sum(axis=2)

In [None]:
param = m.beamnet.fixify(m.beamnet.beam_net(weighted_input), sign=1)

m.beamnet.likelihood(param, y_rad)

In [None]:
param, y_rad

In [None]:
_param = param.clone()
_param[:, 1] = 100
m.beamnet.likelihood(_param, y_rad)

In [None]:
y_rad.clamp(min=0, max=0.1)
y_rad

In [None]:
m.beamnet.beam_net(weighted_input)

In [None]:
m.beamnet.beam_net(weighted_input)

In [None]:
output["pred_theta"]

In [None]:
weighted_input = torch.mul(x, output["segmentation"]).sum(axis=2) / output[
    "segmentation"
].sum(axis=2)
weighted_input.shape

In [None]:
(output["pred_theta"][:, 0] - y_rad).shape, x_to_beamformer_loss.shape

In [None]:
beam_m.loglikelihood(output["pred_theta"], y_rad).shape

In [None]:
torch.mul(output["beam_former"], output["segmentation"]).sum(axis=2) / output[
    "segmentation"
].sum(axis=2)

In [None]:
x = output["pred_theta"]
y = y_rad
(x[:, 3] * torch.exp(-((x[:, 0] - y) ** 2) / x[:, 1])).shape

In [None]:
output["pred_theta"].shape

In [None]:
output["beam_former"].shape

In [None]:

    output_segmentation_upscaled = output["segmentation"] * seg_mask.sum(
        axis=2, keepdim=True
    )
    x_to_segmentation_loss = (output_segmentation_upscaled - seg_mask) ** 2

In [None]:
output["mask_weights"].shape, output["segmentation"].shape, output["beam_former"].shape
k = torch.mul(output["beam_former"], output["segmentation"]) / output[
    "segmentation"
].sum(axis=2, keepdim=True)
k.shape

In [None]:
# output_segmentation_upscaled = output["segmentation"] * seg_mask.sum()
# x_to_segmentation_loss = (output_segmentation_upscaled - seg_mask) ** 2
(output["segmentation"] * seg_mask.sum(axis=2, keepdim=True)).sum(axis=2)

In [None]:
seg_mask.sum(axis=2, keepdim=True)

In [None]:
z = output["segmentation"].detach().cpu().numpy()[0, 0]
# =_p_seg_mask[0,0]
# z=_output_seg[0,0]
plt.scatter(range(len(z)), z)

In [None]:
output["segmentation"].shape

In [None]:
X[:, 1, :].mean(), X[:, 1, :].std()

In [None]:
Y_rad

In [None]:
output.shape, Y_rad.shape

In [None]:
segmentation[0]["all_windows_stats"].shape

In [None]:
segmentation_mask(X, segmentation)

In [None]:
m(X)

In [None]:
_X = X.clone().to(torch_device)
_X[:, :2] /= 500
batch_size, input_channels, session_size = _X.shape
beam_former_input = _X.transpose(1, 2).reshape(
    batch_size * session_size, input_channels
)
print(_X.device, beam_former_input)
beam_former = m.beam_net(beam_former_input).reshape(
    batch_size, session_size, 5  # mu, o1, o2, k1, k2
)
mask_weights = m.softmax(m.unet1d(_X)[:, 0])

In [None]:
beam_former_input

In [None]:
ds[0]

In [None]:
seg_mask.sum(axis=1)

In [None]:
seg_mask.cpu().detach().numpy()[0].sum()

In [None]:
first_n = 40000
x = X[0].cpu()

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].scatter(range(first_n), x[0, :first_n], s=0.3)
axs[0].scatter(range(first_n), x[1, :first_n], s=0.3)
axs[1].scatter(range(first_n), x[2, :first_n], s=0.3)
# mw = mask_weights.cpu().detach().numpy()
mw = m(X).cpu().detach().numpy()[0]
axs[2].scatter(range(first_n), mw[0, :first_n], s=0.3)
axs[2].scatter(range(first_n), seg_mask.cpu().detach().numpy()[0, :first_n], s=0.3)

In [None]:
mask_weights[0]

In [None]:
from spf.model_training_and_inference.models.beamsegnet import BeamNSegNetDirect


m = BeamNSegNetDirect(nthetas=nthetas)

optimizer = torch.optim.AdamW(m.parameters(), lr=0.01)

m.beam_net.beam_net[0].weight.grad

In [None]:
x, y = next(iter(train_dataloader))

In [None]:
k = x[[0]]
k_y = y[[0]]
k[:, 2] = -k[:, 2].sign() * k[:, 2]
# k[:, 2] = k[:, 2].sign() * k[:, 2]

In [None]:
k[:, 2]

In [None]:
optimizer.zero_grad()
m.train()
m.beam_net.beam_net[0].weight.grad

In [None]:
X.max()

In [None]:
output = m(k)

loss_fn = torch.nn.MSELoss()
l = loss_fn(output, k_y)
l.backward()
# mean_loss = output
# optimizer.step()

In [None]:
output

In [None]:
m.beam_net.beam_net[0].weight.grad

In [None]:
plt.imshow(Y.to("cpu"))

In [None]:
import torch


def detrend_iq(iq_tensor):
    print(iq_tensor.shape)
    """
    Remove linear trends from the I and Q components of a 1D complex PyTorch tensor.
    
    Parameters:
    - iq_tensor (torch.Tensor): 1D complex tensor of shape (N,), dtype=torch.complex64 or torch.complex128.
    
    Returns:
    - detrended_iq (torch.Tensor): 1D complex tensor of shape (N,), with linear trends removed from I and Q.
    """
    if not torch.is_complex(iq_tensor):
        raise ValueError("Input tensor must be a complex tensor.")
    if iq_tensor.dim() != 1:
        raise ValueError("Input tensor must be 1D.")

    # Number of samples
    N = iq_tensor.shape[0]

    # Time vector
    t = torch.linspace(0, 1, steps=N, device=iq_tensor.device).unsqueeze(
        1
    )  # Shape: (N, 1)

    # Design the design matrix for linear regression [t, 1]
    X = torch.cat([t, torch.ones_like(t)], dim=1)  # Shape: (N, 2)

    # Separate I and Q components
    I = iq_tensor.real  # Shape: (N,)
    Q = iq_tensor.imag  # Shape: (N,)

    # Add a dimension for matrix operations
    I = I.unsqueeze(1)  # Shape: (N, 1)
    Q = Q.unsqueeze(1)  # Shape: (N, 1)

    # Perform least squares linear regression to find slope and intercept
    # Using torch.linalg.lstsq (available in PyTorch >=1.9)
    coeffs_I = torch.linalg.lstsq(X, I).solution  # Shape: (2, 1)
    coeffs_Q = torch.linalg.lstsq(X, Q).solution  # Shape: (2, 1)

    # Compute the fitted trends
    trend_I = X @ coeffs_I  # Shape: (N, 1)
    trend_Q = X @ coeffs_Q  # Shape: (N, 1)

    # Subtract the trends to detrend
    I_detrended = I - trend_I  # Shape: (N, 1)
    Q_detrended = Q - trend_Q  # Shape: (N, 1)

    # Remove the extra dimension and combine into complex tensor
    I_detrended = I_detrended.squeeze(1)  # Shape: (N,)
    Q_detrended = Q_detrended.squeeze(1)  # Shape: (N,)
    detrended_iq = torch.complex(I_detrended, Q_detrended)  # Shape: (N,)

    return detrended_iq