In [1]:
import numpy as np
import webdataset as wbs
import pickle

from icecube.utils.coordinate import create_bins

def get_features_truth(src):
    features, truth = pickle.loads(src["pickle"])

    x = np.concatenate(
        [features, np.zeros((features.shape[0], 1))], axis=1
    )
    dtype = [
        ("x", "float16"),
        ("y", "float16"),
        ("z", "float16"),
        ("time", "float16"),
        ("charge", "float16"),
        ("auxiliary", "float16"),
        ("rank", "short"),
    ]

    n_pulses = len(x)
   
    event_x = np.zeros(n_pulses, dtype)

    event_x["x"] = x[:, 0]
    event_x["y"] = x[:, 1]
    event_x["z"] = x[:, 2]
    event_x["time"] = x[:, 3] - x[:, 3].min()
    event_x["charge"] = x[:, 4]
    event_x["auxiliary"] = x[:, 5]

    if n_pulses > 96:
        # Find valid time window
        t_peak = event_x["time"][event_x["time"].argmax()]
        t_valid_min = t_peak - 6199.700247193777
        t_valid_max = t_peak + 6199.700247193777

        t_valid = (event_x["time"] > t_valid_min) * (
            event_x["time"] < t_valid_max
        )

        # rank
        event_x["rank"] = 2 * (1 - event_x["auxiliary"]) + (t_valid)

        # sort by rank and charge (important goes to backward)
        event_x = np.sort(event_x, order=["rank", "charge"])

        # pick-up from backward
        event_x = event_x[-96 :]

        # resort by time
        event_x = np.sort(event_x, order="time")

    event_x["x"] /= 600
    event_x["y"] /= 600
    event_x["z"] /= 600
    event_x["time"] /= 1000
    event_x["charge"] /= 300

    event_y = truth.astype(dtype="float16")[:, ::-1]
    code = _y_to_angle_code(event_y)[0]

    placeholder = np.zeros(
        (len(event_x), 6), dtype=np.float16
    )
    placeholder[:n_pulses, 0] = event_x["x"]
    placeholder[:n_pulses, 1] = event_x["y"]
    placeholder[:n_pulses, 2] = event_x["z"]
    placeholder[:n_pulses, 3] = event_x["time"]
    placeholder[:n_pulses, 4] = event_x["charge"]
    placeholder[:n_pulses, 5] = event_x["auxiliary"]

    return placeholder, event_y.squeeze(0), code, np.clip(288, a_min=0, a_max=96)

azimuth_edges, zenith_edges = create_bins(19)

def _y_to_angle_code(y):
    azimuth_code = (y[:, 0] >azimuth_edges[1:].reshape((-1, 1))).sum(
        axis=0
    )
    zenith_code = (y[:, 1] > zenith_edges[1:].reshape((-1, 1))).sum(
        axis=0
    )
    angle_code = 19 * azimuth_code + zenith_code

    return angle_code

In [3]:
trainset = (
    wbs.WebDataset("../../input/webdatasets/batch-{100..145}.tar")
)

In [5]:
from tqdm import tqdm

for t in tqdm(trainset):
    ex, ey, ecode, lengths = get_features_truth(t)
    break

0it [00:00, ?it/s]


In [6]:
ex

array([[-5.8008e-01,  7.5244e-01,  7.5146e-01,  0.0000e+00,  3.5839e-03,
         1.0000e+00],
       [ 2.9077e-01,  5.2588e-01,  2.8735e-01,  1.2299e-01,  6.7520e-03,
         1.0000e+00],
       [ 2.9077e-01,  5.2588e-01,  2.8735e-01,  1.6003e-01,  1.2503e-03,
         1.0000e+00],
       [ 2.0837e-01, -2.1875e-01, -5.9570e-01,  3.7793e-01,  3.4180e-03,
         1.0000e+00],
       [ 5.4443e-01, -3.4863e-01,  5.4688e-01,  5.3516e-01,  3.2501e-03,
         1.0000e+00],
       [-6.7188e-01,  5.8174e-03,  6.0938e-01,  8.9404e-01,  5.8365e-04,
         1.0000e+00],
       [-4.4824e-01,  5.9033e-01,  1.5515e-01,  1.1016e+00,  1.9169e-03,
         1.0000e+00],
       [-4.4824e-01,  5.9033e-01,  1.5515e-01,  1.1152e+00,  2.5826e-03,
         1.0000e+00],
       [-1.2964e-01, -9.0576e-02, -2.9688e-01,  1.1973e+00,  2.7504e-03,
         1.0000e+00],
       [-1.6846e-01,  8.1689e-01,  6.6553e-01,  1.5713e+00,  3.4180e-03,
         1.0000e+00],
       [-3.9160e-01,  2.3413e-01,  3.2300e-01,  1.

In [12]:
for i in range(50):
    print(np.load("/media/eden/sandisk/projects/icecube/input/preprocessed/archive/pp_mpc96_n7_batch_100.npz")["x"][i].shape)

(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)
(96, 7)


In [3]:
from icecube.data.datamodule import EventDataModule

dm = EventDataModule(
    24,
    data_dir=None,
    train_files="../../input/webdatasets/batch-{051..145}.tar",
    val_files="../../input/webdatasets/batch-{146..150}.tar",
    batch_size=32,
    num_workers=4,
    batch_ids=[100, 104],
    file_format="/media/eden/sandisk/projects/icecube/input/preprocessed/pp_mpc96_n7_batch_{batch_id}.npz",
)


In [4]:
dm.setup()

In [5]:
samples = []

for i, data in enumerate(dm.trainset):
    samples.append(data)
    if i == 50: break

In [7]:
import numpy as np

npz = np.load("/media/eden/sandisk/projects/icecube/input/preprocessed/pointpicker_mpc128_n9_batch_51.npz")

In [11]:
x = npz["x"][0]
y = npz["y"][0]
x, y

(array([[  0.   ,   0.525,   1.   , ...,  62.5  ,   8.5  ,   0.   ],
        [321.   ,   2.824,   1.   , ...,  62.5  ,   8.5  ,   0.   ],
        [362.   ,   0.975,   1.   , ...,  62.5  ,   8.5  ,   0.   ],
        ...,
        [  0.   ,   0.   ,  -1.   , ...,   0.   ,   0.   ,   0.   ],
        [  0.   ,   0.   ,  -1.   , ...,   0.   ,   0.   ,   0.   ],
        [  0.   ,   0.   ,  -1.   , ...,   0.   ,   0.   ,   0.   ]],
       dtype=float16),
 array([5.85 , 0.927], dtype=float16))

In [8]:
import torch
import numpy as np

batch = next(iter(dm.train_dataloader()))

In [10]:
import torch


x = torch.nn.utils.rnn.pack_padded_sequence(batch[0], batch[3], batch_first=True, enforce_sorted=False)

In [13]:
m = torch.nn.LSTM(6, 128, batch_first=True).float()

In [16]:
output, _ = m(x.float())

In [18]:
output[:, -1]

TypeError: tuple indices must be integers or slices, not tuple