In [None]:
from itertools import product
import os
from pathlib import Path
import sys
import time
from typing import Dict, List, Tuple

import h5py
import lightning
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC
import yaml

try:
    project_root = Path(__file__).parent.parent
except NameError:
    '''Jupyter notebook environment has no __file__ attribute.'''
    project_root = Path.cwd().parent
sys.path.append(project_root.as_posix())

from src.model_cnn import CNN_EventCNN
from src.model_part import ParT_Light
from src import utils

In [None]:
def wrap_pi(phi: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
    """Wrap angles to [-pi, pi)."""
    return (phi + np.pi) % (2 * np.pi) - np.pi

In [None]:
class MCSimData:
    def __init__(self, path: str | Path, include_decay=True):
        self.path = str(path)
        self.include_decay = include_decay

        with h5py.File(str(path), 'r') as hdf5_file:
            # -------- Jet flavor --------
            self.jet_flavor = self._extract_jet_flavor(hdf5_file)

            # -------- Channels --------
            if 'diphoton' in self.path:
                decay_channel = 'PHOTON'
            elif 'zz4l' in self.path:
                decay_channel = 'LEPTON'
            detector_channels = ['TOWER', 'TRACK']
            self.decay_channel = decay_channel
            self.detector_channels = detector_channels

            # -------- Particle flow information (pt, eta, phi) --------
            particle_flow = self._extract_particle_flow(hdf5_file, detector_channels + [decay_channel])

            # -------- Preprocessing --------
            particle_flow = self._preprocess_phi_transformation(particle_flow)
            particle_flow = self._preprocess_center_of_phi(particle_flow)
            particle_flow = self._preprocess_flipping(particle_flow)
            self.particle_flow = particle_flow

    def _extract_jet_flavor(self, hdf5_file: h5py.File) -> Dict[str, torch.Tensor | int]:
        """Build gluon/quark composition masks from J1/J2 flavors."""

        J1 = torch.as_tensor(hdf5_file["J1"]["flavor"][:], dtype=torch.long)
        J2 = torch.as_tensor(hdf5_file["J2"]["flavor"][:], dtype=torch.long)
        g1, g2 = (J1 == 21), (J2 == 21)  # 21 == gluon

        mask_2q0g = (~g1) & (~g2)
        mask_1q1g = ((~g1) & g2) | (g1 & (~g2))
        mask_0q2g = g1 & g2

        return {
            "2q0g": mask_2q0g,
            "1q1g": mask_1q1g,
            "0q2g": mask_0q2g,
            "total": len(J1),
        }
    
    def _extract_particle_flow(self, hdf5_file: h5py.File, channels: List[str]) -> Dict[str, np.ndarray]:
        """Load pt/eta/phi/(mask) for each channel."""

        particle_flow: Dict[str, Dict[str, np.ndarray]] = {}
        
        for channel in channels:
            pt  = np.asarray(hdf5_file[channel]['pt'][:],  dtype=np.float32)
            eta = np.asarray(hdf5_file[channel]['eta'][:], dtype=np.float32)
            phi = np.asarray(hdf5_file[channel]['phi'][:], dtype=np.float32)
            phi = wrap_pi(phi)

            entry = {'pt': pt, 'eta': eta, 'phi': phi}
            if 'mask' in hdf5_file[channel]:
                mask = np.asarray(hdf5_file[channel]['mask'][:], dtype=bool)
            else:
                mask = np.ones_like(pt, dtype=bool)
            entry['mask'] = mask
            particle_flow[channel] = entry

        return particle_flow

    def _exclude_decay_information(self, particle_flow: Dict[str, np.ndarray], eps: float = 0.) -> Dict[str, np.ndarray]:
        """Mask out detector hits that coincide with decay objects (exact eta/phi match)."""

        decay_pt  = particle_flow[self.decay_channel]['pt']
        decay_eta = particle_flow[self.decay_channel]['eta']
        decay_phi = particle_flow[self.decay_channel]['phi']

        for channel in self.detector_channels:
            pt   = particle_flow[channel]['pt']
            eta  = particle_flow[channel]['eta']
            phi  = particle_flow[channel]['phi']
            mask = particle_flow[channel]['mask']

            eta_diff2 = (eta[:, :, np.newaxis] - decay_eta[:, np.newaxis, :]) ** 2
            phi_diff2 = (phi[:, :, np.newaxis] - decay_phi[:, np.newaxis, :]) ** 2
            dist2 = eta_diff2 + phi_diff2  # (N, M, M)

            # A detector hit is "matched to a decay" if any decay object is within tolerance
            if eps == 0.0:
                matched = (dist2 == 0.0).any(axis=-1)  # (N, M)
            else:
                matched = (dist2 <= (eps ** 2)).any(axis=-1)  # (N, M)
            non_decay = ~matched
            mask = mask & non_decay
            particle_flow[channel]['mask'] = mask
            particle_flow[channel]['pt'] = np.where(mask, pt, 0.0)
            particle_flow[channel]['eta'] = np.where(mask, eta, 0.0)
            particle_flow[channel]['phi'] = np.where(mask, phi, 0.0)

            # Diagnostic: 100% means exactly K decay matches per event have been excluded
            num_decay_match = np.sum(non_decay, axis=-1) == (pt.shape[-1] - decay_pt.shape[-1])
            purity = np.sum(num_decay_match) / num_decay_match.shape[0]
            print(f'[{self.__class__.__name__} Log] {self.path}-{channel} has purity {100 * purity:.4f}%')

        return particle_flow
    
    def _preprocess_phi_transformation(self, particle_flow: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        """Transform phi to reduce variance (if var(phi) > 0.5, phi -> phi + pi)."""

        global_phi = np.concatenate([particle_flow[channel]['phi'] for channel in particle_flow], axis=-1)
        global_phi_var = np.var(global_phi, axis=-1, keepdims=True)

        for channel in particle_flow:
            phi = particle_flow[channel]['phi']
            phi = np.where(global_phi_var > 0.5, phi + np.pi, phi)
            phi = wrap_pi(phi)
            particle_flow[channel]['phi'] = phi

        return particle_flow

    def _preprocess_center_of_phi(self, particle_flow: Dict[str, np.ndarray], eps: float = 1e-8) -> Dict[str, np.ndarray]:
        """Shift phi to the center of pt frame."""

        global_pt = np.concatenate([particle_flow[channel]['pt'] for channel in particle_flow], axis=-1)
        global_phi = np.concatenate([particle_flow[channel]['phi'] for channel in particle_flow], axis=-1)
        global_pt_phi = np.sum(global_pt * global_phi, axis=-1, keepdims=True)
        center_of_phi = global_pt_phi / (np.sum(global_pt, axis=-1, keepdims=True) + eps)

        for channel in particle_flow:
            phi = particle_flow[channel]['phi']
            phi = phi - center_of_phi
            phi = wrap_pi(phi)
            particle_flow[channel]['phi'] = phi

        return particle_flow

    def _preprocess_flipping(self, particle_flow: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        """Flip quadrant with highest pt to the first quadrant (phi > 0, eta > 0)."""

        global_pt = np.concatenate([particle_flow[channel]['pt'] for channel in particle_flow], axis=-1)
        global_phi = np.concatenate([particle_flow[channel]['phi'] for channel in particle_flow], axis=-1)
        global_eta = np.concatenate([particle_flow[channel]['eta'] for channel in particle_flow], axis=-1)

        # -------- Quadrant pT sums (0: ++, 1: +-, 2: --, 3: -+) --------
        cond0 = (global_eta > 0) & (global_phi > 0)
        cond1 = (global_eta > 0) & (global_phi < 0)
        cond2 = (global_eta < 0) & (global_phi < 0)
        cond3 = (global_eta < 0) & (global_phi > 0)
        conds = np.stack([cond0, cond1, cond2, cond3], axis=-1)    # (N, ΣM, 4)
        pt_quadrants = (global_pt[..., None] * conds).sum(axis=1)  # (N, 4)

        # -------- Decide flips per event --------
        q_argmax = np.argmax(pt_quadrants, axis=1)                 # (N,)
        phi_flip = np.where((q_argmax == 1) | (q_argmax == 2), -1.0, 1.0)[:, None]  # (N,1)
        eta_flip = np.where((q_argmax == 2) | (q_argmax == 3), -1.0, 1.0)[:, None]  # (N,1)

        # -------- Apply flips to every channel --------
        for channel in particle_flow:
            particle_flow[channel]['eta'] = particle_flow[channel]['eta'] * eta_flip
            particle_flow[channel]['phi'] = wrap_pi(particle_flow[channel]['phi'] * phi_flip)

        return particle_flow

    
    def to_image(self, grid_size: int = 40, eps=1e-8) -> torch.Tensor:
        """Convert the particle flow data to images (N, C, H, W)."""

        particle_flow = self.particle_flow.copy()

        phi_bins = np.linspace(-np.pi, np.pi, grid_size + 1, dtype=np.float32)
        eta_bins = np.linspace(-5.0, 5.0, grid_size + 1, dtype=np.float32)

        def array_to_image(channel: str) -> np.ndarray:
            """Convert one channel to image (N, H, W)."""

            mask = particle_flow[channel]['mask']  # (N, M)
            pt   = np.where(mask, particle_flow[channel]['pt'], 0.0)   # (N, M)
            eta  = np.where(mask, particle_flow[channel]['eta'], 0.0)  # (N, M)
            phi  = np.where(mask, particle_flow[channel]['phi'], 0.0)  # (N, M)

            N, M = pt.shape
            image = np.zeros((N, grid_size, grid_size), dtype=np.float32)

            phi_idx = np.digitize(phi, phi_bins, right=False) - 1
            eta_idx = np.digitize(eta, eta_bins, right=False) - 1
            phi_idx = np.clip(phi_idx, 0, grid_size - 1)
            eta_idx = np.clip(eta_idx, 0, grid_size - 1)

            event_idx = np.repeat(np.arange(N, dtype=np.int64), M)
            np.add.at(image, (event_idx, phi_idx.ravel(), eta_idx.ravel()), pt.ravel())

            return image

        images = []
        if self.include_decay:
            for channel in self.detector_channels + [self.decay_channel]:
                images.append(array_to_image(channel))
        else:
            decay_image = array_to_image(self.decay_channel)
            for channel in self.detector_channels:
                detector_image = array_to_image(channel)
                images.append(np.where(decay_image == 0.0, detector_image, 0.0))
        images = np.stack(images, axis=1)  # (N, C, H, W)

        # --- pt normalisation per (N, C) across H*W ---
        N, C, H, W = images.shape
        flat = images.reshape(N, C, -1) 
        mean = flat.mean(axis=-1, keepdims=True) 
        std = flat.std(axis=-1, keepdims=True)
        std = np.clip(std, a_min=eps, a_max=None)
        images = (flat - mean) / std
        images = images.reshape(N, C, H, W)

        return torch.from_numpy(images).float()

    def to_sequence(self, eps=1e-8) -> torch.Tensor:
        """Convert the particle flow data to sequences (N, ΣM, 3+|C|)."""

        particle_flow = self.particle_flow.copy()

        if self.include_decay:
            channels = self.detector_channels + [self.decay_channel]
        else:
            particle_flow = self._exclude_decay_information(particle_flow, eps=0.0)
            channels = self.detector_channels

        seqs = []
        C = len(channels)

        for one_hot_index, channel in enumerate(channels):
            mask = particle_flow[channel]['mask']
            pt   = particle_flow[channel]['pt']
            eta  = particle_flow[channel]['eta']
            phi  = particle_flow[channel]['phi']

            # --- pt normalization (per event) ---
            pt_mean = np.mean(pt[mask], axis=-1, keepdims=True)
            pt_std  = np.std(pt[mask], axis=-1, keepdims=True)
            pt_std  = np.clip(pt_std, a_min=eps, a_max=None)
            pt = (pt - pt_mean) / pt_std

            # --- one-hot encoding ---
            feat = np.stack([pt, eta, phi], axis=-1)  # (N, M, 3)
            N, M, _ = feat.shape
            one_hot = np.zeros((1, 1, C), dtype=feat.dtype)
            one_hot[..., one_hot_index] = 1.0
            one_hot = np.broadcast_to(one_hot, (N, M, C))
            feat_oh = np.concatenate([feat, one_hot], axis=-1)  # (N, M, 3+C)

            # --- masking ---
            mask = particle_flow[channel]['mask'].astype(bool, copy=False)  # (N, M)
            feat_oh = np.where(mask[..., None], feat_oh, np.nan)

            seqs.append(feat_oh)

        seqs = np.concatenate(seqs, axis=1)  # (N, ΣM, 3+C)
        
        return torch.from_numpy(seqs).float()

In [None]:
class LitDataModule(lightning.LightningDataModule):
    def __init__(self, batch_size: int, data_mode: str, data_format: str, data_info: dict, 
                 include_decay: bool, luminosity: float = None, num_phi_augmentation: int = 0,
                 **kwargs):
        super().__init__()

        self.save_hyperparameters()

        # Monte Carlo simulation data
        sig_data = MCSimData(project_root / data_info['signal']['path'], include_decay=include_decay)
        bkg_data = MCSimData(project_root / data_info['background']['path'], include_decay=include_decay)

        # Choose the representation of the dataset
        if data_format == 'image':
            self.sig_tensor, self.bkg_tensor = sig_data.to_image(), bkg_data.to_image()
        elif data_format == 'sequence':
            self.sig_tensor, self.bkg_tensor = sig_data.to_sequence(), bkg_data.to_sequence()
        else:
            raise ValueError(f"Unsupported data format: {data_format}. Supported formats are 'image' and 'sequence'.")

        # Create mixed dataset for implementing CWoLa
        if data_mode == 'jet_flavor':
            train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg = self.split_by_jet_flavor(data_info=data_info, sig_flavor=sig_data.jet_flavor, bkg_flavor=bkg_data.jet_flavor)
        elif data_mode == 'supervised':
            train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg = self.split_by_supervised()
        else:
            raise ValueError(f"Unsupported data mode: {data_mode}. Supported data modes are 'jet_flavor' and 'supervised'.")

        if num_phi_augmentation > 0:
            train_sig = self.phi_augmentations(train_sig, num_phi_augmentation)
            train_bkg = self.phi_augmentations(train_bkg, num_phi_augmentation)

        # For tracking number of data samples
        self.train_sig, self.train_bkg = train_sig, train_bkg
        self.valid_sig, self.valid_bkg = valid_sig, valid_bkg
        self.test_sig, self.test_bkg = test_sig, test_bkg

        # Create torch datasets
        self.train_dataset = TensorDataset(torch.cat([train_sig, train_bkg], dim=0), torch.cat([torch.ones(len(train_sig)), torch.zeros(len(train_bkg))], dim=0))
        self.valid_dataset = TensorDataset(torch.cat([valid_sig, valid_bkg], dim=0), torch.cat([torch.ones(len(valid_sig)), torch.zeros(len(valid_bkg))], dim=0))
        self.test_dataset  = TensorDataset(torch.cat([test_sig, test_bkg], dim=0), torch.cat([torch.ones(len(test_sig)), torch.zeros(len(test_bkg))], dim=0))

        # Calculate positive weight for loss function
        num_pos = len(train_sig)  # y == 1
        num_neg = len(train_bkg)  # y == 0
        self.pos_weight = torch.tensor([num_neg / num_pos], dtype=torch.float32)
    
    def phi_augmentations(self, data: torch.Tensor, rotations: int) -> torch.Tensor:

        augmented_data = [data]

        for _ in range(rotations):
            new_data = data.clone()

            if self.hparams.data_format == 'image':
                shift = np.random.randint(1, new_data.shape[-2])
                new_data = torch.roll(new_data, shifts=shift, dims=-2)
            elif self.hparams.data_format == 'sequence':
                phi_shift = 2 * np.pi * np.random.rand()
                phi_column = 2
                new_data[..., phi_column] = wrap_pi(new_data[..., phi_column] + phi_shift)
            
            augmented_data.append(new_data)

        return torch.cat(augmented_data, dim=0)

    def split_by_supervised(self):

        NUM_TRAIN, NUM_VALID, NUM_TEST = 100000, 25000, 25000

        sig_tensor = self.sig_tensor[torch.randperm(len(self.sig_tensor))]
        bkg_tensor = self.bkg_tensor[torch.randperm(len(self.bkg_tensor))]
        
        train_sig = sig_tensor[:NUM_TRAIN]
        train_bkg = bkg_tensor[:NUM_TRAIN]
        valid_sig = sig_tensor[NUM_TRAIN: NUM_TRAIN + NUM_VALID]
        valid_bkg = bkg_tensor[NUM_TRAIN: NUM_TRAIN + NUM_VALID]
        test_sig = sig_tensor[NUM_TRAIN + NUM_VALID: NUM_TRAIN + NUM_VALID + NUM_TEST]
        test_bkg = bkg_tensor[NUM_TRAIN + NUM_VALID: NUM_TRAIN + NUM_VALID + NUM_TEST]

        return train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg

    def split_by_jet_flavor(self, data_info:dict, sig_flavor: torch.Tensor, bkg_flavor: torch.Tensor):

        def get_event_counts(data_type: str, cut_info_key: str):
            cut_info_path = project_root / data_info[data_type]['cut_info']
            cut_info_npy = np.load(cut_info_path, allow_pickle=True)
            cut_info = cut_info_npy.item()['cutflow_number']
            L = self.hparams.luminosity
            N = int(data_info[data_type]['cross_section'] * cut_info[cut_info_key] / cut_info['Total'] * data_info['branching_ratio'] * L)
            print(f"[CWoLa-Log] [{data_type}] {cut_info_key}: {N} events")
            return N

        NUM_TEST = 10000
        num_sig_in_sig = get_event_counts('signal', 'two quark jet: sig region')
        num_sig_in_bkg = get_event_counts('signal', 'two quark jet: bkg region')
        num_bkg_in_sig = get_event_counts('background', 'two quark jet: sig region')
        num_bkg_in_bkg = get_event_counts('background', 'two quark jet: bkg region')
        num_test_sig_in_sig = int(NUM_TEST * num_sig_in_sig / (num_sig_in_sig + num_sig_in_bkg))
        num_test_sig_in_bkg = NUM_TEST - num_test_sig_in_sig
        num_test_bkg_in_sig = int(NUM_TEST * num_bkg_in_sig / (num_bkg_in_sig + num_bkg_in_bkg))
        num_test_bkg_in_bkg = NUM_TEST - num_test_bkg_in_sig

        sig_in_sig = self.sig_tensor[sig_flavor['2q0g']]
        sig_in_bkg = self.sig_tensor[sig_flavor['1q1g'] | sig_flavor['0q2g']]
        bkg_in_sig = self.bkg_tensor[bkg_flavor['2q0g']]
        bkg_in_bkg = self.bkg_tensor[bkg_flavor['1q1g'] | bkg_flavor['0q2g']]

        idx_sig_in_sig = np.random.choice(len(sig_in_sig), num_sig_in_sig + num_test_sig_in_sig, replace=False)
        idx_sig_in_bkg = np.random.choice(len(sig_in_bkg), num_sig_in_bkg + num_test_sig_in_bkg, replace=False)
        idx_bkg_in_sig = np.random.choice(len(bkg_in_sig), num_bkg_in_sig + num_test_bkg_in_sig, replace=False)
        idx_bkg_in_bkg = np.random.choice(len(bkg_in_bkg), num_bkg_in_bkg + num_test_bkg_in_bkg, replace=False)

        sig_in_sig = sig_in_sig[idx_sig_in_sig]
        sig_in_bkg = sig_in_bkg[idx_sig_in_bkg]
        bkg_in_sig = bkg_in_sig[idx_bkg_in_sig]
        bkg_in_bkg = bkg_in_bkg[idx_bkg_in_bkg]

        def split_data(data: torch.Tensor, num_samples: int, num_test: int):
            TRAIN_SIZE_RATIO = 0.8
            train_size = int(num_samples * TRAIN_SIZE_RATIO)
            return (
                data[:train_size],
                data[train_size:num_samples],
                data[num_samples:num_samples + num_test]
            )

        train_sig_in_sig, valid_sig_in_sig, test_sig_in_sig = split_data(sig_in_sig, num_sig_in_sig, num_test_sig_in_sig)
        train_sig_in_bkg, valid_sig_in_bkg, test_sig_in_bkg = split_data(sig_in_bkg, num_sig_in_bkg, num_test_sig_in_bkg)
        train_bkg_in_sig, valid_bkg_in_sig, test_bkg_in_sig = split_data(bkg_in_sig, num_bkg_in_sig, num_test_bkg_in_sig)
        train_bkg_in_bkg, valid_bkg_in_bkg, test_bkg_in_bkg = split_data(bkg_in_bkg, num_bkg_in_bkg, num_test_bkg_in_bkg)

        train_sig = torch.cat([train_sig_in_sig, train_bkg_in_sig], dim=0)
        train_bkg = torch.cat([train_sig_in_bkg, train_bkg_in_bkg], dim=0)
        valid_sig = torch.cat([valid_sig_in_sig, valid_bkg_in_sig], dim=0)
        valid_bkg = torch.cat([valid_sig_in_bkg, valid_bkg_in_bkg], dim=0)
        test_sig  = torch.cat([test_sig_in_sig, test_sig_in_bkg], dim=0)
        test_bkg  = torch.cat([test_bkg_in_sig, test_bkg_in_bkg], dim=0)

        return train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.hparams.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, shuffle=False)

In [None]:
class BinaryLitModel(lightning.LightningModule):
    def __init__(self, model: nn.Module, lr: float, pos_weight: torch.Tensor = None, optimizer_settings: dict = None):
        super().__init__()
        self.save_hyperparameters()

        self.model = model
        self.lr = lr
        self.optimizer_settings = optimizer_settings
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        base_metrics = MetricCollection({"auc": BinaryAUROC(), "accuracy": BinaryAccuracy()})
        self.metrics = {
            "train": base_metrics.clone(prefix="train_"),
            "valid": base_metrics.clone(prefix="valid_"),
            "test": base_metrics.clone(prefix="test_"),
        }

    def forward(self, x: torch.Tensor) -> torch.Tensor:        
        return self.model(x)

    def configure_optimizers(self):
        optimizer_settings = self.optimizer_settings
        optimizer = getattr(torch.optim, self.optimizer_settings['optimizer'])
        optimizer = optimizer(self.parameters(), lr=self.lr)
        if optimizer_settings['lr_scheduler'] is None:
            return optimizer
        else:
            scheduler = getattr(torch.optim.lr_scheduler, optimizer_settings['lr_scheduler'])
            scheduler = scheduler(optimizer, **optimizer_settings[scheduler.__name__])
            lr_scheduler: dict = {'scheduler': scheduler}
            lr_scheduler.update(optimizer_settings['lightning_monitor'])
            return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
        
    def on_fit_start(self):
        # Move all metrics to the same device as the model
        for split in self.metrics:
            self.metrics[split] = self.metrics[split].to(self.device)

    def _shared_step(self, batch: Tuple[torch.Tensor, torch.Tensor], split: str):
        x, y_true = batch
        logits    = self(x).squeeze(-1)
        loss      = self.loss_fn(logits, y_true.float())
        y_pred    = torch.sigmoid(logits)
        self.metrics[split].update(y_pred, y_true.int())
        self.log(f"{split}_loss", loss, on_epoch=True, on_step=False, prog_bar=(split == "train"), batch_size=y_true.size(0))
        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, "valid")

    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, "test")
    
    def _compute_and_log_split(self, split: str, prog_bar: bool = False):
        computed = self.metrics[split].compute()
        self.log_dict(computed, on_epoch=True, on_step=False, prog_bar=prog_bar)
        self.metrics[split].reset()

    def on_train_epoch_end(self):
        self._compute_and_log_split("train", prog_bar=True)

    def on_validation_epoch_end(self):
        self._compute_and_log_split("valid", prog_bar=True)

    def on_test_epoch_end(self):
        self._compute_and_log_split("test", prog_bar=False)

In [None]:
def keras_like_init(module: nn.Module):
    # Conv / Linear → Glorot-uniform (aka Xavier-uniform), bias zeros
    if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d,
                           nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d,
                           nn.Linear)):
        if module.weight is not None:
            nn.init.xavier_uniform_(module.weight)   # Keras: glorot_uniform
        if module.bias is not None:
            nn.init.zeros_(module.bias)              # Keras: zeros

    # BatchNorm → gamma=1, beta=0 (running stats are already mean=0, var=1 in PyTorch)
    if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
        if module.weight is not None:
            nn.init.ones_(module.weight)             # gamma
        if module.bias is not None:
            nn.init.zeros_(module.bias)              # beta

In [None]:
DATETIME = time.strftime("%Y%m%d_%H%M%S", time.localtime())

def training(
        data_mode: str, data_format: str, data_info: dict, include_decay: bool,
        model_cls: nn.Module, lr: float, tags: List[str], rnd_seed: int, **kwargs
    ):
    """Train a model with the given configuration."""

    lightning.seed_everything(rnd_seed)

    num_channels = 3 if include_decay else 2
    model = model_cls(num_channels=num_channels)
    model.apply(keras_like_init)

    # Output and log directories
    save_dir = project_root / 'output' / (('' if include_decay else 'ex-') + data_info['decay_channel']) / ('_'.join(tags))
    if data_mode == 'jet_flavor':
        name = f"{model.__class__.__name__}-{DATETIME}-L{kwargs['luminosity']}"
    elif data_mode == 'supervised':
        name = f"{model.__class__.__name__}-{DATETIME}-SV"
    version = f"rnd_seed-{rnd_seed}"
    output_dir = save_dir / name / version

    # Lightning data setup
    BATCH_SIZE = 512
    lit_data_module = LitDataModule(
        batch_size=BATCH_SIZE,
        data_mode=data_mode,
        data_format=data_format,
        data_info=data_info,
        include_decay=include_decay,
        **kwargs
    )

    # Lightning model setup
    with open(project_root / 'config' / 'training.yml', 'r') as f:
        training_config = yaml.safe_load(f)
    lit_model = BinaryLitModel(
        model=model,
        lr=lr,
        pos_weight=lit_data_module.pos_weight,
        optimizer_settings=training_config['optimizer_settings']
    )

    # Lightning loggers
    logger = CSVLogger(save_dir=save_dir, name=name, version=version)
    hparams = {}
    hparams.update(lit_data_module.hparams)
    hparams.update(training_config)
    logger.log_hyperparams(hparams)

    # Lightning trainer
    trainer = lightning.Trainer(
        max_epochs=200,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        logger=logger,
        callbacks=[
            ModelCheckpoint(**training_config['ModelCheckpoint']),
            EarlyStopping(**training_config['EarlyStopping']),
        ],
    )

    # Lightning trainning and testing
    trainer.fit(lit_model, lit_data_module)
    trainer.test(lit_model, datamodule=lit_data_module, ckpt_path='best')
    os.makedirs(output_dir, exist_ok=True)
    utils.count_number_of_data(lit_data_module, output_dir)
    utils.count_model_parameters(lit_model, output_dir)
    utils.plot_metrics(output_dir)

In [None]:
if __name__ == '__main__':

    rnd_seeds = [23 + 100 * i for i in range(10)]

    for rnd_seed, data_mode, include_decay in product(rnd_seeds, ['jet_flavor'], [True, False]):

        with open(project_root / 'config' / 'data_diphoton.yml', 'r') as f:
            data_info = yaml.safe_load(f)

        for data_format, model_cls, lr in [
            ('image', CNN_EventCNN, 1e-4),
            # ('sequence', ParT_Light, 4e-4),
        ]:

            for luminosity in [100, 300, 900, 1800, 3000]:
                training(
                    data_mode=data_mode,
                    data_format=data_format,
                    data_info=data_info,
                    include_decay=include_decay,
                    model_cls=model_cls,
                    lr=lr,
                    tags=[data_mode],
                    rnd_seed=rnd_seed,
                    luminosity=luminosity,
                )