In [11]:
import pickle
from typing import List

import torch
import webdataset as wbs
from torch.utils.data import DataLoader
from torch_geometric.data import Data, Batch
from graphnet.data.constants import FEATURES, TRUTH


[1;34mgraphnet[0m: [32mINFO    [0m 2023-03-08 22:30:39 - get_logger - Writing log to [1mlogs/graphnet_20230308-223039.log[0m


In [3]:
def get_features_truth(src):
    features, truth = pickle.loads(src["pickle"])
    
    x = torch.as_tensor(features, dtype=torch.float32)
    n_pulses = torch.tensor(len(x), dtype=torch.int32)
    graph = Data(x=x, edge_index=None)
    graph.n_pulses = n_pulses
    graph.features = ['x', 'y', 'z', 'time', 'charge', 'auxiliary']
    
    target_names = ["zenith", "azimuth"]
    for index, name in enumerate(target_names):
        graph[name] = torch.tensor(truth[0][index], dtype=torch.float32)

    for index, feature in enumerate(graph.features):
        if feature not in ["x"]:
            graph[feature] = graph.x[:, index].detach()
                
    x = torch.cos(graph["azimuth"]) * torch.sin(graph["zenith"]).reshape(-1, 1)
    y = torch.sin(graph["azimuth"]) * torch.sin(graph["zenith"]).reshape(-1, 1)
    z = torch.cos(graph["zenith"]).reshape(-1, 1)
    graph["direction"] = torch.cat([x, y, z], dim=1)
    
    graph["event_id"] = torch.tensor(int(src["__key__"]), dtype=torch.int32)
    
    return graph

In [24]:
batch_ids = [51]

import polars as pl

df_meta = pl.read_parquet("../../raw/icecube-neutrinos-in-deep-ice/train_meta.parquet")
n_events = df_meta.filter(pl.col("batch_id").is_in(batch_ids))["event_id"].n_unique()

In [4]:
dataset = wbs.WebDataset("./webdatasets/shards-{0000..0001}.tar").map(get_features_truth).with_length(n_events)

In [9]:
def collate_fn(graphs: List[Data]) -> Batch:
    """Remove graphs with less than two DOM hits.

    Should not occur in "production.
    """
    graphs = [g for g in graphs if g.n_pulses > 1]
    return Batch.from_data_list(graphs)

dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)

In [19]:
from graphnet.data.constants import FEATURES, TRUTH
from graphnet.models import StandardModel
from graphnet.models.detector.icecube import IceCubeKaggle
from graphnet.models.gnn import DynEdge
from graphnet.models.graph_builders import KNNGraphBuilder
from graphnet.models.task.reconstruction import \
    DirectionReconstructionWithKappa
from graphnet.training.callbacks import PiecewiseLinearLR, ProgressBar
from graphnet.training.labels import Direction
from graphnet.training.loss_functions import VonMisesFisher3DLoss
from graphnet.training.utils import make_dataloader
from typing import Dict, Any
from torch.optim import SGD


config = {
    "path": "/media/eden/sandisk/projects/icecube/input/sqlite/batch_51_100.db",
    "pulsemap": "pulse_table",
    "truth_table": "meta_table",
    "features": FEATURES.KAGGLE,
    "truth": TRUTH.KAGGLE,
    "index_column": "event_id",
    "batch_size": 64,
    "num_workers": 4,
    "target": "direction",
    "run_name_tag": "batch_1_50",
    "early_stopping_patience": 5,
    "fit": {
        "max_epochs": 100,
        "gpus": [0],
        "distribution_strategy": None,
        "limit_train_batches": 1.0,  # debug
        "limit_val_batches": 1.0,
        "precision": 16,
    },
    "base_dir": "training",
    "lr": 0.01,
}


def build_model(
    config: Dict[str, Any], train_dataloader: Any
) -> StandardModel:
    """Builds GNN from config"""
    # Building model
    detector = IceCubeKaggle(
        graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
    )
    gnn = DynEdge(
        nb_inputs=detector.nb_outputs,
        global_pooling_schemes=["min", "max", "mean"],
    )
    gnn._activation = torch.nn.Mish()

    if config["target"] == "direction":
        task = DirectionReconstructionWithKappa(
            hidden_size=gnn.nb_outputs,
            target_labels=config["target"],
            loss_function=VonMisesFisher3DLoss(),
        )
        prediction_columns = [
            config["target"] + "_x",
            config["target"] + "_y",
            config["target"] + "_z",
            config["target"] + "_kappa",
        ]
        additional_attributes = ["zenith", "azimuth", "event_id"]

    model = StandardModel(
        detector=detector,
        gnn=gnn,
        tasks=[task],
        optimizer_class=SGD,
        optimizer_kwargs={
            "lr": config["lr"],
            "momentum": 0.9,
            "nesterov": True,
            # "weight_decay": 1e-4,
        },
    )
    model.prediction_columns = prediction_columns
    model.additional_attributes = additional_attributes

    return model

In [20]:
model = build_model(config, dataloader)

In [23]:
model.fit(dataloader, **config["fit"])

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | _detector | IceCubeKaggle    | 0     
1 | _gnn      | DynEdge          | 1.3 M 
2 | _tasks   

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [28]:
for batch in dataloader:
    y = model(batch.cuda())
    break