In [None]:
from itertools import product
import os
from pathlib import Path
import sys
import time
from typing import Dict, List, Tuple

import h5py
import lightning
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC
import yaml

try:
    project_root = Path(__file__).parent.parent
except NameError:
    '''Jupyter notebook environment has no __file__ attribute.'''
    project_root = Path.cwd().parent
sys.path.append(project_root.as_posix())

from src.model_cnn import CNN_EventCNN
from src.model_part import ParT_Light
from src import utils

In [None]:
def wrap_pi(phi: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
    """Wrap angles to [-pi, pi)."""
    return (phi + np.pi) % (2 * np.pi) - np.pi

In [None]:
class MCSimData:
    def __init__(self, path: str | Path, luminosity: float = 1.0):
        self.path = str(path)
        self.luminosity = luminosity

        # -------- Channels --------
        self.detector_channels = ['TOWER', 'TRACK']
        self.slices = [slice(0, 250), slice(250, 400)]
        if 'diphoton' in self.path:
            self.decay_channel = 'PHOTON'
            self.slices.append(slice(400, 402))
        elif 'zz4l' in self.path:
            self.decay_channel = 'LEPTON'
            self.slices.append(slice(400, 404))
        else:
            raise ValueError(f"Unsupported dataset: {self.path}. Supported datasets are 'diphoton' and 'zz4l'.")
        self.channels = self.detector_channels + ([self.decay_channel] if self.decay_channel else [])

        with h5py.File(str(path), 'r') as hdf5_file:
            # -------- Jet flavor --------
            self.jet_flavor = self._extract_jet_flavor(hdf5_file)

            # -------- Particle flow information (pt, eta, phi) --------
            particle_flow = self._extract_particle_flow(hdf5_file, self.channels)

            # -------- Preprocessing --------
            particle_flow = self._preprocess_phi_transformation(particle_flow)
            particle_flow = self._preprocess_center_of_phi(particle_flow)
            particle_flow = self._preprocess_flipping(particle_flow)

        self.particle_flow = particle_flow

    def _extract_jet_flavor(self, hdf5_file: h5py.File) -> Dict[str, np.ndarray | int]:
        """Build gluon/quark composition masks from J1/J2 flavors."""

        J1 = np.asarray(hdf5_file["J1"]["flavor"][:])
        J2 = np.asarray(hdf5_file["J2"]["flavor"][:])
        g1, g2 = (J1 == 21), (J2 == 21)  # 21 == gluon

        mask_2q0g = (~g1) & (~g2)
        mask_1q1g = ((~g1) & g2) | (g1 & (~g2))
        mask_0q2g = g1 & g2

        return {
            "2q0g": mask_2q0g,
            "1q1g": mask_1q1g,
            "0q2g": mask_0q2g,
            "total": len(J1),
        }
    
    def _extract_particle_flow(self, hdf5_file: h5py.File, channels: List[str]) -> np.ndarray:
        """Load pt/eta/phi/(mask) for each channel."""

        # -------- Particle flow array (N, ΣM, 4) --------
        pts = np.concatenate([np.asarray(hdf5_file[channel]['pt'][:],  dtype=np.float32) for channel in channels], axis=1)  # (N, ΣM)
        etas = np.concatenate([np.asarray(hdf5_file[channel]['eta'][:], dtype=np.float32) for channel in channels], axis=1)  # (N, ΣM)
        phis = np.concatenate([np.asarray(hdf5_file[channel]['phi'][:], dtype=np.float32) for channel in channels], axis=1)  # (N, ΣM)
        phis = wrap_pi(phis)
        particle_flow = np.stack([pts, etas, phis], axis=-1)  # (N, ΣM, 3)
        
        # -------- Mask array (N, ΣM) --------
        mask = []
        for channel in channels:
            if 'mask' in hdf5_file[channel]:
                _mask = np.asarray(hdf5_file[channel]['mask'][:], dtype=bool)  # (N, M)
            else:
                _mask = np.ones_like(hdf5_file[channel]['pt'][:], dtype=bool)
            mask.append(_mask)
        mask = np.concatenate(mask, axis=1)  # (N, ΣM)

        # -------- Apply mask to particle flow --------
        particle_flow = np.where(mask[..., None], particle_flow, np.nan)

        return particle_flow

    def _preprocess_phi_transformation(self, particle_flow: np.ndarray) -> np.ndarray:
        """Transform phi to reduce variance (if var(phi) > 0.5, phi -> phi + pi)."""

        phi = particle_flow[..., 2]
        phi_var = np.var(np.nan_to_num(phi, nan=0.0), axis=-1, keepdims=True)
        phi = np.where(phi_var > 0.5, phi + np.pi, phi)
        phi = wrap_pi(phi)
        particle_flow[..., 2] = phi

        return particle_flow

    def _preprocess_center_of_phi(self, particle_flow: np.ndarray, eps: float = 1e-8) -> np.ndarray:
        """Shift phi to the center of pt frame."""

        pt = particle_flow[..., 0]  # (N, ΣM)
        phi = particle_flow[..., 2]  # (N, ΣM)

        pt_phi = np.nansum(pt * phi, axis=-1, keepdims=True)  # (N, 1)
        center_of_phi = pt_phi / (np.nansum(pt, axis=-1, keepdims=True) + eps)  # (N, 1)
        phi = wrap_pi(phi - center_of_phi)
        particle_flow[..., 2] = phi

        return particle_flow

    def _preprocess_flipping(self, particle_flow: np.ndarray) -> np.ndarray:
        """Flip quadrant with highest pt to the first quadrant (phi > 0, eta > 0)."""

        pt = particle_flow[..., 0]  # (N, ΣM)
        eta = particle_flow[..., 1]  # (N, ΣM)
        phi = particle_flow[..., 2]  # (N, ΣM)

        # -------- Quadrant pT sums (0: ++, 1: +-, 2: --, 3: -+) --------
        cond0 = (eta > 0) & (phi > 0)
        cond1 = (eta > 0) & (phi < 0)
        cond2 = (eta < 0) & (phi < 0)
        cond3 = (eta < 0) & (phi > 0)
        conds = np.stack([cond0, cond1, cond2, cond3], axis=-1)  # (N, ΣM, 4)
        pt_quadrants = np.nansum(pt[..., None] * conds, axis=1)  # (N, 4)

        # -------- Decide flips per event --------
        q_argmax = np.argmax(pt_quadrants, axis=1)  # (N,)
        phi_flip = np.where((q_argmax == 1) | (q_argmax == 2), -1.0, 1.0)[:, None]  # (N, 1)
        eta_flip = np.where((q_argmax == 2) | (q_argmax == 3), -1.0, 1.0)[:, None]  # (N, 1)

        # -------- Apply flips to particle flow --------
        eta = eta * eta_flip
        phi = wrap_pi(phi * phi_flip)
        particle_flow[..., 1] = eta
        particle_flow[..., 2] = phi

        return particle_flow

    def to_image(self, particle_flow: np.ndarray, include_decay: bool, grid_size: int = 40, eps: float=1e-8) -> torch.Tensor:
        """Convert the particle flow data to images (N, C, H, W)."""

        particle_flow = np.where(np.isnan(particle_flow), 0.0, particle_flow)  # (N, ΣM, 3)

        phi_bins = np.linspace(-np.pi, np.pi, grid_size + 1, dtype=np.float32)
        eta_bins = np.linspace(-5.0, 5.0, grid_size + 1, dtype=np.float32)

        def array_to_image(array) -> np.ndarray:
            """Convert one channel to image (N, H, W)."""

            pt, eta, phi = array[..., 0], array[..., 1], array[..., 2]  # (N, M)

            N, M = pt.shape
            image = np.zeros((N, grid_size, grid_size), dtype=np.float32)

            phi_idx = np.digitize(phi, phi_bins, right=False) - 1
            eta_idx = np.digitize(eta, eta_bins, right=False) - 1
            phi_idx = np.clip(phi_idx, 0, grid_size - 1)
            eta_idx = np.clip(eta_idx, 0, grid_size - 1)

            event_idx = np.repeat(np.arange(N, dtype=np.int64), M)
            np.add.at(image, (event_idx, eta_idx.ravel(), phi_idx.ravel()), pt.ravel())

            return image

        images = []
        if include_decay:
            for _slice in self.slices:
                array = particle_flow[:, _slice, :]  # (N, M, 3)
                images.append(array_to_image(array))
        else:
            decay_image = array_to_image(particle_flow[:, self.slices[-1], :])
            for i, channel in enumerate(self.detector_channels):
                array = particle_flow[:, self.slices[i], :]  # (N, M, 3)
                image = array_to_image(array)
                if ('diphoton' in self.path and channel == 'TOWER') or ('zz4l' in self.path and channel == 'TRACK'):
                    image = image - decay_image
                images.append(image)
        images = np.stack(images, axis=1)  # (N, C, H, W)

        # --- pt normalisation per (N, C) across H*W ---
        N, C, H, W = images.shape
        flat = images.reshape(N, C, -1) 
        mean = flat.mean(axis=-1, keepdims=True) 
        std = flat.std(axis=-1, keepdims=True)
        std = np.clip(std, a_min=eps, a_max=None)
        images = (flat - mean) / std
        images = images.reshape(N, C, H, W)

        return torch.from_numpy(images).float()

    def to_sequence(self, particle_flow: np.ndarray, include_decay: bool, eps: float = 0.0) -> torch.Tensor:
        """Convert particle flow features to sequences (N, ΣM_selected, 3+C)."""

        # Choose which channel names / spans to emit
        if include_decay:
            channel_slices = self.slices
        else:
            channel_slices = self.slices[:-1]

            # --- Remove detector hits that match decay objects (like _exclude_decay_information) ---
            decay_slice = self.slices[-1]
            decay_eta = particle_flow[:, decay_slice, 1]  # (N, M_dec)
            decay_phi = particle_flow[:, decay_slice, 2]  # (N, M_dec)

            # NaNs compare False in <= / ==, so NaN decay entries are ignored automatically
            for detector_slice in channel_slices:
                detector_eta = particle_flow[:, detector_slice, 1]  # (N, M_det)
                detector_phi = particle_flow[:, detector_slice, 2]  # (N, M_det)

                # broadcast (N, M_det, 1) vs (N, 1, M_dec) -> (N, M_det, M_dec)
                eta_diff2 = (detector_eta[:, :, None] - decay_eta[:, None, :]) ** 2
                phi_diff2 = (detector_phi[:, :, None] - decay_phi[:, None, :]) ** 2
                dist2 = eta_diff2 + phi_diff2

                if eps == 0.0:
                    matched = (dist2 == 0.0).any(axis=-1)  # (N, M_det)
                else:
                    matched = (dist2 <= (eps * eps)).any(axis=-1)

                # Set matched detector hits to NaN across [pt, eta, phi]
                detector_view = particle_flow[:, detector_slice, :]
                detector_view[matched] = np.nan

        # --- Build sequences with per-event pt normalization and one-hot channel indicator ---
        C = len(self.channels) if include_decay else len(self.detector_channels)
        sequences = []

        for one_hot_index, detector_slice in enumerate(channel_slices):
            pt  = particle_flow[:, detector_slice, 0]  # (N, M)
            eta = particle_flow[:, detector_slice, 1]  # (N, M)
            phi = particle_flow[:, detector_slice, 2]  # (N, M)

            # mask of valid hits for this channel
            valid = ~np.isnan(pt)  # (N, M)

            # per-event normalization (ignore NaNs)
            pt_mean = np.nanmean(pt, axis=-1, keepdims=True)  # (N,1)
            pt_std  = np.nanstd(pt,  axis=-1, keepdims=True)  # (N,1)
            pt = (pt - pt_mean) / pt_std

            feat = np.stack([pt, eta, phi], axis=-1)  # (N, M, 3)

            # one-hot channel id
            one_hot = np.zeros((1, 1, C), dtype=feat.dtype)  # (1,1,C)
            one_hot[..., one_hot_index] = 1.0
            one_hot = np.broadcast_to(one_hot, (feat.shape[0], feat.shape[1], C))
            feat_oh = np.concatenate([feat, one_hot], axis=-1)  # (N, M, 3+C)

            # keep NaNs where invalid
            feat_oh = np.where(valid[..., None], feat_oh, np.nan)
            sequences.append(feat_oh)

        # concat channels along the sequence axis
        sequences = np.concatenate(sequences, axis=1)  # (N, ΣM_selected, 3+C)

        return torch.from_numpy(sequences).float()


In [None]:
class LitDataModule(lightning.LightningDataModule):
    def __init__(self, batch_size: int, data_mode: str, data_format: str, data_info: dict, 
                 include_decay: bool, luminosity: float = None, num_phi_augmentation: int = 0,
                 **kwargs):
        super().__init__()

        self.save_hyperparameters()
        self.batch_size = batch_size

        # Monte Carlo simulation data
        sig_data = MCSimData(project_root / data_info['signal']['path'])
        bkg_data = MCSimData(project_root / data_info['background']['path'])

        # Create mixed dataset for implementing CWoLa
        if data_mode == 'jet_flavor':
            train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg = self.split_by_jet_flavor(luminosity, data_info, sig_data, bkg_data)
        elif data_mode == 'supervised':
            train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg = self.split_by_supervised(sig_data, bkg_data)
        else:
            raise ValueError(f"Unsupported data mode: {data_mode}. Supported data modes are 'jet_flavor' and 'supervised'.")
        
        # Data augmentation by random phi rotation
        if num_phi_augmentation > 0:
            train_sig = self.augment_phi_per_event(train_sig, num_phi_augmentation)
            train_bkg = self.augment_phi_per_event(train_bkg, num_phi_augmentation)
        
        # Transform to desired data format
        if data_format == 'image':
            train_sig = sig_data.to_image(train_sig, include_decay)
            train_bkg = bkg_data.to_image(train_bkg, include_decay)
            valid_sig = sig_data.to_image(valid_sig, include_decay)
            valid_bkg = bkg_data.to_image(valid_bkg, include_decay)
            test_sig  = sig_data.to_image(test_sig, include_decay)
            test_bkg  = bkg_data.to_image(test_bkg, include_decay)
        elif data_format == 'sequence':
            train_sig = sig_data.to_sequence(train_sig, include_decay)
            train_bkg = bkg_data.to_sequence(train_bkg, include_decay)
            valid_sig = sig_data.to_sequence(valid_sig, include_decay)
            valid_bkg = bkg_data.to_sequence(valid_bkg, include_decay)
            test_sig  = sig_data.to_sequence(test_sig, include_decay)
            test_bkg  = bkg_data.to_sequence(test_bkg, include_decay)

        # For tracking number of data samples
        self.train_sig, self.train_bkg = train_sig, train_bkg
        self.valid_sig, self.valid_bkg = valid_sig, valid_bkg
        self.test_sig, self.test_bkg = test_sig, test_bkg

        # Create torch datasets
        self.train_dataset = TensorDataset(torch.cat([train_sig, train_bkg], dim=0), torch.cat([torch.ones(len(train_sig)), torch.zeros(len(train_bkg))], dim=0))
        self.valid_dataset = TensorDataset(torch.cat([valid_sig, valid_bkg], dim=0), torch.cat([torch.ones(len(valid_sig)), torch.zeros(len(valid_bkg))], dim=0))
        self.test_dataset  = TensorDataset(torch.cat([test_sig, test_bkg], dim=0), torch.cat([torch.ones(len(test_sig)), torch.zeros(len(test_bkg))], dim=0))

        # Calculate positive weight for loss function
        num_pos = len(train_sig)  # y == 1
        num_neg = len(train_bkg)  # y == 0
        self.pos_weight = torch.tensor([num_neg / num_pos], dtype=torch.float32)

    def split_by_supervised(self, sig_data: MCSimData, bkg_data: MCSimData):
        """Split data for supervised training."""

        NUM_TRAIN, NUM_VALID, NUM_TEST = 100000, 25000, 25000

        perm_sig = np.random.permutation(len(sig_data.particle_flow))
        perm_bkg = np.random.permutation(len(bkg_data.particle_flow))
        sig_array = sig_data.particle_flow[perm_sig]
        bkg_array = bkg_data.particle_flow[perm_bkg]

        train_sig = sig_array[:NUM_TRAIN]
        train_bkg = bkg_array[:NUM_TRAIN]
        valid_sig = sig_array[NUM_TRAIN: NUM_TRAIN + NUM_VALID]
        valid_bkg = bkg_array[NUM_TRAIN: NUM_TRAIN + NUM_VALID]
        test_sig = sig_array[NUM_TRAIN + NUM_VALID: NUM_TRAIN + NUM_VALID + NUM_TEST]
        test_bkg = bkg_array[NUM_TRAIN + NUM_VALID: NUM_TRAIN + NUM_VALID + NUM_TEST]

        return train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg

    def split_by_jet_flavor(self, luminosity: float, data_info:dict, sig_data: MCSimData, bkg_data: MCSimData):
        """Split data by jet flavor composition for CWoLa training."""

        def get_event_counts(data_type: str, cut_info_key: str):
            cut_info_path = project_root / data_info[data_type]['cut_info']
            cut_info_npy = np.load(cut_info_path, allow_pickle=True)
            cut_info = cut_info_npy.item()['cutflow_number']
            N = int(data_info[data_type]['cross_section'] * cut_info[cut_info_key] / cut_info['Total'] * data_info['branching_ratio'] * luminosity)
            print(f"[CWoLa-Log] [{data_type}] {cut_info_key}: {N} events")
            return N

        NUM_TEST = 10000
        num_sig_in_sig = get_event_counts('signal', 'two quark jet: sig region')
        num_sig_in_bkg = get_event_counts('signal', 'two quark jet: bkg region')
        num_bkg_in_sig = get_event_counts('background', 'two quark jet: sig region')
        num_bkg_in_bkg = get_event_counts('background', 'two quark jet: bkg region')

        sig_in_sig_mask = sig_data.jet_flavor['2q0g']
        sig_in_bkg_mask = sig_data.jet_flavor['1q1g'] | sig_data.jet_flavor['0q2g']
        bkg_in_sig_mask = bkg_data.jet_flavor['2q0g']
        bkg_in_bkg_mask = bkg_data.jet_flavor['1q1g'] | bkg_data.jet_flavor['0q2g']
        num_test_sig_in_sig = int(NUM_TEST * np.sum(sig_in_sig_mask) / (np.sum(sig_in_sig_mask) + np.sum(sig_in_bkg_mask)))
        num_test_sig_in_bkg = NUM_TEST - num_test_sig_in_sig
        num_test_bkg_in_sig = int(NUM_TEST * np.sum(bkg_in_sig_mask) / (np.sum(bkg_in_sig_mask) + np.sum(bkg_in_bkg_mask)))
        num_test_bkg_in_bkg = NUM_TEST - num_test_bkg_in_sig

        sig_in_sig = sig_data.particle_flow[sig_in_sig_mask]
        sig_in_bkg = sig_data.particle_flow[sig_in_bkg_mask]
        bkg_in_sig = bkg_data.particle_flow[bkg_in_sig_mask]
        bkg_in_bkg = bkg_data.particle_flow[bkg_in_bkg_mask]

        idx_sig_in_sig = np.random.choice(len(sig_in_sig), num_sig_in_sig + num_test_sig_in_sig, replace=False)
        idx_sig_in_bkg = np.random.choice(len(sig_in_bkg), num_sig_in_bkg + num_test_sig_in_bkg, replace=False)
        idx_bkg_in_sig = np.random.choice(len(bkg_in_sig), num_bkg_in_sig + num_test_bkg_in_sig, replace=False)
        idx_bkg_in_bkg = np.random.choice(len(bkg_in_bkg), num_bkg_in_bkg + num_test_bkg_in_bkg, replace=False)

        sig_in_sig = sig_in_sig[idx_sig_in_sig]
        sig_in_bkg = sig_in_bkg[idx_sig_in_bkg]
        bkg_in_sig = bkg_in_sig[idx_bkg_in_sig]
        bkg_in_bkg = bkg_in_bkg[idx_bkg_in_bkg]

        def split_data(data: np.array, num_samples: int, num_test: int):
            TRAIN_SIZE_RATIO = 0.8
            train_size = int(num_samples * TRAIN_SIZE_RATIO)
            return (
                data[:train_size],
                data[train_size:num_samples],
                data[num_samples:num_samples + num_test]
            )

        train_sig_in_sig, valid_sig_in_sig, test_sig_in_sig = split_data(sig_in_sig, num_sig_in_sig, num_test_sig_in_sig)
        train_sig_in_bkg, valid_sig_in_bkg, test_sig_in_bkg = split_data(sig_in_bkg, num_sig_in_bkg, num_test_sig_in_bkg)
        train_bkg_in_sig, valid_bkg_in_sig, test_bkg_in_sig = split_data(bkg_in_sig, num_bkg_in_sig, num_test_bkg_in_sig)
        train_bkg_in_bkg, valid_bkg_in_bkg, test_bkg_in_bkg = split_data(bkg_in_bkg, num_bkg_in_bkg, num_test_bkg_in_bkg)

        train_sig = np.concatenate([train_sig_in_sig, train_bkg_in_sig], axis=0)
        train_bkg = np.concatenate([train_sig_in_bkg, train_bkg_in_bkg], axis=0)
        valid_sig = np.concatenate([valid_sig_in_sig, valid_bkg_in_sig], axis=0)
        valid_bkg = np.concatenate([valid_sig_in_bkg, valid_bkg_in_bkg], axis=0)
        test_sig  = np.concatenate([test_sig_in_sig, test_sig_in_bkg], axis=0)
        test_bkg  = np.concatenate([test_bkg_in_sig, test_bkg_in_bkg], axis=0)

        return train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg
    
    def augment_phi_per_event(self, data: np.ndarray, k: int) -> np.ndarray:
        """Augment data by random phi rotation per event, implemented in a memory-efficient way."""

        N, M, C = data.shape
        out = np.empty(((k + 1) * N, M, C), dtype=data.dtype)
        out[:N] = data
        for i in range(k):
            s = slice((i + 1) * N, (i + 2) * N)
            out[s] = data
            shift = np.random.uniform(-np.pi, np.pi, size=(N, 1)).astype(data.dtype)
            out[s, :, 2] = wrap_pi(out[s, :, 2] + shift)
        return out

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

In [None]:
class BinaryLitModel(lightning.LightningModule):
    def __init__(self, model: nn.Module, lr: float, pos_weight: torch.Tensor = None, optimizer_settings: dict = None):
        super().__init__()
        self.save_hyperparameters()

        self.model = model
        self.lr = lr
        self.optimizer_settings = optimizer_settings
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        base_metrics = MetricCollection({"auc": BinaryAUROC(), "accuracy": BinaryAccuracy()})
        self.metrics = {
            "train": base_metrics.clone(prefix="train_"),
            "valid": base_metrics.clone(prefix="valid_"),
            "test": base_metrics.clone(prefix="test_"),
        }

    def forward(self, x: torch.Tensor) -> torch.Tensor:        
        return self.model(x)

    def configure_optimizers(self):
        optimizer_settings = self.optimizer_settings
        optimizer = getattr(torch.optim, self.optimizer_settings['optimizer'])
        optimizer = optimizer(self.parameters(), lr=self.lr)
        if optimizer_settings['lr_scheduler'] is None:
            return optimizer
        else:
            scheduler = getattr(torch.optim.lr_scheduler, optimizer_settings['lr_scheduler'])
            scheduler = scheduler(optimizer, **optimizer_settings[scheduler.__name__])
            lr_scheduler: dict = {'scheduler': scheduler}
            lr_scheduler.update(optimizer_settings['lightning_monitor'])
            return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
        
    def on_fit_start(self):
        # Move all metrics to the same device as the model
        for split in self.metrics:
            self.metrics[split] = self.metrics[split].to(self.device)

    def _shared_step(self, batch: Tuple[torch.Tensor, torch.Tensor], split: str):
        x, y_true = batch
        logits    = self(x).squeeze(-1)
        loss      = self.loss_fn(logits, y_true.float())
        y_pred    = torch.sigmoid(logits)
        self.metrics[split].update(y_pred, y_true.int())
        self.log(f"{split}_loss", loss, on_epoch=True, on_step=False, prog_bar=(split == "train"), batch_size=y_true.size(0))
        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, "valid")

    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, "test")
    
    def _compute_and_log_split(self, split: str, prog_bar: bool = False):
        computed = self.metrics[split].compute()
        self.log_dict(computed, on_epoch=True, on_step=False, prog_bar=prog_bar)
        self.metrics[split].reset()

    def on_train_epoch_end(self):
        self._compute_and_log_split("train", prog_bar=True)

    def on_validation_epoch_end(self):
        self._compute_and_log_split("valid", prog_bar=True)

    def on_test_epoch_end(self):
        self._compute_and_log_split("test", prog_bar=False)

In [None]:
def keras_like_init(module: nn.Module):
    # Dense / Conv weights = Glorot-uniform, biases = zeros
    if isinstance(module, (nn.Linear,
                           nn.Conv1d, nn.Conv2d, nn.Conv3d,
                           nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
        if getattr(module, "weight", None) is not None:
            nn.init.xavier_uniform_(module.weight)     # = Keras glorot_uniform
        if getattr(module, "bias", None) is not None:
            nn.init.zeros_(module.bias)                # = Keras zeros

    # BatchNorm: gamma=1, beta=0 (running stats untouched here)
    if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
        if getattr(module, "weight", None) is not None:
            nn.init.ones_(module.weight)               # gamma
        if getattr(module, "bias", None) is not None:
            nn.init.zeros_(module.bias)                # beta

    # LayerNorm: gamma=1, beta=0 (match Keras LayerNormalization scale/beta)
    if isinstance(module, nn.LayerNorm):
        if getattr(module, "weight", None) is not None:
            nn.init.ones_(module.weight)
        if getattr(module, "bias", None) is not None:
            nn.init.zeros_(module.bias)


def training(
        data_mode: str, data_format: str, data_info: dict, include_decay: bool,
        model_cls: nn.Module, lr: float, keras_init: bool, 
        tags: List[str], rnd_seed: int, date_time: str, **kwargs
    ):
    """Train a model with the given configuration."""

    lightning.seed_everything(rnd_seed)

    num_channels = 3 if include_decay else 2
    model = model_cls(num_channels=num_channels, keras_init=keras_init)
    if keras_init:
        model.apply(keras_like_init)

    # Output and log directories
    save_dir = project_root / 'output' / (('' if include_decay else 'ex-') + data_info['decay_channel']) / ('_'.join(tags))
    if data_mode == 'jet_flavor':
        name = f"{model.__class__.__name__}-{date_time}-L{kwargs['luminosity']}"
    elif data_mode == 'supervised':
        name = f"{model.__class__.__name__}-{date_time}-SV"
    version = f"rnd_seed-{rnd_seed}"
    output_dir = save_dir / name / version
    if os.path.exists(output_dir):
        print(f"[Warning] Output directory {output_dir} already exists.")
        return

    # Lightning data setup
    BATCH_SIZE = 512
    lit_data_module = LitDataModule(
        batch_size=BATCH_SIZE,
        data_mode=data_mode,
        data_format=data_format,
        data_info=data_info,
        include_decay=include_decay,
        **kwargs
    )

    # Lightning model setup
    with open(project_root / 'config' / 'training.yml', 'r') as f:
        training_config = yaml.safe_load(f)
    lit_model = BinaryLitModel(
        model=model,
        lr=lr,
        pos_weight=lit_data_module.pos_weight,
        optimizer_settings=training_config['optimizer_settings']
    )

    # Lightning loggers
    logger = CSVLogger(save_dir=save_dir, name=name, version=version)
    hparams = {}
    hparams.update(lit_data_module.hparams)
    hparams.update(training_config)
    logger.log_hyperparams(hparams)

    # Lightning trainer
    trainer = lightning.Trainer(
        max_epochs=500,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        logger=logger,
        callbacks=[
            ModelCheckpoint(**training_config['ModelCheckpoint']),
            EarlyStopping(**training_config['EarlyStopping']),
        ],
    )

    # Lightning trainning and testing
    trainer.fit(lit_model, lit_data_module)
    trainer.test(lit_model, datamodule=lit_data_module, ckpt_path='best')
    os.makedirs(output_dir, exist_ok=True)
    utils.count_number_of_data(lit_data_module, output_dir)
    utils.count_model_parameters(lit_model, output_dir)
    utils.plot_metrics(output_dir)

In [None]:
DATETIME = time.strftime("%Y%m%d_%H%M%S", time.localtime())
KERAS_INIT = False
N_AUG_ROT = 5

if __name__ == '__main__':

    rnd_seeds = [123 + 100 * i for i in range(10)]

    for rnd_seed, data_mode, include_decay in product(rnd_seeds, ['jet_flavor'], [True, False]):

        with open(project_root / 'config' / 'data_diphoton.yml', 'r') as f:
            data_info = yaml.safe_load(f)

        for data_format, model_cls, lr in [
            ('image', CNN_EventCNN, 1e-4),
            # ('sequence', ParT_Light, 4e-4),
        ]:

            for luminosity in [100, 300, 900, 1800, 3000]:
                training(
                    data_mode=data_mode,
                    data_format=data_format,
                    data_info=data_info,
                    include_decay=include_decay,
                    model_cls=model_cls,
                    lr=lr,
                    keras_init=KERAS_INIT,
                    tags=[data_mode],
                    rnd_seed=rnd_seed,
                    date_time=DATETIME,
                    luminosity=luminosity,
                    num_phi_augmentation=N_AUG_ROT,
                )