In [None]:
# def write_to_record_matrix(self, thread_idx, record_idx, data):
#     tx_pos = self.position_controller.controller.position["xy"][
#         self.yaml_config["emitter"]["motor_channel"]
#     ]
#     rx_pos = self.position_controller.controller.position["xy"][
#         self.rx_configs[0].motor_channel
#     ]

#     data.tx_pos_x_mm = tx_pos[0]
#     data.tx_pos_y_mm = tx_pos[1]
#     data.rx_pos_x_mm = rx_pos[0]
#     data.rx_pos_y_mm = rx_pos[1]
#     assert data.rx_lo > 1

#     if not self.yaml_config["dry-run"]:
#         z = self.zarr[f"receivers/r{thread_idx}"]
#         z.signal_matrix[record_idx] = data.signal_matrix
#         for k in v5rx_f64_keys + v5rx_2xf64_keys:
#             z[k][record_idx] = getattr(data, k)

In [None]:
from spf.dataset.spf_dataset import v5spfdataset

from spf.dataset.v5_data import v5rx_2xf64_keys, v5rx_f64_keys, v5rx_new_dataset

ds = v5spfdataset(
    "/mnt/md2/2d_wallarray_v2_data/april_nuand/wallarrayv3_2025_04_04_21_06_04_nRX2_rx_random_circle_spacing0p08.zarr",
    nthetas=65,
    ignore_qc=True,
    precompute_cache="/mnt/md2/cache/precompute_cache_3p5_chunk1/",
    gpu=False,
    snapshots_per_session=1,
    n_parallel=8,
    paired=True,
    segmentation_version=3.5,
)

In [None]:
# for idx in range(100):
#     ds[idx]

In [None]:
v5_raw_keys = v5rx_f64_keys + v5rx_2xf64_keys + ["signal_matrix"]


def data_single_radio_to_raw(d):
    return {k: d[k] for k in v5_raw_keys}

In [None]:
import multiprocessing
import pickle
from typing import List
from spf.dataset.segmentation import (
    DEFAULT_SEGMENT_ARGS,
    mean_phase_from_simple_segmentation,
    segment_session,
)
from spf.dataset.spf_dataset import (
    data_from_precomputed,
    encode_vehicle_type,
    get_empirical_dist,
    uri_to_device_type,
)
from spf.scripts.train_single_point import (
    global_config_to_keys_used,
    load_config_from_fn,
)
from spf.sdrpluto.sdr_controller import rx_config_from_receiver_yaml
from spf.utils import SEGMENTATION_VERSION, load_config, to_bin
import torch
from torch.utils.data import Dataset
from spf.rf import (
    precompute_steering_vectors,
    speed_of_light,
    torch_get_phase_diff,
)

training_only_keys = [
    "ground_truth_theta",
    "ground_truth_phi",
    "craft_ground_truth_theta",
    "y_rad",
    "y_phi",
    "craft_y_rad",
    "y_rad_binned",
]

segmentation_based_keys = [
    "weighted_beamformer",
    "all_windows_stats",
    "weighted_windows_stats",
    "downsampled_segmentation_mask",
    "simple_segmentations",
    "mean_phase_segmentation",
]


class v5inferencedataset(Dataset):
    def __init__(
        self,
        yaml_fn: str,
        nthetas: int,  # Number of theta angles for beamforming discretization
        model_config_fn: str = "",
        paired: bool = False,  # If True, return paired samples from all receivers at once
        gpu: bool = False,  # Use GPU for segmentation computation if available
        skip_fields: List[
            str
        ] = [],  # Data fields to exclude during loading to save memory
        n_parallel: int = 20,  # Number of parallel processes for segmentation
        empirical_data_fn: (
            str | None
        ) = None,  # Path to empirical distribution data file for phase-to-angle mapping
        empirical_individual_radio: bool = False,  # Use per-radio empirical distributions if True
        empirical_symmetry: bool = True,  # Use symmetric empirical distributions if True
        target_dtype=torch.float32,  # Target dtype for tensor conversion (memory optimization)
        distance_normalization: int = 1000,  # Divisor to normalize distance measurements (mm to meters)
        target_ntheta: (
            bool | None
        ) = None,  # Target number of theta bins for classification (defaults to nthetas)
        windows_per_snapshot: int = 256,  # Maximum number of windows per snapshot to use
        skip_detrend: bool = False,
        skip_segmentation: bool = True,
        vehicle_type: str = "",
        max_in_memory: int = 10,
    ):
        # Store configuration parameters
        self.yaml_fn = yaml_fn
        self.n_parallel = n_parallel
        self.nthetas = nthetas  # Number of angles to discretize space for beamforming
        self.target_ntheta = self.nthetas if target_ntheta is None else target_ntheta

        self.max_in_memory = max_in_memory
        self.min_idx = 0
        self.condition = multiprocessing.Condition()
        self.lock = multiprocessing.Lock()
        self.store = {}

        # Segmentation parameters control how raw signal is processed into windows
        # and how phase difference is computed between antenna elements
        self.skip_detrend = skip_detrend
        self.windows_per_snapshot = windows_per_snapshot

        self.distance_normalization = distance_normalization
        self.skip_fields = skip_fields
        self.skip_segmentation = skip_segmentation
        if self.skip_segmentation:
            self.skip_fields += segmentation_based_keys
        self.paired = paired
        assert self.paired
        self.gpu = gpu  # Whether to use GPU acceleration for beamforming calculations
        self.target_dtype = target_dtype
        self.precomputed_entries = 0
        self.precomputed_zarr = (
            None  # Will hold preprocessed beamforming and segmentation data
        )

        self.yaml_config = load_config(self.yaml_fn)

        if model_config_fn != "":
            self.model_config = load_config_from_fn(model_config_fn)
            self.keys_to_get = global_config_to_keys_used(self.model_config["global"])
        else:
            self.keys_to_get = global_config_to_keys_used(None)

        # Get system metadata
        self.vehicle_type = vehicle_type

        # Extract receiver properties - important for beamforming calculations
        self.wavelengths = [
            speed_of_light / receiver["f-carrier"]  # λ = c/f
            for receiver in self.yaml_config["receivers"]
        ]
        self.carrier_frequencies = [
            receiver["f-carrier"] for receiver in self.yaml_config["receivers"]
        ]
        self.rf_bandwidths = [
            receiver["bandwidth"] for receiver in self.yaml_config["receivers"]
        ]

        # Validate that all receivers have consistent configurations
        for rx_idx in range(1, 2):
            assert (
                self.yaml_config["receivers"][0]["antenna-spacing-m"]
                == self.yaml_config["receivers"][rx_idx]["antenna-spacing-m"]
            )
            assert self.wavelengths[0] == self.wavelengths[rx_idx]
            assert self.rf_bandwidths[0] == self.rf_bandwidths[rx_idx]

        # Set up receiver spacing properties - critical for beamforming
        # Spacing between antenna elements affects phase difference and angle estimation
        self.rx_spacing = self.yaml_config["receivers"][0]["antenna-spacing-m"]
        assert self.yaml_config["receivers"][1]["antenna-spacing-m"] == self.rx_spacing

        # rx_wavelength_spacing (d/λ) is a key parameter for beamforming
        # It determines how phase differences map to arrival angles
        self.rx_wavelength_spacing = self.rx_spacing / self.wavelengths[0]

        # Create receiver configs and determine device types
        self.rx_configs = [
            rx_config_from_receiver_yaml(receiver)
            for receiver in self.yaml_config["receivers"]
        ]
        self.sdr_device_types = [
            uri_to_device_type(rx_config.uri) for rx_config in self.rx_configs
        ]

        # Ensure all receivers use the same device type
        if len(self.sdr_device_types) > 1:
            for device_type in self.sdr_device_types:
                assert device_type == self.sdr_device_types[0]
        self.sdr_device_type = self.sdr_device_types[0]

        # Precompute steering vectors for beamforming
        # Steering vectors are complex weights applied to each antenna element
        # They're used to "steer" the array to look in a specific direction
        # For each possible angle (theta), calculate the appropriate phase shifts
        self.steering_vectors = [
            precompute_steering_vectors(
                receiver_positions=rx_config.rx_pos,
                carrier_frequency=rx_config.lo,
                spacing=nthetas,
            )
            for rx_config in self.rx_configs
        ]

        # Define keys to load per session
        self.keys_per_session = (
            v5rx_f64_keys + v5rx_2xf64_keys + ["rx_wavelength_spacing"]
        )
        if "signal_matrix" not in self.skip_fields:
            self.keys_per_session.append("signal_matrix")

        # Load empirical distribution data if provided
        # These are learned phase-to-angle mappings that can improve angle estimation
        if empirical_data_fn is not None:
            self.empirical_data_fn = empirical_data_fn
            self.empirical_data = pickle.load(open(empirical_data_fn, "rb"))
            self.empirical_individual_radio = empirical_individual_radio
            self.empirical_symmetry = empirical_symmetry
        else:
            self.empirical_data_fn = None
            self.empirical_data = None

    # ASSUMING EVERYTHING WILL BE REQUESTED IN SEQUENCE!!
    def __getitem__(self, idx):
        if 
        return [
            self.render_session(receiver_idx, idx)
            for receiver_idx in range(self.n_receivers)
        ]

    def write_to_idx(self, idx, ridx, raw):
        if idx < self.min_idx:
            return  # we dont need this sample
        rendered_data = self.render_session(idx, ridx, raw)

        self.lock.acquire()
        if idx not in self.store:
            self.store[idx] = {"count": 0, "data": [None, None]}  # entry not ready
        self.store[idx]["data"][ridx] = rendered_data
        self.store[idx]["count"] += 1
        self.lock.release()

        with self.condition:
            self.condition.notify_all()

    def render_session(self, idx, ridx, data):
        snapshot_idxs = [0]  # which snapshots to get

        data["rx_wavelength_spacing"] = torch.tensor(self.rx_wavelength_spacing)

        data["gains"] = data["gains"][:, None]
        data["receiver_idx"] = torch.tensor([[ridx]], dtype=torch.int)

        data["ground_truth_theta"] = torch.tensor([torch.inf])  # unknown
        data["y_rad"] = data["ground_truth_theta"]  # torch.inf

        data["ground_truth_phi"] = torch.tensor([torch.inf])  # unkown
        data["y_phi"] = data["ground_truth_phi"]  # torch.inf

        data["craft_ground_truth_theta"] = torch.tensor([torch.inf])  # unknown
        data["craft_y_rad"] = data["craft_ground_truth_theta"]  # torch.inf

        data["vehicle_type"] = torch.tensor(
            [encode_vehicle_type(self.vehicle_type)]
        ).reshape(1)
        data["sdr_device_type"] = torch.tensor([self.sdr_device_type.value]).reshape(1)

        if "signal_matrix" not in self.skip_fields:
            # WARNGING this does not respect flipping!
            abs_signal = data["signal_matrix"].abs().to(torch.float32)
            assert data["signal_matrix"].shape[0] == 1
            pd = torch_get_phase_diff(data["signal_matrix"][0]).to(torch.float32)
            data["abs_signal_and_phase_diff"] = torch.concatenate(
                [abs_signal, pd[None, :, None]], dim=2
            )

        data["rx_pos_mm"] = torch.vstack(
            [
                data["rx_pos_x_mm"],
                data["rx_pos_y_mm"],
            ]
        ).T

        data["tx_pos_mm"] = torch.vstack(
            [
                data["tx_pos_x_mm"],
                data["tx_pos_y_mm"],
            ]
        ).T

        data["rx_pos_xy"] = (
            data["rx_pos_mm"][snapshot_idxs].unsqueeze(0) / self.distance_normalization
        )

        data["tx_pos_xy"] = (
            data["tx_pos_mm"][snapshot_idxs].unsqueeze(0) / self.distance_normalization
        )

        segmentation = segment_session(
            data["signal_matrix"][0][ridx],
            gpu=False,
            skip_beamformer=False,
            skip_detrend=False,
            skip_segmentation=self.skip_segmentation,
            **{
                "steering_vectors": self.steering_vectors[ridx],
                **DEFAULT_SEGMENT_ARGS,
            },
        )

        data.update(
            data_from_precomputed(
                v5ds=self,
                precomputed_data=segmentation,
                segmentation=[segmentation],
                snapshot_idxs=[0],
            )
        )
        if not self.skip_segmentation:
            data["mean_phase_segmentation"] = torch.tensor(
                mean_phase_from_simple_segmentation([segmentation])
            ).unsqueeze(0)

            if self.empirical_data is not None:
                empirical_dist = get_empirical_dist(self, ridx)
                #  ~ 1, snapshots, ntheta(empirical_dist.shape[0])
                data["empirical"] = empirical_dist[
                    to_bin(data["mean_phase_segmentation"][0], empirical_dist.shape[0])
                ].unsqueeze(0)
                mask = data["mean_phase_segmentation"].isnan()
                data["empirical"][mask] = 1.0 / empirical_dist.shape[0]

        data["y_rad_binned"] = (
            to_bin(data["y_rad"], self.target_ntheta).unsqueeze(0).to(torch.long)
        )
        data["craft_y_rad_binned"] = (
            to_bin(data["craft_y_rad"], self.target_ntheta).unsqueeze(0).to(torch.long)
        )

        # convert to target dtype on CPU!
        for key in data:
            if isinstance(data[key], torch.Tensor) and data[key].dtype in (
                torch.float16,
                torch.float32,
                torch.float64,
            ):
                data[key] = data[key].to(self.target_dtype)
        return data

In [None]:
ds[0][0]["mean_phase_segmentation"].shape

In [None]:
from spf.dataset.segmentation import DEFAULT_SEGMENT_ARGS, segment_session

v5inf = v5inferencedataset(
    yaml_fn="/mnt/md2/2d_wallarray_v2_data/april_nuand/wallarrayv3_2025_04_04_21_06_04_nRX2_rx_random_circle_spacing0p08.yaml",
    nthetas=65,
    gpu=False,
    n_parallel=8,
    paired=True,
    model_config_fn="",
    vehicle_type="wallarray",
    skip_segmentation=True,
)
a = ds[0][0]

In [None]:
a["receiver_idx"]

In [None]:
b = v5inf.write_to_idx(0, 0, data_single_radio_to_raw(a))
a = 1

In [None]:
a["simple_segmentations"][0]

In [None]:
def compare_two_entries(a, b):
    for k, v in a.items():
        if k in training_only_keys:
            continue
        assert k in b
        vp = b[k]
        print(k)
        if isinstance(v, torch.Tensor):
            assert v.isclose(vp, rtol=1e-3).all(), f"{v} {vp}"
        elif isinstance(v, List):
            s = v[0]
            sp = vp[0]
            assert len(s) == len(sp)
            for idx in range(len(s)):
                assert s[idx]["start_idx"] == sp[idx]["start_idx"]
                assert s[idx]["end_idx"] == sp[idx]["end_idx"]

In [None]:
from spf.dataset.segmentation import DEFAULT_SEGMENT_ARGS, segment_session

x = v5inf.write_to_idx(0, 0, data_single_radio_to_raw(ds[0][0]))
receiver_idx = 0

s = segment_session(
    x["signal_matrix"][0][receiver_idx],
    gpu=False,
    skip_beamformer=False,
    skip_detrend=False,
    skip_segmentation=True,
    **{
        "steering_vectors": v5inf.steering_vectors[receiver_idx],
        **DEFAULT_SEGMENT_ARGS,
    },
)

d = data_from_precomputed(
    v5ds=v5inf,
    precomputed_data=s,
    segmentation=[s],
    snapshot_idxs=[0],
)

s.keys(), d.keys()

In [None]:
# take a single zarr, receiver and session_idx and segment it
from spf.dataset.segmentation import simple_segment
from spf.rf import beamformer_given_steering_nomean, reduce_theta_to_positive_y
from spf.sdrpluto.detrend import detrend_np
import numpy as np

from scipy.stats import trim_mean


def segment_single_session(
    raw,
    receiver_idx,
    skip_beamformer=False,
    skip_detrend=False,
    **kwrgs,
):
    v = raw["signal_matrix"][0][receiver_idx]

    if not skip_detrend:
        v = detrend_np(v)

    segmentation_results = simple_segment(v, **kwrgs)

    segmentation_results["all_windows_stats"] = (
        segmentation_results["all_windows_stats"].astype(np.float16).T
    )

    # Get dimensions for further processing
    _, windows = segmentation_results["all_windows_stats"].shape
    nthetas = kwrgs["steering_vectors"].shape[0]  # Number of angle bins for beamforming

    if not skip_beamformer:

        # CPU version of beamforming (same algorithm but slower)
        segmentation_results["windowed_beamformer"] = (
            beamformer_given_steering_nomean(
                steering_vectors=kwrgs["steering_vectors"],
                signal_matrix=v.astype(np.complex64),
            )
            .reshape(nthetas, -1, kwrgs["window_size"])
            .mean(axis=2)
            .T
        )

        # Calculate a weighted beamformer output for the entire session
        # by combining window-level beamformer outputs, using the segmentation mask
        # as weights (so only signal windows contribute)
        weighted_beamformer = (
            segmentation_results["windowed_beamformer"].astype(np.float32)
            * segmentation_results["downsampled_segmentation_mask"][:, None]
        ).sum(axis=0) / (
            segmentation_results["downsampled_segmentation_mask"].sum() + 0.001
        )

        # Convert to float16 to save memory
        segmentation_results["windowed_beamformer"] = segmentation_results[
            "windowed_beamformer"
        ].astype(np.float16)

        # Store the session-level weighted beamformer
        segmentation_results["weighted_beamformer"] = weighted_beamformer
    else:
        # If skipping beamforming, create empty placeholders
        windowed_beamformer = np.zeros((windows, nthetas), dtype=np.float16)
        windowed_beamformer.fill(np.nan)
        segmentation_results["windowed_beamformer"] = windowed_beamformer

    # Calculate session-level statistics from the identified signal windows
    if segmentation_results["downsampled_segmentation_mask"].sum() > 0:
        # Calculate trimmed mean of phase differences from signal windows only
        # Trimming removes extreme values to make the mean more robust
        mean_phase = trim_mean(
            reduce_theta_to_positive_y(segmentation_results["all_windows_stats"][0])[
                segmentation_results["downsampled_segmentation_mask"]
            ].astype(np.float32),
            0.1,  # Trim 10% from both ends
        )

        # Calculate trimmed mean of standard deviation and signal amplitude
        stddev_and_abs_signal = trim_mean(
            segmentation_results["all_windows_stats"][1:][
                :, segmentation_results["downsampled_segmentation_mask"]
            ].astype(np.float32),
            0.1,
            axis=1,
        )

        # Store the session-level statistics
        segmentation_results["weighted_stats"] = np.array(
            [mean_phase, stddev_and_abs_signal[0], stddev_and_abs_signal[1]],
            dtype=np.float32,
        )
    else:
        # If no signal windows were identified, use placeholder values
        segmentation_results["weighted_stats"] = np.array([-1, -1, -1])

    return segmentation_results

In [None]:
x["signal_matrix"]

In [None]:
for k in v5rx_f64_keys + v5rx_2xf64_keys:
    print(k)
    if k in ds[0][0]:
        print(ds[0][0][k])
    # print(k,getattr(ds[0][0],k,None))

In [None]:
ds_fn = "/mnt/md0/2d_wallarray_v2_data/oct_batch2/wallarrayv3_2024_10_26_01_14_28_nRX2_rx_circle_spacing0p075.zarr"

ds1 = v5spfdataset(
    ds_fn,
    nthetas=65,
    ignore_qc=True,
    precompute_cache="/mnt/md2/cache/precompute_cache_3p5_chunk1/",
    gpu=False,
    snapshots_per_session=1,
    n_parallel=8,
    paired=True,
    segmentation_version=3.5,
)

In [None]:
ds2 = v5spfdataset(
    ds_fn,
    nthetas=65,
    ignore_qc=True,
    precompute_cache="/mnt/md2/cache/precompute_cache_3p6/",
    gpu=False,
    snapshots_per_session=1,
    n_parallel=8,
    paired=True,
    segmentation_version=3.6,
)

In [None]:
ds3 = v5spfdataset(
    ds_fn,
    nthetas=65,
    ignore_qc=True,
    precompute_cache="/mnt/md2/cache/precompute_cache_3p6x/",
    gpu=False,
    snapshots_per_session=1,
    n_parallel=8,
    paired=True,
    segmentation_version=3.6,
)

In [None]:
def comparez(a, b):
    for k, v in a.items():
        if k not in b:
            print("B missing", k)
        else:
            vb = b[k]
            if isinstance(v, torch.Tensor):
                if (~v.isfinite()).all():
                    pass  # all infinite
                elif (v.isclose(vb, rtol=1e-3) * 1.0).mean() > 0.99:
                    # print(k,"pass")
                    # print((v.isclose(vb,rtol=1e-3)*1.0).mean())
                    pass
                else:
                    # print((v.isclose(vb,rtol=1e-3)*1.0).mean())
                    print(k, "fail", (v.isclose(vb) * 1.0).mean(), v, vb)


for idx in range(30):
    for ridx in range(2):
        z1 = ds1[idx][ridx]
        z2 = ds2[idx][ridx]
        z3 = ds3[idx][ridx]
        comparez(z1, z2)
        comparez(z2, z3)

In [None]:
~torch.tensor(torch.inf).isfinite()

In [None]:
ds2[0][0]["windowed_beamformer"]

In [None]:
ds1[0][0]["windowed_beamformer"]