In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Train script (train + val) for PONE ParquetDataset with DynEdge.
- Adam optimizer
- PiecewiseLinearLR (paper schedule: 1e-5 -> 1e-3 -> 1e-5)
- EarlyStopping(patience=5) saving best_model.pth + config.yml
"""

import os
import sys
import argparse
from typing import Tuple

import torch
from torch import Tensor

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger

# GraphNeT
from graphnet.data.dataset.parquet.parquet_dataset import ParquetDataset
from graphnet.data.dataloader import DataLoader

from graphnet.models import StandardModel
from graphnet.models.gnn import DynEdge
from graphnet.models.graphs import KNNGraph
from graphnet.models.graphs.nodes import NodesAsPulses

from graphnet.models.task import StandardLearnedTask
from graphnet.training.loss_functions import LogCoshLoss
from graphnet.training.callbacks import ProgressBar, GraphnetEarlyStopping, PiecewiseLinearLR

from graphnet.utilities.maths import eps_like


# ----------------------------
# Task + transforms (SS'deki gibi)
# ----------------------------
def logarithm(E: torch.Tensor) -> torch.Tensor:
    """Forward transform: E -> log10(E), safe-clamp to avoid log10(0)/negatives."""
    E_safe = torch.clamp(E, min=eps_like(E))
    return torch.log10(E_safe)


def exponential(t: torch.Tensor) -> torch.Tensor:
    """Inverse transform: log10(E) -> E."""
    return torch.pow(10.0, t)


class DepositedEnergyLog10Task(StandardLearnedTask):
    default_target_labels = ["energy"]
    default_prediction_labels = ["log10_energy_pred"]
    nb_inputs = 1

    def _forward(self, x: Tensor) -> Tensor:
        # StandardLearnedTask zaten latent'i nb_inputs'e map ediyor.
        # Burada sadece 1D output'u döndürmek yeterli.
        return x[:, :1]


# ----------------------------
# Build graph_definition + datasets + loaders
# ----------------------------
def build_datasets_and_loaders(
    train_path: str,
    val_path: str,
    batch_size: int,
    num_workers: int,
) -> Tuple[object, DataLoader, DataLoader]:
    """
    Returns:
        graph_definition, train_loader, val_loader
    """
    # ---- Custom detector import (senin notebooktaki sys.path hack'ini "script friendly" yaptım)
    repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
    if repo_root not in sys.path:
        sys.path.append(repo_root)

    from MyClasses.detector_pone import PONE  # noqa: E402

    FEATURES = ["dom_x", "dom_y", "dom_z", "dom_time", "charge"]  # (x,y,z,t,q)

    detector = PONE(
        replace_with_identity=["charge", "dom_time", "dom_x", "dom_y", "dom_z"]
    )

    graph_definition = KNNGraph(
        detector=detector,
        node_definition=NodesAsPulses(),
        nb_nearest_neighbours=8,
        distance_as_edge_feature=True,
    )

    # Train/val dataset ayrı
    train_ds = ParquetDataset(
        path=train_path,
        pulsemaps="features",
        truth_table="truth",
        features=FEATURES,
        truth=["energy"],
        data_representation=graph_definition,
    )

    val_ds = ParquetDataset(
        path=val_path,
        pulsemaps="features",
        truth_table="truth",
        features=FEATURES,
        truth=["energy"],
        data_representation=graph_definition,
    )

    # GraphNeT DataLoader (SS'deki gibi)
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        multiprocessing_context="spawn",
        persistent_workers=True,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        multiprocessing_context="spawn",
        persistent_workers=True,
    )

    return graph_definition, train_loader, val_loader


# ----------------------------
# Main
# ----------------------------
def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--train", required=True, help="Train merged folder path")
    parser.add_argument("--val", required=True, help="Val merged folder path")
    parser.add_argument("--outdir", required=True, help="Output directory for logs/models")

    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--seed", type=int, default=42)

    # LR schedule params (paper)
    parser.add_argument("--lr-max", type=float, default=1e-3)   # peak
    parser.add_argument("--lr-min", type=float, default=1e-5)   # start/end

    args = parser.parse_args()

    os.makedirs(args.outdir, exist_ok=True)
    seed_everything(args.seed, workers=True)

    graph_definition, train_loader, val_loader = build_datasets_and_loaders(
        train_path=args.train,
        val_path=args.val,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )

    # ----------------------------
    # Backbone (SS'deki ayarlar)
    # ----------------------------
    backbone = DynEdge(
        nb_inputs=5,                 # FEATURES length
        nb_neighbours=8,
        global_pooling_schemes=["min", "max", "mean", "sum"],
        add_global_variables_after_pooling=True,
        add_norm_layer=False,
        skip_readout=False,
    )

    # ----------------------------
    # Task (SS'deki gibi)
    # ----------------------------
    interval = (1e1, 1e8)
    task = DepositedEnergyLog10Task(
        hidden_size=backbone.nb_outputs,
        loss_function=LogCoshLoss(),
        target_labels=["energy"],
        prediction_labels=["log_energy"],
        transform_prediction_and_target=None,
        transform_target=logarithm,
        transform_inference=exponential,
        transform_support=interval,
        loss_weight=None,
    )

    # ----------------------------
    # LR schedule milestones (step-based)
    # ----------------------------
    try:
        steps_per_epoch = len(train_loader)
    except TypeError as e:
        raise RuntimeError(
            "len(train_loader) bulunamadı. "
            "Bu durumda steps_per_epoch'i elle hesaplayıp koda koyman gerekir."
        ) from e

    total_steps = steps_per_epoch * args.epochs
    warmup_steps = int(0.5 * steps_per_epoch)  # epoch1'in %50'si

    # PiecewiseLinearLR: base_lr = lr_max, factors ile lr_min'e in/çık
    # lr(t=0)=lr_max*0.01=1e-5, lr(t=warmup)=lr_max*1=1e-3, lr(t=end)=lr_max*0.01=1e-5
    factor_min = args.lr_min / args.lr_max  # 1e-5 / 1e-3 = 0.01
    scheduler_kwargs = dict(
        milestones=[0, warmup_steps, total_steps],
        factors=[factor_min, 1.0, factor_min],
    )

    # ----------------------------
    # StandardModel (GraphNeT)
    # ----------------------------
    model = StandardModel(
        graph_definition=graph_definition,
        backbone=backbone,
        tasks=[task],
        optimizer_class=torch.optim.Adam,
        optimizer_kwargs={"lr": args.lr_max, "eps": 1e-3},
        scheduler_class=PiecewiseLinearLR,
        scheduler_kwargs=scheduler_kwargs,
        scheduler_config={"interval": "step"},
    )

    # ----------------------------
    # Callbacks: progress + early stopping(best save/load)
    # ----------------------------
    early_stopping = GraphnetEarlyStopping(
        save_dir=args.outdir,
        monitor="val_loss",
        mode="min",
        patience=5,
        verbose=True,
    )

    callbacks = [ProgressBar(), early_stopping]

    # CSV logger: metrics.csv kaydeder
    csv_logger = CSVLogger(save_dir=args.outdir, name="pl_logs")

    # Epoch başına 1 log gibi davran (batch spam olmasın)
    log_every_n_steps = max(1, steps_per_epoch)

    trainer = Trainer(
        max_epochs=args.epochs,
        logger=csv_logger,
        callbacks=callbacks,
        log_every_n_steps=log_every_n_steps,
        enable_checkpointing=False,  # best'i EarlyStopping callback zaten kaydediyor
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
    )

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    # Training bitti: EarlyStopping on_fit_end ile "best_model.pth" yüklemiş oluyor.
    # Ek olarak final çıktıları kaydet:
    model.save_config(os.path.join(args.outdir, "final_config.yml"))
    model.save_state_dict(os.path.join(args.outdir, "final_state_dict.pth"))
    model.save(os.path.join(args.outdir, "final_model.pth"))

    print(f"\nDONE. Outputs saved under: {args.outdir}\n")


if __name__ == "__main__":
    main()


In [None]:
python train_dynedge_energy.py \
  --train /project/def-nahee/kbas/POM_Response_Parquet/train_merged \
  --val   /project/def-nahee/kbas/POM_Response_Parquet/val_merged \
  --outdir /project/def-nahee/kbas/runs/dynedge_energy_run01 \
  --batch-size 4 \
  --num-workers 4 \
  --epochs 30
