# Fine-tuning GraphNeT

## Import

In [1]:
import sqlite3
from typing import List, Dict, Any
from pathlib import Path

from icecube.constants import *
from graphnet.utilities.logging import get_logger

logger = get_logger(log_folder=log_dir)


import numpy as np
import pandas as pd
from tqdm import tqdm
from lion_pytorch import Lion
from sklearn.model_selection import KFold
from torch.optim import Adam, SGD, AdamW, Adagrad
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

from graphnet.data.constants import FEATURES, TRUTH
from graphnet.models import StandardModel
from graphnet.models.detector.icecube import IceCubeKaggle
from graphnet.models.gnn import DynEdge
from graphnet.data.sqlite.sqlite_dataset import SQLiteDataset
from graphnet.models.graph_builders import KNNGraphBuilder
from graphnet.models.task.reconstruction import (
    DirectionReconstructionWithKappa,
    ZenithReconstructionWithKappa,
    AzimuthReconstructionWithKappa,
)
from graphnet.training.callbacks import ProgressBar, PiecewiseLinearLR
from graphnet.training.loss_functions import (
    VonMisesFisher3DLoss,
    VonMisesFisher2DLoss,
)
from graphnet.training.labels import Direction
from graphnet.training.utils import make_dataloader


[1;34mgraphnet[0m: [32mINFO    [0m 2023-03-05 23:32:11 - get_logger - Writing log to [1m/media/eden/sandisk/projects/icecube/logs/graphnet_20230305-233211.log[0m


## Constants

In [4]:
PULSEMAP = "pulse_table"
DATABASE_PATH = database_dir / "batch_51_100.db"
# DATABASE_PATH = "/media/eden/sandisk/projects/icecube/input/sqlite/batch_1.db"
PULSE_THRESHOLD = 200
SEED = 42

# Training configs
MAX_EPOCHS = 100
LR = 5e-3
BS = 256
ES = 5
NUM_FOLDS = 5
NUM_WORKERS = 16

# Paths
FOLD_PATH = input_dir / "folds"
COUNT_PATH = FOLD_PATH / "batch51_100_counts.csv"
CV_PATH = FOLD_PATH / f"batch51_100_cv_max_{PULSE_THRESHOLD}_pulses.csv"
WANDB_DIR = log_dir
PROJECT_NAME = "icecube"
GROUP_NAME = "ft_batch_51_100"

CREATE_FOLDS = True


## Create selection

In [3]:
def make_selection(
    df: pd.DataFrame, num_folds: int = 5, pulse_threshold: int = 200
) -> None:
    """Creates a validation and training selection (20 - 80). All events in both selections satisfies n_pulses <= 200 by default."""
    n_events = np.arange(0, len(df), 1)
    df["fold"] = 0

    kf = KFold(n_splits=num_folds, shuffle=True, random_state=SEED)
    for i, (_, val_idx) in enumerate(kf.split(n_events)):
        df.loc[val_idx, "fold"] = i

    # Remove events with large pulses from training and validation sample (memory)
    df["fold"][df["n_pulses"] > pulse_threshold] = -1

    df.to_csv(CV_PATH)
    return


def get_number_of_pulses(db: Path, event_id: int, pulsemap: str) -> int:
    with sqlite3.connect(str(db)) as con:
        query = f"select event_id from {pulsemap} where event_id = {event_id} limit 20000"
        data = con.execute(query).fetchall()
    return len(data)


def count_pulses(database: Path, pulsemap: str) -> pd.DataFrame:
    """Will count the number of pulses in each event and return a single dataframe that contains counts for each event_id."""
    with sqlite3.connect(str(database)) as con:
        query = "select event_id from meta_table"
        events = pd.read_sql(query, con)
    counts = {"event_id": [], "n_pulses": []}

    for event_id in tqdm(events["event_id"]):
        a = get_number_of_pulses(database, event_id, pulsemap)
        counts["event_id"].append(event_id)
        counts["n_pulses"].append(a)

    df = pd.DataFrame(counts)
    df.to_csv(COUNT_PATH)
    return df


In [5]:
if CREATE_FOLDS:
    df = (
        count_pulses(DATABASE_PATH, PULSEMAP)
        if not COUNT_PATH.exists()
        else pd.read_csv(COUNT_PATH)
    )
    make_selection(df=df, pulse_threshold=PULSE_THRESHOLD)


## Training

### Configurations

In [6]:
config = {
    "path": str(DATABASE_PATH),
    # "path": "/media/eden/sandisk/projects/icecube/input/sqlite/batch_1.db",
    "pulsemap": "pulse_table",
    "truth_table": "meta_table",
    "features": FEATURES.KAGGLE,
    "truth": TRUTH.KAGGLE,
    "index_column": "event_id",
    "batch_size": BS,
    "num_workers": NUM_WORKERS,
    "target": "direction",
    "run_name_tag": "batch_1_50",
    "early_stopping_patience": ES,
    "fit": {
        "max_epochs": MAX_EPOCHS,
        "gpus": [0],
        "distribution_strategy": None,
        # "limit_train_batches": 10,  # debug
        # "limit_val_batches": 10,
    },
    "base_dir": "training",
    "wandb": {
        "project": PROJECT_NAME,
        "group": GROUP_NAME,
    },
}


In [7]:
import torch

err = torch.tensor([torch.nan, torch.inf])

torch.clip(err, 0, 2*torch.pi).sum()

tensor(nan)

In [7]:
def build_model(
    config: Dict[str, Any], train_dataloader: Any
) -> StandardModel:
    """Builds GNN from config"""
    # Building model
    detector = IceCubeKaggle(
        graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
    )
    gnn = DynEdge(
        nb_inputs=detector.nb_outputs,
        global_pooling_schemes=["min", "max", "mean"],
    )

    if config["target"] == "direction":
        task = DirectionReconstructionWithKappa(
            hidden_size=gnn.nb_outputs,
            target_labels=config["target"],
            loss_function=VonMisesFisher3DLoss(),
        )
        prediction_columns = [
            config["target"] + "_x",
            config["target"] + "_y",
            config["target"] + "_z",
            config["target"] + "_kappa",
        ]
        additional_attributes = ["zenith", "azimuth", "event_id"]

    model = StandardModel(
        detector=detector,
        gnn=gnn,
        tasks=[task],
        optimizer_class=SGD,
        optimizer_kwargs={
            "lr": LR,
            "momentum": 0.9,
            "nesterov": True,
        },
        scheduler_class=PiecewiseLinearLR,
        scheduler_kwargs={
            "milestones": [
                0,
                len(train_dataloader) / 2,
                len(train_dataloader) * config["fit"]["max_epochs"],
            ],
            "factors": [1e-03, 1, 1e-03],
        },
        scheduler_config={
            "interval": "step",
        },
    )
    model.prediction_columns = prediction_columns
    model.additional_attributes = additional_attributes

    return model


def load_pretrained_model(
    config: Dict[str, Any],
    state_dict_path: str = "/kaggle/input/dynedge-pretrained/dynedge_pretrained_batch_1_to_50/state_dict.pth",
) -> StandardModel:
    train_dataloader, _ = make_dataloaders(config=config)
    model = build_model(config=config, train_dataloader=train_dataloader)
    # model._inference_trainer = Trainer(config['fit'])
    model.load_state_dict(state_dict_path)
    model.prediction_columns = [
        config["target"] + "_x",
        config["target"] + "_y",
        config["target"] + "_z",
        config["target"] + "_kappa",
    ]
    model.additional_attributes = ["zenith", "azimuth", "event_id"]
    return model


def make_dataloaders(config: Dict[str, Any], fold: int = 0) -> List[Any]:
    """Constructs training and validation dataloaders for training with early stopping."""
    df_cv = pd.read_csv(CV_PATH)

    val_idx = (
        df_cv[df_cv["fold"] == fold][config["index_column"]].ravel().tolist()
    )
    train_idx = (
        df_cv[~df_cv["fold"].isin([-1, fold])][config["index_column"]]
        .ravel()
        .tolist()
    )

    train_dataloader = make_dataloader(
        db=config["path"],
        selection=train_idx,
        pulsemaps=config["pulsemap"],
        features=FEATURES.KAGGLE,
        truth=TRUTH.KAGGLE,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=True,
        labels={"direction": Direction()},
        index_column=config["index_column"],
        truth_table=config["truth_table"],
    )

    validate_dataloader = make_dataloader(
        db=config["path"],
        selection=val_idx,
        pulsemaps=config["pulsemap"],
        features=FEATURES.KAGGLE,
        truth=TRUTH.KAGGLE,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        shuffle=False,
        labels={"direction": Direction()},
        index_column=config["index_column"],
        truth_table=config["truth_table"],
    )

    return train_dataloader, validate_dataloader


def train_dynedge(
    config: Dict[str, Any], fold: int = 0, resume: Path = None
) -> pd.DataFrame:
    """Builds(or resumes) and trains GNN according to config."""
    logger.info(f"features: {config['features']}")
    logger.info(f"truth: {config['truth']}")

    run_name = (
        f"dynedge_{config['target']}_{config['run_name_tag']}_fold{fold}"
    )

    wandb_logger = WandbLogger(
        project=PROJECT_NAME,
        group=GROUP_NAME,
        name=run_name,
        save_dir=WANDB_DIR,
        log_model=True,
    )
    wandb_logger.experiment.config.update(config)

    train_dataloader, validate_dataloader = make_dataloaders(
        config=config, fold=fold
    )

    if not resume:
        model = build_model(config, train_dataloader)
    else:
        model = load_pretrained_model(config, state_dict_path=resume)

    # Training model
    callbacks = [
        EarlyStopping(
            monitor="val_mae",
            patience=config["early_stopping_patience"],
        ),
        LearningRateMonitor(logging_interval="step"),
        ProgressBar(),
    ]

    model.fit(
        train_dataloader,
        validate_dataloader,
        callbacks=callbacks,
        logger=wandb_logger,
        **config["fit"],
    )

    return model


def convert_to_3d(df: pd.DataFrame) -> pd.DataFrame:
    """Converts zenith and azimuth to 3D direction vectors"""
    df["true_x"] = np.cos(df["azimuth"]) * np.sin(df["zenith"])
    df["true_y"] = np.sin(df["azimuth"]) * np.sin(df["zenith"])
    df["true_z"] = np.cos(df["zenith"])
    return df


def calculate_angular_error(df: pd.DataFrame) -> pd.DataFrame:
    """Calcualtes the opening angle (angular error) between true and reconstructed direction vectors"""
    df["angular_error"] = np.arccos(
        df["true_x"] * df["direction_x"]
        + df["true_y"] * df["direction_y"]
        + df["true_z"] * df["direction_z"]
    )
    return df


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9d74332b00>
Traceback (most recent call last):
  File "/home/eden/anaconda3/envs/icecube/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/home/eden/anaconda3/envs/icecube/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1397, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'


In [12]:
train_dataloader, _ = make_dataloaders(config=config)
model = build_model(config=config, train_dataloader=train_dataloader)
# model._inference_trainer = Trainer(config['fit'])
# model.load_state_dict("../../models/batch1_50/state_dict.pth")

### CV

In [18]:
import torch

state_dict = torch.load("../../models/graphnet_50_100.ckpt")["state_dict"]

In [19]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [7]:
import torch


torch.set_float32_matmul_precision("high")

for fold in range(NUM_FOLDS):
    train_dynedge(
        config=config,
        fold=fold,
        resume="/media/eden/sandisk/projects/icecube/models/batch1_50/state_dict.pth",
    )


[1;34mgraphnet[0m: [32mINFO    [0m 2023-02-19 13:44:53 - train_dynedge - features: ['x', 'y', 'z', 'time', 'charge', 'auxiliary'][0m
[1;34mgraphnet[0m: [32mINFO    [0m 2023-02-19 13:44:53 - train_dynedge - truth: ['zenith', 'azimuth'][0m


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33medenn0[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | _detector | IceCubeKaggle    | 0     
1 | _gnn      | DynEdge          | 1.3 M 
2 | _tasks    | ModuleList       | 387   
3 | mae       | MeanAngularError | 0     
-----------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.395     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]