In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

import sys

sys.path.append("/Users/miskodzamba/Dropbox/research/gits/")
sys.path.append(
    "/Users/miskodzamba/Dropbox/research/gits/spf/software/model_training_and_inference"
)
import importlib
import spf.software.model_training_and_inference.utils.rf as rf


from spf.software.model_training_and_inference.utils.image_utils import (
    detector_positions_to_theta_grid,
    labels_to_source_images,
    radio_to_image,
)
from spf.software.model_training_and_inference.utils.spf_generate import (
    generate_session,
)
from compress_pickle import dump, load

output_cols = {  # maybe this should get moved to the dataset part...
    "src_pos": [0, 1],
    "src_theta": [2],
    "src_dist": [3],
    "det_delta": [4, 5],
    "det_theta": [6],
    "det_space": [7],
    "src_v": [8, 9],
}

input_cols = {
    "det_pos": [0, 1],
    "time": [2],
    "space_delta": [3, 4],
    "space_theta": [5],
    "space_dist": [6],
    "det_theta2": [7],
}

In [None]:
class SessionsDatasetReal(Dataset):

    def __init__(
        self,
        root_dir,
        snapshots_in_file=400000,
        nthetas=65,
        snapshots_in_sample=128,
        nsources=1,
        width=3000,
        receiver_pos_x=1352.7,
        receiver_pos_y=2647.7,
        receiver_spacing=60.0,
    ):
        # time_step,x,y,mean_angle,_mean_angle #0,1,2,3,4
        # m = np.memmap(filename, dtype='float32', mode='r', shape=(,70))
        """
        Arguments:
          root_dir (string): Directory with all the images.
        """
        assert nsources == 1  # TODO implement more
        self.root_dir = root_dir
        self.nthetas = nthetas
        self.thetas = np.linspace(-np.pi, np.pi, self.nthetas)
        self.args = {
            "width": width,
        }
        self.receiver_pos = np.array(
            [
                [receiver_pos_x - receiver_spacing / 2, receiver_pos_y],
                [receiver_pos_x + receiver_spacing / 2, receiver_pos_y],
            ]
        )
        self.detector_position = np.array([[receiver_pos_x, receiver_pos_y]])
        self.snapshots_in_file = snapshots_in_file
        self.snapshots_in_sample = snapshots_in_sample
        self.filenames = sorted(
            filter(
                lambda x: ".npy" in x,
                ["%s/%s" % (self.root_dir, x) for x in os.listdir(self.root_dir)],
            )
        )
        self.datas = [
            np.memmap(
                filename,
                dtype="float32",
                mode="r",
                shape=(self.snapshots_in_file, self.nthetas + 5),
            )
            for filename in self.filenames
        ]
        self.samples_per_file = [
            d.shape[0] // self.snapshots_in_sample for d in self.datas
        ]
        self.cumsum_samples_per_file = np.cumsum([0] + self.samples_per_file)
        self.len = sum(self.samples_per_file)
        self.zeros = np.zeros((self.snapshots_in_sample, 5))
        self.ones = np.ones((self.snapshots_in_sample, 5))
        self.widths = np.ones((self.snapshots_in_sample, 1)) * self.args["width"]
        self.halfpis = np.ones((self.snapshots_in_sample, 1)) * np.pi / 2
        idx_to_fileidx_and_sampleidx = {}

    def idx_to_fileidx_and_startidx(self, idx):
        file_idx = bisect.bisect_right(self.cumsum_samples_per_file, idx) - 1
        if file_idx >= len(self.samples_per_file):
            return None
        return (
            file_idx,
            (idx - self.cumsum_samples_per_file[file_idx]) * self.snapshots_in_sample,
        )

    def __len__(self):
        self.len

    def __getitem__(self, idx):
        fileidx, startidx = self.idx_to_fileidx_and_startidx(idx)
        m = self.datas[fileidx][startidx : startidx + self.snapshots_in_sample]
        return {
            "broadcasting_positions_at_t": self.ones[:, [0]],  # TODO multi source
            "source_positions_at_t": m[:, 1:3],
            "source_velocities_at_t": self.zeros[:, :2],  # TODO calc velocity,
            "receiver_positions_at_t": np.broadcast_to(
                self.receiver_pos[None], (m.shape[0], 2, 2)
            ),
            "beam_former_outputs_at_t": m[:, 5:],
            "thetas_at_t": self.thetas,
            "time_stamps": m[:, 0],
            "width_at_t": self.widths,
            "detector_orientation_at_t": self.halfpis,  # np.arctan2(1,0)=np.pi/2
            "detector_position_at_t": np.broadcast_to(
                self.detector_position, (m.shape[0], 2)
            ),
            "source_theta_at_t": self.zeros[:, [0]],
            "source_distance_at_t": self.zeros[:, [0]],
        }

In [None]:
sdr = SessionsDatasetReal("/Users/miskodzamba/Dropbox/research/gits/spf/software/data")
sdr.len, sdr.idx_to_fileidx_and_startidx(2), sdr.idx_to_fileidx_and_startidx(
    sdr.len - 1
)
sdr[2]

In [None]:
np.ones((128, 5))[:, [0]]