In [None]:
repo_root = "/Users/miskodzamba/Dropbox/research/gits/spf/"
import sys


if repo_root not in sys.path:
    sys.path.append(repo_root)  # go to parent dir

In [None]:
from torch.utils.data import Dataset
import yaml
import torch
from spf.data_collector import rx_config_from_receiver_yaml
from spf.dataset.spf_dataset import pi_norm
from spf.rf import precompute_steering_vectors
from spf.utils import zarr_open_from_lmdb_store
from spf.dataset.v5_data import v5rx_f64_keys, v5rx_2xf64_keys
import numpy as np


class v5spfdataset(Dataset):
    def __init__(self, prefix, nthetas=64 + 1):
        prefix = prefix.replace(".zarr", "")
        self.zarr_fn = f"{prefix}.zarr"
        self.yaml_fn = f"{prefix}.yaml"
        self.z = zarr_open_from_lmdb_store(self.zarr_fn)
        self.yaml_config = yaml.safe_load(open(self.yaml_fn, "r"))

        self.n_receivers = len(self.yaml_config["receivers"])

        self.rx_configs = [
            rx_config_from_receiver_yaml(receiver)
            for receiver in self.yaml_config["receivers"]
        ]

        self.receiver_data = [
            self.z.receivers[f"r{ridx}"] for ridx in range(self.n_receivers)
        ]

        self.signal_matrices = [
            self.z.receivers[f"r{ridx}"].signal_matrix
            for ridx in range(self.n_receivers)
        ]

        self.n_sessions, self.n_antennas_per_receiver, self.session_length = (
            self.signal_matrices[0].shape
        )
        assert self.n_antennas_per_receiver == 2
        for ridx in range(self.n_receivers):
            assert self.signal_matrices[ridx].shape == (
                self.n_sessions,
                self.n_antennas_per_receiver,
                self.session_length,
            )

        self.steering_vectors = [
            precompute_steering_vectors(
                receiver_positions=rx_config.rx_pos,
                carrier_frequency=rx_config.lo,
                spacing=nthetas,
            )
            for rx_config in self.rx_configs
        ]

        self.keys_per_session = v5rx_f64_keys + v5rx_2xf64_keys + ["signal_matrix"]

    def __len__(self):
        return self.n_sessions * self.n_receivers

    def render_session(self, receiver_idx, session_idx):
        r = self.receiver_data[receiver_idx]
        data = {key: r[key][session_idx] for key in self.keys_per_session}

        data["estimated_theta"] = self.get_theta(data)
        data = {
            k: (
                torch.from_numpy(v)
                if type(v) not in (np.float64, float)
                else torch.Tensor([v])
            )
            for k, v in data.items()
        }
        return data

    def get_theta(self, session):
        rx_theta_in_pis = session["rx_theta_in_pis"]
        tx_pos = np.array([session["tx_pos_x_mm"], session["tx_pos_y_mm"]])
        rx_pos = np.array([session["rx_pos_x_mm"], session["rx_pos_y_mm"]])

        # compute the angle of the tx with respect to rx
        d = tx_pos - rx_pos

        rx_to_tx_theta = np.arctan2(d[0], d[1])
        return pi_norm(rx_to_tx_theta - rx_theta_in_pis * np.pi)

    def __getitem__(self, idx):
        assert idx < self.n_sessions * self.n_receivers
        receiver_idx = idx % self.n_receivers
        return self.render_session(receiver_idx, idx // self.n_receivers)

    def get_sample_for_receiver():
        pass

In [None]:
ds = v5spfdataset(
    # "/Volumes/SPFData/missions/april5/wallarrayv3_2024_04_05_22_13_07_nRX2_rx_circle"
    "/Volumes/SPFData2/missions/april5/wallarrayv3_2024_05_06_19_04_15_nRX2_bounce"
)

In [None]:
k = [ds[x] for x in range(1000)]

In [None]:
np.array([1, 2]).shape

In [None]:
torch.hstack([torch.Tensor([1]), torch.Tensor([2])]).shape