In [None]:
import os

import lightning
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC

from source.preprocessing import hdf5_to_seq, hdf5_jet_flavor
from source.part import ParticleTransformer

sns.set_theme()

rnd_seed = 42
lightning.seed_everything(rnd_seed)

In [None]:
class TorchDataset(Dataset):
    def __init__(self, sig: torch.Tensor, bkg: torch.Tensor):
        self.x = torch.cat([sig, bkg], dim=0)
        self.y = torch.cat([torch.ones(len(sig)), torch.zeros(len(bkg))], dim=0)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class LitDataModule(lightning.LightningDataModule):
    def __init__(self, batch_size=64, L=3000, num_test=10000):
        super().__init__()
        self.batch_size = batch_size
        print(f"{'='* 20} Data Size Information {'='* 20}")
        
        # Sequence data
        GGF = hdf5_to_seq('GGF.h5')
        VBF = hdf5_to_seq('VBF.h5')
        print(f'GGF.h5 shape: {GGF.shape}, VBF.h5 shape: {VBF.shape}')

        # Jet flavor
        GGF_info = hdf5_jet_flavor('GGF.h5')
        VBF_info = hdf5_jet_flavor('VBF.h5')

        # Number of data from real luminosity
        BR_Haa = 0.00227
        cross_section_GGF = 54.67 * 1000
        cross_section_VBF = 4.278 * 1000

        # Selection rate 0.9 & 0.41 from 謝豐仰
        GGF_after_selection = cross_section_GGF * BR_Haa * L * 0.09
        VBF_after_selection = cross_section_VBF * BR_Haa * L * 0.41

        # CWoLa: (2q0g) v.s. (1q1g + 0q2g)
        num_GGF_real = int((sum(GGF_info['2q0g']) + sum(GGF_info['1q1g']) + sum(GGF_info['0q2g'])) / GGF_info['total'] * GGF_after_selection)
        num_VBF_real = int((sum(VBF_info['2q0g']) + sum(VBF_info['1q1g']) + sum(VBF_info['0q2g'])) / VBF_info['total'] * VBF_after_selection)
        print(f'GGF after selection: {num_GGF_real}, VBF after selection: {num_VBF_real}')

        # Randomly sampling and artificially set num_test for GGF and VBF
        GGF_index = torch.nonzero((GGF_info['2q0g'] | GGF_info['1q1g'] | GGF_info['0q2g'])).squeeze()
        VBF_index = torch.nonzero((VBF_info['2q0g'] | VBF_info['1q1g'] | VBF_info['0q2g'])).squeeze()

        GGF_index = GGF_index[torch.randperm(len(GGF_index))]
        VBF_index = VBF_index[torch.randperm(len(VBF_index))]

        GGF_train_index = GGF_index[:int(num_GGF_real * 0.8)]
        VBF_train_index = VBF_index[:int(num_VBF_real * 0.8)]

        GGF_valid_index = GGF_index[int(num_GGF_real * 0.8):int(num_GGF_real)]
        VBF_valid_index = VBF_index[int(num_VBF_real * 0.8):int(num_VBF_real)]

        GGF_test_index  = GGF_index[int(num_GGF_real):int(num_GGF_real) + num_test]
        VBF_test_index  = VBF_index[int(num_VBF_real):int(num_VBF_real) + num_test]

        train_sig = torch.cat((
            GGF[GGF_train_index][GGF_info['2q0g'][GGF_train_index]],
            VBF[VBF_train_index][VBF_info['2q0g'][VBF_train_index]]
            ), dim=0)
        
        train_bkg = torch.cat((
            GGF[GGF_train_index][GGF_info['1q1g'][GGF_train_index] | GGF_info['0q2g'][GGF_train_index]],
            VBF[VBF_train_index][VBF_info['1q1g'][VBF_train_index] | VBF_info['0q2g'][VBF_train_index]]
            ), dim=0)
        
        valid_sig = torch.cat((
            GGF[GGF_valid_index][GGF_info['2q0g'][GGF_valid_index]],
            VBF[VBF_valid_index][VBF_info['2q0g'][VBF_valid_index]]
            ), dim=0)
        
        valid_bkg = torch.cat((
            GGF[GGF_valid_index][GGF_info['1q1g'][GGF_valid_index] | GGF_info['0q2g'][GGF_valid_index]],
            VBF[VBF_valid_index][VBF_info['1q1g'][VBF_valid_index] | VBF_info['0q2g'][VBF_valid_index]]
            ), dim=0)
        
        test_sig = VBF[VBF_test_index]
        test_bkg = GGF[GGF_test_index]

        print(f'Train signal shape: {train_sig.shape}, Train background shape: {train_bkg.shape}')
        print(f'Valid signal shape: {valid_sig.shape}, Valid background shape: {valid_bkg.shape}')
        print(f'Test signal shape: {test_sig.shape}, Test background shape: {test_bkg.shape}')
        print(f"{'='* 50}")

        # Create datasets
        self.train_dataset = TorchDataset(train_sig, train_bkg)
        self.valid_dataset = TorchDataset(valid_sig, valid_bkg)
        self.test_dataset  = TorchDataset(test_sig, test_bkg)

        # Calculate positive weight for loss function
        num_pos = torch.sum(self.train_dataset.y == 1)
        num_neg = torch.sum(self.train_dataset.y == 0)
        self.pos_weight = torch.tensor([num_neg / num_pos], dtype=torch.float32)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

In [None]:
class BinaryLitModel(lightning.LightningModule):
    def __init__(self, model: nn.Module, pos_weight: torch.Tensor):
        super().__init__()
        
        self.model = model
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        self.train_accuracy = BinaryAccuracy()
        self.valid_accuracy = BinaryAccuracy()
        self.test_accuracy = BinaryAccuracy()

        self.train_auc = BinaryAUROC()
        self.valid_auc = BinaryAUROC()
        self.test_auc = BinaryAUROC()

    def forward(self, x: torch.Tensor) -> torch.Tensor:        
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.RAdam(self.parameters(), lr=3e-4)

    def _shared_step(self, batch: tuple[torch.Tensor, torch.Tensor], mode: str):
        x, y_true = batch
        logits = self(x)
        loss = self.loss_fn(logits.view(-1), y_true.float())
        y_pred = torch.sigmoid(logits.view(-1))

        if mode == 'train':
            self.train_auc.update(y_pred, y_true)
            self.train_accuracy.update(y_pred, y_true)
        elif mode == 'valid':
            self.valid_auc.update(y_pred, y_true)
            self.valid_accuracy.update(y_pred, y_true)
        elif mode == 'test':
            self.test_auc.update(y_pred, y_true)
            self.test_accuracy.update(y_pred, y_true)

        self.log(f"{mode}_loss", loss, on_epoch=True, prog_bar=(mode == 'train'))

        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, mode='valid')

    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, mode='test')

    def on_train_epoch_end(self):
        self.log('train_auc', self.train_auc.compute(), prog_bar=True)
        self.log('train_accuracy', self.train_accuracy.compute(), prog_bar=True)
        self.train_auc.reset()
        self.train_accuracy.reset()

    def on_validation_epoch_end(self):
        self.log('valid_auc', self.valid_auc.compute(), prog_bar=True)
        self.log('valid_accuracy', self.valid_accuracy.compute(), prog_bar=True)
        self.valid_auc.reset()
        self.valid_accuracy.reset()

    def on_test_epoch_end(self):
        self.log('test_auc', self.test_auc.compute(), prog_bar=True)
        self.log('test_accuracy', self.test_accuracy.compute(), prog_bar=True)
        self.test_auc.reset()
        self.test_accuracy.reset()

In [None]:
class ParT_Baseline(ParticleTransformer):
    def __init__(self):

        hyperparameters = {
            "ParEmbed": {
                "input_dim": 3 + 3,  # (pt, eta, phi) + one-hot_encoding
                "embed_dim": [128, 512, 128]
            },
            "ParAtteBlock": {
                "num_heads": 8,
                "fc_dim": 512,
                "dropout": 0.1
            },
            "ClassAtteBlock": {
                "num_heads": 8,
                "fc_dim": 512,
                "dropout": 0.0
            },
            "num_ParAtteBlock": 8,
            "num_ClassAtteBlock": 2
        }

        super().__init__(score_dim=1, parameters=hyperparameters)

class ParT_Light(ParticleTransformer):
    def __init__(self):

        hyperparameters = {
            "ParEmbed": {
                "input_dim": 3 + 3,  # (pt, eta, phi) + one-hot_encoding
                "embed_dim": [64, 64, 64]
            },
            "ParAtteBlock": {
                "num_heads": 4,
                "fc_dim": 64,
                "dropout": 0.1
            },
            "ClassAtteBlock": {
                "num_heads": 4,
                "fc_dim": 64,
                "dropout": 0.0
            },
            "num_ParAtteBlock": 4,
            "num_ClassAtteBlock": 1
        }

        super().__init__(score_dim=1, parameters=hyperparameters)

In [None]:
# Setup.
model = ParT_Light()
batch_size = 256  # Batch size for training and validation
L = 3000  # Luminosity in fb^-1
num_test = 10000  # Number of testing samples

# Save directory and name.
save_dir = os.path.join('training_logs')
preprocessing_mode = "R"
name = f"ParT_{preprocessing_mode}_{rnd_seed}"
version = f"{model.__class__.__name__}_B{batch_size}"

"""Training"""
# Lightning DataModule & Model.
lit_data_module = LitDataModule(batch_size=batch_size, L=L, num_test=num_test)
lit_model = BinaryLitModel(model=model, pos_weight=lit_data_module.pos_weight)

# Lightning Logger & Trainer.
logger = CSVLogger(save_dir=save_dir, name=name, version=version)
trainer = lightning.Trainer(
    accelerator='gpu',
    max_epochs=100,
    logger=logger,
    callbacks=[ModelCheckpoint(
            monitor='valid_auc',
            mode='max',
            save_top_k=5,
            save_last=True,
            filename='{epoch}-{valid_auc:.3f}-{valid_accuracy:.3f}',
        )],
)

# Train and test the model.
trainer.fit(lit_model, lit_data_module)
trainer.test(lit_model, datamodule=lit_data_module, ckpt_path='best')

# Summay of the number of parameters.
with open(os.path.join(save_dir, name, version, 'num_params.txt'), 'w') as file_num_params:
    for depth in range(1, 4):
        print(f"Model Summary (max_depth={depth}):", file=file_num_params)
        print(ModelSummary(lit_model, max_depth=depth), file=file_num_params)
        print(f"\n{'='*100}\n", file=file_num_params)

"""Plot Metrics"""
metrics_csv = os.path.join(save_dir, name, version, 'metrics.csv')
df = pd.read_csv(metrics_csv)

fig, ax = plt.subplots(2, 3, figsize=(10, 6))
metrics = ['train_loss_epoch', 'train_accuracy', 'train_auc', 'valid_loss', 'valid_accuracy', 'valid_auc']

for i, metric in enumerate(metrics):
    data = df[df[metric].notna()]
    plot = sns.lineplot(data=data, x='epoch', y=metric, ax=ax.flat[i])
    plot.set_title(metric)

plt.tight_layout()
plt.savefig(os.path.join(save_dir, name, version, 'metrics.png'))