In [1]:
import gc
from pathlib import Path
from typing import Callable

import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping
from torchmetrics import Metric, Accuracy

from icecube import constants

In [2]:
# data preprocessing
point_picker_format = '../../input/preprocessed/pointpicker_mpc128_n9_batch_{batch_id:d}.npz'

# model
model_output_path = "../../models/PointPicker_mpc128bin16_LSTM160DENSE0"

# data
bin_num = 16

train_batch_id_min = 51
train_batch_id_max = 55

train_batch_ids = list(range(train_batch_id_min, train_batch_id_max + 1))

# model
LSTM_width = 160
DENSE_width = 0

# training
validation_split = 0.05

In [4]:
def angular_dist_score(az_true, zen_true, az_pred, zen_pred, avg=True):
    '''
    calculate the MAE of the angular distance between two directions.
    The two vectors are first converted to cartesian unit vectors,
    and then their scalar product is computed, which is equal to
    the cosine of the angle between the two vectors. The inverse 
    cosine (arccos) thereof is then the angle between the two input vectors
    
    Parameters:
    -----------
    
    az_true : float (or array thereof)
        true azimuth value(s) in radian
    zen_true : float (or array thereof)
        true zenith value(s) in radian
    az_pred : float (or array thereof)
        predicted azimuth value(s) in radian
    zen_pred : float (or array thereof)
        predicted zenith value(s) in radian
    
    Returns:
    --------
    
    dist : float
        mean over the angular distance(s) in radian
    '''
    
    if not (torch.all(torch.isfinite(az_true)) and
            torch.all(torch.isfinite(zen_true)) and
            torch.all(torch.isfinite(az_pred)) and
            torch.all(torch.isfinite(zen_pred))):
        raise ValueError("All arguments must be finite")
    
    # pre-compute all sine and cosine values
    sa1 = torch.sin(az_true)
    ca1 = torch.cos(az_true)
    sz1 = torch.sin(zen_true)
    cz1 = torch.cos(zen_true)
    
    sa2 = torch.sin(az_pred)
    ca2 = torch.cos(az_pred)
    sz2 = torch.sin(zen_pred)
    cz2 = torch.cos(zen_pred)
    
    # scalar product of the two cartesian vectors (x = sz*ca, y = sz*sa, z = cz)
    scalar_prod = sz1*sz2*(ca1*ca2 + sa1*sa2) + (cz1*cz2)
    
    # scalar product of two unit vectors is always between -1 and 1, this is against nummerical instability
    # that might otherwise occure from the finite precision of the sine and cosine functions
    scalar_prod =  torch.clip(scalar_prod, -1, 1)
    
    ae = torch.abs(torch.arccos(scalar_prod))
    
    # convert back to an angle (in radian)
    return torch.average(ae) if avg else ae

In [5]:
azimuth_edges = np.linspace(0, 2 * np.pi, bin_num + 1)
zenith_edges_flat = np.linspace(0, np.pi, bin_num + 1)
zenith_edges = list()
zenith_edges.append(0)
for bin_idx in range(1, bin_num):
    # cos(zen_before) - cos(zen_now) = 2 / bin_num
    zen_now = np.arccos(np.cos(zenith_edges[-1]) - 2 / (bin_num))
    zenith_edges.append(zen_now)
zenith_edges.append(np.pi)
zenith_edges = np.array(zenith_edges)

In [6]:
def y_to_onehot(batch_y):
    # evaluate bin code
    azimuth_code = (batch_y[:, 0] > azimuth_edges[1:].reshape((-1, 1))).sum(axis=0)
    zenith_code = (batch_y[:, 1] > zenith_edges[1:].reshape((-1, 1))).sum(axis=0)
    angle_code = bin_num * azimuth_code + zenith_code

    # one-hot
    batch_y_onehot = np.zeros((angle_code.size, bin_num * bin_num))
    batch_y_onehot[np.arange(angle_code.size), angle_code] = 1
    
    return batch_y_onehot

In [7]:
angle_bin_zenith0 = np.tile(zenith_edges[:-1], bin_num)
angle_bin_zenith1 = np.tile(zenith_edges[1:], bin_num)
angle_bin_azimuth0 = np.repeat(azimuth_edges[:-1], bin_num)
angle_bin_azimuth1 = np.repeat(azimuth_edges[1:], bin_num)

angle_bin_area = (angle_bin_azimuth1 - angle_bin_azimuth0) * (np.cos(angle_bin_zenith0) - np.cos(angle_bin_zenith1))
angle_bin_vector_sum_x = (np.sin(angle_bin_azimuth1) - np.sin(angle_bin_azimuth0)) * ((angle_bin_zenith1 - angle_bin_zenith0) / 2 - (np.sin(2 * angle_bin_zenith1) - np.sin(2 * angle_bin_zenith0)) / 4)
angle_bin_vector_sum_y = (np.cos(angle_bin_azimuth0) - np.cos(angle_bin_azimuth1)) * ((angle_bin_zenith1 - angle_bin_zenith0) / 2 - (np.sin(2 * angle_bin_zenith1) - np.sin(2 * angle_bin_zenith0)) / 4)
angle_bin_vector_sum_z = (angle_bin_azimuth1 - angle_bin_azimuth0) * ((np.cos(2 * angle_bin_zenith0) - np.cos(2 * angle_bin_zenith1)) / 4)

angle_bin_vector_mean_x = angle_bin_vector_sum_x / angle_bin_area
angle_bin_vector_mean_y = angle_bin_vector_sum_y / angle_bin_area
angle_bin_vector_mean_z = angle_bin_vector_sum_z / angle_bin_area

angle_bin_vector = np.zeros((1, bin_num * bin_num, 3))
angle_bin_vector[:, :, 0] = angle_bin_vector_mean_x
angle_bin_vector[:, :, 1] = angle_bin_vector_mean_y
angle_bin_vector[:, :, 2] = angle_bin_vector_mean_z

In [8]:
angle_bin_vector = torch.tensor(angle_bin_vector, device="cuda")

In [9]:
def pred_to_angle(pred, epsilon=1e-8):
    # convert prediction to vector
    pred_vector = (pred.reshape((-1, bin_num * bin_num, 1)) * angle_bin_vector).sum(axis=1)
    
    # normalize
    pred_vector_norm = (pred_vector**2).sum(axis=1).sqrt()
    mask = pred_vector_norm < epsilon
    pred_vector_norm[mask] = 1
    
    # assign <1, 0, 0> to very small vectors (badly predicted)
    pred_vector /= pred_vector_norm.reshape((-1, 1))
    pred_vector[mask] = torch.tensor([1., 0., 0.], device="cuda", dtype=pred_vector.dtype)
    
    # convert to angle
    azimuth = torch.arctan2(pred_vector[:, 1], pred_vector[:, 0])
    azimuth[azimuth < 0] += 2 * np.pi
    zenith = torch.arccos(pred_vector[:, 2])
    
    return azimuth, zenith

In [10]:
print("Reading training data...")

train_x = None
train_y = None
for batch_id in tqdm(train_batch_ids):
    train_data_file = np.load(point_picker_format.format(batch_id=batch_id))
    
    if train_x is None:
        train_x = train_data_file["x"]
        train_y = train_data_file["y"]
    else:
        train_x = np.append(train_x, train_data_file["x"], axis=0)
        train_y = np.append(train_y, train_data_file["y"], axis=0)
        
    train_data_file.close()
    del train_data_file
    _ = gc.collect()

Reading training data...


100%|█████████████████████████████████████████████| 5/5 [00:02<00:00,  2.05it/s]


In [11]:
# train_x[:, :, 0] /= 1000  # time
# train_x[:, :, 1] /= 300  # charge
# train_x[:, :, 3:] /= 600  # space

train_y_onehot = y_to_onehot(train_y)

In [12]:
from sklearn.preprocessing import StandardScaler

original_shape = train_x.shape
scaler = StandardScaler().fit(train_x[:100000].reshape(-1, original_shape[-1]))

In [13]:
train_x = scaler.transform(train_x.reshape(-1, original_shape[-1])).reshape(original_shape)

In [14]:
num_valid = int(validation_split * len(train_x))

valid_x = train_x[-num_valid:]
valid_y = train_y[-num_valid:]
valid_y_onehot = train_y_onehot[-num_valid:]

train_x = train_x[:-num_valid]
train_y = train_y[:-num_valid]
train_y_onehot = train_y_onehot[:-num_valid]

In [15]:
print(f"{'data':16s}" + f"{'shape':24s}" + f"{'mem [MB]':8s}")
print(f"{'train_x':16s}" + f"{str(train_x.shape):24s}" + f"{train_x.nbytes / 1024 / 1024:.4f}"[:8])
print(f"{'train_y':16s}" + f"{str(train_y.shape):24s}" + f"{train_y.nbytes / 1024 / 1024:.4f}"[:8])
print(f"{'train_y_onehot':16s}" + f"{str(train_y_onehot.shape):24s}" + f"{train_y_onehot.nbytes / 1024 / 1024:.4f}"[:8])
print("-" * (16 + 24 + 8))
print(f"{'valid_x':16s}" + f"{str(valid_x.shape):24s}" + f"{valid_x.nbytes / 1024 / 1024:.4f}"[:8])
print(f"{'valid_y':16s}" + f"{str(valid_y.shape):24s}" + f"{valid_y.nbytes / 1024 / 1024:.4f}"[:8])
print(f"{'valid_y_onehot':16s}" + f"{str(valid_y_onehot.shape):24s}" + f"{valid_y_onehot.nbytes / 1024 / 1024:.4f}"[:8])
print("-" * (16 + 24 + 8))
total = (train_x.nbytes + train_y.nbytes + train_y_onehot.nbytes + valid_x.nbytes + valid_y.nbytes + valid_y_onehot.nbytes) / 1024 / 1024
print(f"{'total':16s}" + f"{'':24s}" + f"{total:.4f}"[:8])
print("        real RAM usage can be doubled...")

data            shape                   mem [MB]
train_x         (950000, 128, 9)        2087.402
train_y         (950000, 2)             3.6240
train_y_onehot  (950000, 256)           1855.468
------------------------------------------------
valid_x         (50000, 128, 9)         109.8633
valid_y         (50000, 2)              0.1907
valid_y_onehot  (50000, 256)            97.6562
------------------------------------------------
total                                   4154.205
        real RAM usage can be doubled...


In [16]:
max_pulse_count = train_x.shape[1]
n_features = train_x.shape[2]

In [17]:
train_x = torch.tensor(train_x)
train_y = torch.tensor(train_y)
train_y_onehot = torch.tensor(train_y_onehot)

valid_x = torch.tensor(valid_x)
valid_y = torch.tensor(valid_y)
valid_y_onehot = torch.tensor(valid_y_onehot)

In [18]:
config = dict(
    trainer=dict(
        accelerator="gpu",
        devices=1,
        max_epochs=100,
        precision=16,
    ),
    model=dict(
        input_size=9,
        output_size=bin_num**2,
        hidden_size=LSTM_width,
        bidirectional=True,
        num_layers=1,
        bias=False,
        batch_first=True,
        dropout=0,
        lr=5e-4,
    ),
    wandb=dict(
        project="icecube",
        name="pointpicker_lstm_batch_51_55",
    ),
    data=dict(
        batch_size=256,
        num_workers=16,
        train_set=train_batch_ids,
    ),
)

In [19]:
trainset = TensorDataset(train_x, train_y, train_y_onehot)
validset = TensorDataset(valid_x, valid_y, valid_y_onehot)

trainloader = DataLoader(
    dataset=trainset,
    batch_size=config["data"]["batch_size"],
    shuffle=True,
    num_workers=config["data"]["num_workers"],
    pin_memory=True,
)
validloader = DataLoader(
    dataset=validset,
    batch_size=config["data"]["batch_size"],
    shuffle=False,
    num_workers=config["data"]["num_workers"],
    pin_memory=True,
)


In [20]:
class LSTMClassifier(LightningModule):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_size: int,
        num_layers: int = 1,
        bias: bool = False,
        batch_first: bool = True,
        dropout: float = 0,
        bidirectional: bool = False,
        lr: float = 3e-4,
    ):
        super(LSTMClassifier, self).__init__()
        self.save_hyperparameters()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bias=bias,
            batch_first=batch_first,
            dropout=dropout,
            bidirectional=bidirectional,
        )
        
        fc_input_size = hidden_size * 2 if bidirectional else hidden_size
        self.linear = nn.Linear(fc_input_size, output_size)
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=output_size)
        self.mae = MeanAngularError()

    def forward(self, x):
        # lstm_out = (batch_size, seq_len, hidden_size)
        lstm_out, _ = self.lstm(x)
        last = lstm_out[:, -1] if self.hparams.batch_first else lstm_out[-1]
        y_pred = self.linear(last)
        return y_pred

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.lr)

    def training_step(self, batch, batch_idx):
        x, _, y_oh = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y_oh)
        self.log("train/loss", loss, on_step=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y, y_oh = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y_oh)
        self.accuracy(y_hat, y_oh)
        self.log("val/loss", loss)
        
        azimuth, zenith = pred_to_angle(y_hat)
        preds = torch.stack([azimuth, zenith], axis=-1)
        self.mae(preds, y)

        return loss
    
    def on_validation_epoch_end(self) -> None:
        acc = self.accuracy.compute()
        mae = self.mae.compute()
        self.log("val/acc", acc)
        self.log("val/mae", mae, prog_bar=True)

        
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)

        return loss

    

class MeanAngularError(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("err", default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        ae = angular_dist_score(target[:, 0], target[:, 1], preds[:, 0], preds[:, 1], avg=False)

        self.err += ae.sum()
        self.total += ae.shape[0]

    def compute(self):
        return self.err / self.total


In [21]:
model = LSTMClassifier(**config["model"])

logger = WandbLogger(**config["wandb"])
logger.experiment.config.update(config)
logger.experiment.watch(model)

early_stopping = EarlyStopping(monitor="val/mae", patience=5)

trainer = Trainer(**config["trainer"], logger=logger, callbacks=[early_stopping])

[34m[1mwandb[0m: Currently logged in as: [33medenn0[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [22]:
trainer.fit(model, trainloader, validloader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | lstm      | LSTM               | 216 K 
1 | linear    | Linear             | 82.2 K
2 | criterion | CrossEntropyLoss   | 0     
3 | accuracy  | MulticlassAccuracy | 0     
4 | mae       | MeanAngularError   | 0     
-------------------------------------------------
298 K     Trainable params
0         Non-trainable params
298 K     Total params
0.597     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
