In [None]:
import os
from pathlib import Path
import sys
import yaml

import lightning
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC

try:
    project_root = Path(__file__).parent.parent
except NameError:
    '''Jupyter notebook environment has no __file__ attribute.'''
    project_root = Path.cwd().parent
sys.path.append(project_root.as_posix())

from src.data_preprocess import MCSimData
from src.data_cwola import split_by_sv
from src.model_cnn import CNN_Baseline, CNN_EventCNN
from src.model_part import ParT_Baseline, ParT_Light

In [None]:
class LitDataModule(lightning.LightningDataModule):
    def __init__(self, data_format: str, data_info: dict, include_decay: bool, num_test: int = None):
        super().__init__()

        self.data_format = data_format
        self.data_info = data_info
        self.include_decay = include_decay
        self.num_test = num_test

        # Information of signal and background datasets
        sig_info = data_info['signal']
        bkg_info = data_info['background']

        # Monte Carlo simulation data
        SIG = MCSimData(sig_info['path'], include_decay=include_decay)
        BKG = MCSimData(bkg_info['path'], include_decay=include_decay)

        # Choose the representation of the dataset
        if data_format == 'image':
            sig_tensor = SIG.to_image()
            bkg_tensor = BKG.to_image()
        elif data_format == 'sequence':
            sig_tensor = SIG.to_sequence()
            bkg_tensor = BKG.to_sequence()

        # Only keep testing data
        _, _, _, _, test_sig, test_bkg = split_by_sv(
            sig_tensor=sig_tensor, bkg_tensor=bkg_tensor,
            num_train=0, num_valid=0, num_test=num_test,
        )

        # Create torch datasets
        self.test_dataset = TensorDataset(
            torch.cat([test_sig, test_bkg], dim=0),
            torch.cat([torch.ones(len(test_sig)), torch.zeros(len(test_bkg))], dim=0)
        )

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=64, shuffle=False)

In [None]:
class BinaryLitModel(lightning.LightningModule):
    def __init__(self, model: nn.Module):
        super().__init__()

        self.model = model

        self.train_accuracy = BinaryAccuracy()
        self.valid_accuracy = BinaryAccuracy()
        self.test_accuracy = BinaryAccuracy()

        self.train_auc = BinaryAUROC()
        self.valid_auc = BinaryAUROC()
        self.test_auc = BinaryAUROC()

    def forward(self, x: torch.Tensor) -> torch.Tensor:        
        return self.model(x)

    def test_step(self, batch, batch_idx):
        x, y_true = batch
        logits: torch.Tensor = self(x)
        y_pred = torch.sigmoid(logits.view(-1))

        self.test_auc.update(y_pred, y_true)
        self.test_accuracy.update(y_pred, y_true)

    def on_test_epoch_end(self):
        self.log('test_auc', self.test_auc.compute(), prog_bar=True, on_epoch=True, on_step=False)
        self.log('test_accuracy', self.test_accuracy.compute(), prog_bar=True, on_epoch=True, on_step=False)

In [None]:
def get_best_ckpt(ckpt_dir: str, metric: str = 'valid_auc') -> str:
    """Find the best checkpoint file based on the specified metric."""

    # Initialize variables to store the best checkpoint
    best_ckpt = None
    best_value = -float('inf')

    # Find the best checkpoint file
    for ckpt_file in os.listdir(ckpt_dir):
        if ckpt_file.startswith('epoch=') and ckpt_file.endswith('.ckpt'):
            metrics = ckpt_file.split('-')
            metrics = [m for m in metrics if m.startswith(metric)]
            metric_value = eval(metrics[0].split('=')[1])
            if metric_value > best_value:
                best_value = metric_value
                best_ckpt = ckpt_file

    print(f"The best checkpoint is: {best_ckpt} with {metric} = {best_value}")

    return best_ckpt

In [None]:
def inference(training_channel, training_mode, training_date, inference_channel, include_decay, information=None):
    """Run inference on the specified model and dataset."""
    
    num_channels = 2 + include_decay

    with open(project_root / Path(f"config/data_{inference_channel}.yml"), 'r') as f:
        data_info = yaml.safe_load(f)

    df = pd.DataFrame(columns=['model', 'rnd_seed', 'test_auc', 'test_accuracy'])

    for data_format, model in [
        ('image', CNN_Baseline(num_channels=num_channels)),
        ('image', CNN_EventCNN(num_channels=num_channels)),
        ('sequence', ParT_Baseline(num_channels=num_channels)),
        ('sequence', ParT_Light(num_channels=num_channels)),
    ]:
        lit_data_module = LitDataModule(
            data_format=data_format,
            data_info=data_info,
            include_decay=include_decay,
            num_test=10000,
        )

        for rnd_seed in range(1, 6):
            # Set random seed for reproducibility
            lightning.seed_everything(rnd_seed)

            # Load the best checkpoint
            ckpt_dir = project_root / Path(f"output/{'ex-' * (not include_decay) + training_channel}/{training_mode}/{model.__class__.__name__}/{training_date}-rnd_seed{rnd_seed}/checkpoints")
            ckpt_path = get_best_ckpt(ckpt_dir, metric='valid_auc')
            ckpt = torch.load(ckpt_dir / Path(ckpt_path))

            # Remove 'model.' prefix from all keys in the state_dict
            ckpt_state_dict = ckpt['state_dict']
            ckpt_state_dict = {k.replace('model.', ''): v for k, v in ckpt_state_dict.items()}
            model.load_state_dict(ckpt_state_dict, strict=False)
            lit_model = BinaryLitModel(model=model)
            lit_model.eval()
            trainer = lightning.Trainer(logger=False)
            result = trainer.test(lit_model, datamodule=lit_data_module)

            df = pd.concat([df, pd.DataFrame({
                'model': model.__class__.__name__,
                'rnd_seed': rnd_seed,
                'test_auc': result[0]['test_auc'],
                'test_accuracy': result[0]['test_accuracy'],
            }, index=[0])], ignore_index=True)

    output_dir = project_root / Path('output/inference')
    os.makedirs(output_dir, exist_ok=True)
    df.to_csv(output_dir / f"{'ex-' * (not include_decay) + training_channel}_to_{inference_channel}-{training_mode}-{training_date}{'-' * (information is not None) + information}.csv", index=False)

In [None]:
for training_mode in ['jet_flavor', 'jet_flavor_uni5', 'jet_flavor_uni10', 'jet_flavor_uni15']:
    inference(
        training_channel='diphoton',
        training_mode=training_mode,
        training_date='20250723_173318',
        inference_channel='zz4l',
        include_decay=True,
        information='L=3000',
    )
    inference(
        training_channel='diphoton',
        training_mode=training_mode,
        training_date='20250729_154839',
        inference_channel='zz4l',
        include_decay=True,
        information='L=300',
    )
    inference(
        training_channel='diphoton',
        training_mode=training_mode,
        training_date='20250721_121840',
        inference_channel='zz4l',
        include_decay=False,
        information='L=3000',
    )
    inference(
        training_channel='diphoton',
        training_mode=training_mode,
        training_date='20250731_015137',
        inference_channel='zz4l',
        include_decay=False,
        information='L=300',
    )