In [1]:
import sys
import sqlite3
from typing import List, Optional, Dict, Any
from pathlib import Path
prj_path = Path("/media/eden/sandisk/projects/icecube/")
sys.path.append(str(prj_path))

from src.constants import *
from graphnet.utilities.logging import get_logger

logger = get_logger(log_folder=log_dir)


import numpy as np
import pandas as pd
from tqdm import tqdm
from lion_pytorch import Lion
from sklearn.model_selection import KFold
from torch.optim.adam import Adam
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from graphnet.data.constants import FEATURES, TRUTH
from graphnet.models import StandardModel
from graphnet.models.detector.icecube import IceCubeKaggle
from graphnet.models.gnn import DynEdge
from graphnet.models.graph_builders import KNNGraphBuilder
from graphnet.models.task.reconstruction import (
    DirectionReconstructionWithKappa,
    ZenithReconstructionWithKappa,
    AzimuthReconstructionWithKappa,
)
from graphnet.training.callbacks import ProgressBar, PiecewiseLinearLR
from graphnet.training.loss_functions import VonMisesFisher3DLoss, VonMisesFisher2DLoss
from graphnet.training.labels import Direction
from graphnet.training.utils import make_dataloader


[1;34mgraphnet[0m: [32mINFO    [0m 2023-02-19 02:42:43 - get_logger - Writing log to [1m/media/eden/sandisk/projects/icecube/logs/graphnet_20230219-024243.log[0m


In [2]:
PULSEMAP = "pulse_table"
DATABASE_PATH = database_dir / "batch_51_100.db"
# DATABASE_PATH = "/media/eden/sandisk/projects/icecube/input/sqlite/batch_1.db"
PULSE_THRESHOLD = 400
SEED = 42

# Training configs
MAX_EPOCHS = 100
LR = 1e-3
BS = 512
ES = 5
NUM_FOLDS = 5
NUM_WORKERS = 16

# Paths
FOLD_PATH = input_dir / "folds"
COUNT_PATH = FOLD_PATH / "batch51_100_counts.csv"
CV_PATH = FOLD_PATH / f"batch51_100_cv_max_{PULSE_THRESHOLD}_pulses.csv"
WANDB_DIR = log_dir / "wandb"
PROJECT_NAME = "icecube"
GROUP_NAME = "ft_batch_51_100"

CREATE_FOLDS = False


In [3]:
config = {
    "path": str(DATABASE_PATH),
    # "path": "/media/eden/sandisk/projects/icecube/input/sqlite/batch_1.db",
    "pulsemap": "pulse_table",
    "truth_table": "meta_table",
    "features": FEATURES.KAGGLE,
    "truth": TRUTH.KAGGLE,
    "index_column": "event_id",
    "batch_size": BS,
    "num_workers": NUM_WORKERS,
    "target": "direction",
    "run_name_tag": "batch_1_50",
    "early_stopping_patience": ES,
    "fit": {
        "max_epochs": MAX_EPOCHS,
        "gpus": [0],
        "distribution_strategy": None,
    },
    "base_dir": "training",
    "wandb": {
        "project": PROJECT_NAME,
        "group": GROUP_NAME,
    },
}


def build_model(config: Dict[str, Any], train_dataloader: Any) -> StandardModel:
    """Builds GNN from config"""
    # Building model
    detector = IceCubeKaggle(
        graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
    )
    gnn = DynEdge(
        nb_inputs=detector.nb_outputs,
        global_pooling_schemes=["min", "max", "mean"],
    )

    if config["target"] == "direction":
        task = DirectionReconstructionWithKappa(
            hidden_size=gnn.nb_outputs,
            target_labels=config["target"],
            loss_function=VonMisesFisher3DLoss(),
        )
        prediction_columns = [
            config["target"] + "_x",
            config["target"] + "_y",
            config["target"] + "_z",
            config["target"] + "_kappa",
        ]
        additional_attributes = ["zenith", "azimuth", "event_id"]

    model = StandardModel(
        detector=detector,
        gnn=gnn,
        tasks=[task],
        optimizer_class=Adam,
        optimizer_kwargs={"lr": LR, "eps": 1e-03},
        scheduler_class=PiecewiseLinearLR,
        scheduler_kwargs={
            "milestones": [
                0,
                len(train_dataloader) / 2,
                len(train_dataloader) * config["fit"]["max_epochs"],
            ],
            "factors": [1e-02, 1, 1e-02],
        },
        scheduler_config={
            "interval": "step",
        },
    )
    model.prediction_columns = prediction_columns
    model.additional_attributes = additional_attributes

    return model


def load_pretrained_model(
    config: Dict[str, Any],
    state_dict_path: str = "/kaggle/input/dynedge-pretrained/dynedge_pretrained_batch_1_to_50/state_dict.pth",
) -> StandardModel:
    train_dataloader, _ = make_dataloaders(config=config)
    model = build_model(config=config, train_dataloader=train_dataloader)
    # model._inference_trainer = Trainer(config['fit'])
    model.load_state_dict(state_dict_path)
    model.prediction_columns = [
        config["target"] + "_x",
        config["target"] + "_y",
        config["target"] + "_z",
        config["target"] + "_kappa",
    ]
    model.additional_attributes = ["zenith", "azimuth", "event_id"]
    return model


def make_dataloaders(config: Dict[str, Any], fold: int = 0) -> List[Any]:
    """Constructs training and validation dataloaders for training with early stopping."""
    df_cv = pd.read_csv(CV_PATH)

    val_idx = df_cv[df_cv["fold"] == fold][config["index_column"]].ravel().tolist()
    train_idx = (
        df_cv[~df_cv["fold"].isin([-1, fold])][config["index_column"]].ravel().tolist()
    )

    train_dataloader = make_dataloader(
        db=config["path"],
        selection=train_idx,
        pulsemaps=config["pulsemap"],
        features=FEATURES.KAGGLE,
        truth=TRUTH.KAGGLE,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=True,
        labels={"direction": Direction()},
        index_column=config["index_column"],
        truth_table=config["truth_table"],
    )

    validate_dataloader = make_dataloader(
        db=config["path"],
        selection=val_idx,
        pulsemaps=config["pulsemap"],
        features=FEATURES.KAGGLE,
        truth=TRUTH.KAGGLE,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=False,
        labels={"direction": Direction()},
        index_column=config["index_column"],
        truth_table=config["truth_table"],
    )

    return train_dataloader, validate_dataloader



In [4]:
model = load_pretrained_model(config=config, state_dict_path="/media/eden/sandisk/projects/icecube/models/batch1_50/state_dict.pth")



In [5]:
import torch

state_dict = torch.load("/media/eden/sandisk/projects/icecube/models/batch1_50/state_dict.pth")

In [7]:
state_dict.keys()

odict_keys(['_gnn._conv_layers.0.nn.0.weight', '_gnn._conv_layers.0.nn.0.bias', '_gnn._conv_layers.0.nn.2.weight', '_gnn._conv_layers.0.nn.2.bias', '_gnn._conv_layers.1.nn.0.weight', '_gnn._conv_layers.1.nn.0.bias', '_gnn._conv_layers.1.nn.2.weight', '_gnn._conv_layers.1.nn.2.bias', '_gnn._conv_layers.2.nn.0.weight', '_gnn._conv_layers.2.nn.0.bias', '_gnn._conv_layers.2.nn.2.weight', '_gnn._conv_layers.2.nn.2.bias', '_gnn._conv_layers.3.nn.0.weight', '_gnn._conv_layers.3.nn.0.bias', '_gnn._conv_layers.3.nn.2.weight', '_gnn._conv_layers.3.nn.2.bias', '_gnn._post_processing.0.weight', '_gnn._post_processing.0.bias', '_gnn._post_processing.2.weight', '_gnn._post_processing.2.bias', '_gnn._readout.0.weight', '_gnn._readout.0.bias', '_tasks.0._affine.weight', '_tasks.0._affine.bias'])

In [5]:
train_loader, _ = make_dataloaders(config=config)

for batch in train_loader:
    y = model(batch)
    break



In [6]:
from torchviz import make_dot

make_dot(y[0].mean(), params=dict(model.named_parameters())).render(format="png")

'Digraph.gv.png'

In [None]:
"""
StandardModel(
  (_detector): IceCubeKaggle(
    (_graph_builder): KNNGraphBuilder()
  )
  (_gnn): DynEdge(
    (_activation): LeakyReLU(negative_slope=0.01)
    (_conv_layers): ModuleList(
      (0): DynEdgeConv(nn=Sequential(
        (0): Linear(in_features=34, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
        (2): Linear(in_features=128, out_features=256, bias=True)
        (3): LeakyReLU(negative_slope=0.01)
      ))
      (1): DynEdgeConv(nn=Sequential(
        (0): Linear(in_features=512, out_features=336, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
        (2): Linear(in_features=336, out_features=256, bias=True)
        (3): LeakyReLU(negative_slope=0.01)
      ))
      (2): DynEdgeConv(nn=Sequential(
        (0): Linear(in_features=512, out_features=336, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
        (2): Linear(in_features=336, out_features=256, bias=True)
        (3): LeakyReLU(negative_slope=0.01)
      ))
      (3): DynEdgeConv(nn=Sequential(
        (0): Linear(in_features=512, out_features=336, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
        (2): Linear(in_features=336, out_features=256, bias=True)
        (3): LeakyReLU(negative_slope=0.01)
      ))
    )
    (_post_processing): Sequential(
      (0): Linear(in_features=1041, out_features=336, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=336, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
    )
    (_readout): Sequential(
      (0): Linear(in_features=768, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
  )
  (_tasks): ModuleList(
    (0): DirectionReconstructionWithKappa(
      (_loss_function): VonMisesFisher3DLoss()
      (_affine): Linear(in_features=128, out_features=3, bias=True)
    )
  )
)
"""