In [None]:
import os

import lightning
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC

from source.preprocessing import hdf5_to_seq
from source.part import ParticleTransformer

sns.set_theme()

rnd_seed = 42
lightning.seed_everything(rnd_seed)

Seed set to 42


42

In [None]:
class CWoLaDataset(Dataset):
    def __init__(self, signal: torch.Tensor, background: torch.Tensor, CWoLa_ratio: float=0):
        self.x = torch.cat([signal, background], dim=0)

        # Randomly assign labels to a fraction of the data
        num_sig = len(signal)
        num_bkg = len(background)
        sig_y = torch.zeros(num_sig)
        bkg_y = torch.ones(num_bkg)
        sig_y[torch.randperm(num_sig)[:int(CWoLa_ratio * num_sig)]] = 1
        bkg_y[torch.randperm(num_bkg)[:int(CWoLa_ratio * num_bkg)]] = 0
        self.y = torch.cat([sig_y, bkg_y], dim=0)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class LitDataModule(lightning.LightningDataModule):
    def __init__(self, batch_size=64, val_split=0.1, CWoLa_ratio: float=0):
        super().__init__()
        self.batch_size = batch_size
        self.val_split = val_split
        
        GGF = hdf5_to_seq('GGF.h5')
        VBF = hdf5_to_seq('VBF.h5')

        # Generate indices
        num_ggf = len(GGF)
        num_vbf = len(VBF)

        ggf_train_len = int((1 - self.val_split) * num_ggf)
        vbf_train_len = int((1 - self.val_split) * num_vbf)

        ggf_indices = torch.randperm(num_ggf)
        vbf_indices = torch.randperm(num_vbf)

        GGF_train = GGF[ggf_indices[:ggf_train_len]]
        GGF_valid = GGF[ggf_indices[ggf_train_len:]]

        VBF_train = VBF[vbf_indices[:vbf_train_len]]
        VBF_valid = VBF[vbf_indices[vbf_train_len:]]

        # class_counts: 0 = negative class (GGF), 1 = positive class (VBF)
        self.pos_weight = torch.tensor([ggf_train_len / vbf_train_len], dtype=torch.float)

        # Create datasets
        self.train_dataset = CWoLaDataset(GGF_train, VBF_train, CWoLa_ratio)
        self.valid_dataset = CWoLaDataset(GGF_valid, VBF_valid)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, shuffle=False)


In [3]:
class BinaryLitModel(lightning.LightningModule):
    def __init__(self, model: nn.Module, pos_weight: torch.Tensor):
        super().__init__()
        
        self.model = model
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        self.train_accuracy = BinaryAccuracy()
        self.valid_accuracy = BinaryAccuracy()
        self.train_auc = BinaryAUROC()
        self.valid_auc = BinaryAUROC()

    def forward(self, x: torch.Tensor) -> torch.Tensor:        
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.RAdam(self.parameters(), lr=1e-3)

    def _shared_step(self, batch: tuple[torch.Tensor, torch.Tensor], mode: str):
        x, y_true = batch
        logits = self(x)
        loss = self.loss_fn(logits.view(-1), y_true.float())
        y_pred = torch.sigmoid(logits.view(-1))

        # Update metrics
        if mode == 'train':
            self.train_auc.update(y_pred, y_true)
            self.train_accuracy.update(y_pred, y_true)
        elif mode == 'valid':
            self.valid_auc.update(y_pred, y_true)
            self.valid_accuracy.update(y_pred, y_true)

        # Log metrics
        self.log(f"{mode}_loss", loss, on_epoch=True, prog_bar=(mode == 'train'))

        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, mode='valid')

    def on_train_epoch_end(self):
        auc = self.train_auc.compute()
        acc = self.train_accuracy.compute()
        self.log('train_auc', auc, prog_bar=True)
        self.log('train_accuracy', acc, prog_bar=True)
        self.train_auc.reset()
        self.train_accuracy.reset()

    def on_validation_epoch_end(self):
        auc = self.valid_auc.compute()
        acc = self.valid_accuracy.compute()
        self.log('valid_auc', auc, prog_bar=True)
        self.log('valid_accuracy', acc, prog_bar=True)
        self.valid_auc.reset()
        self.valid_accuracy.reset()

In [4]:
class ParT_Baseline(ParticleTransformer):
    def __init__(self):

        hyperparameters = {
            "ParEmbed": {
                "input_dim": 3 + 3,  # (pt, eta, phi) + one-hot_encoding
                "embed_dim": [128, 512, 128]
            },
            "ParAtteBlock": {
                "num_heads": 8,
                "fc_dim": 512,
                "dropout": 0.1
            },
            "ClassAtteBlock": {
                "num_heads": 8,
                "fc_dim": 512,
                "dropout": 0.0
            },
            "num_ParAtteBlock": 8,
            "num_ClassAtteBlock": 2
        }

        super().__init__(score_dim=1, parameters=hyperparameters)

class ParT_Light(ParticleTransformer):
    def __init__(self):

        hyperparameters = {
            "ParEmbed": {
                "input_dim": 3 + 3,  # (pt, eta, phi) + one-hot_encoding
                "embed_dim": [64, 64, 64]
            },
            "ParAtteBlock": {
                "num_heads": 4,
                "fc_dim": 64,
                "dropout": 0.1
            },
            "ClassAtteBlock": {
                "num_heads": 4,
                "fc_dim": 64,
                "dropout": 0.0
            },
            "num_ParAtteBlock": 4,
            "num_ClassAtteBlock": 1
        }

        super().__init__(score_dim=1, parameters=hyperparameters)

In [None]:
# Setup.
CWoLa_ratio = 0.0
model = ParT_Light()

# Save directory and name.
save_dir = os.path.join('training_logs')
preprocessing_mode = "SV"
name = f"ParT_{preprocessing_mode}_{rnd_seed}"
version = f"{model.__class__.__name__}_{CWoLa_ratio}"

"""Training"""
# Lightning DataModule & Model.
lit_data_module = LitDataModule(CWoLa_ratio=CWoLa_ratio)
lit_model = BinaryLitModel(model=model, pos_weight=lit_data_module.pos_weight)
with open(os.path.join(save_dir, name, version, 'num_params.txt'), 'w') as file_num_params:
    for depth in range(1, 4):
        print(f"Model Summary (max_depth={depth}):", file=file_num_params)
        print(ModelSummary(lit_model, max_depth=depth), file=file_num_params)
        print(f"\n{'='*100}\n", file=file_num_params)

# Lightning Logger & Trainer.
logger = CSVLogger(save_dir=save_dir, name=name, version=version)
trainer = lightning.Trainer(
    accelerator='gpu',
    max_epochs=100,
    logger=logger,
    callbacks=[ModelCheckpoint(
            monitor='valid_auc',
            mode='max',
            save_top_k=5,
            save_last=True,
            filename='{epoch}-{valid_auc:.3f}-{valid_accuracy:.3f}',
        )],
)

# Train and test the model.
trainer.fit(lit_model, lit_data_module)
trainer.test(lit_model, datamodule=lit_data_module, ckpt_path='best')

"""Plot Metrics"""
metrics_csv = os.path.join(save_dir, name, version, 'metrics.csv')
df = pd.read_csv(metrics_csv)

fig, ax = plt.subplots(2, 3, figsize=(10, 6))
metrics = ['train_loss_epoch', 'train_accuracy', 'train_auc', 'valid_loss', 'valid_accuracy', 'valid_auc']

for i, metric in enumerate(metrics):
    data = df[df[metric].notna()]
    plot = sns.lineplot(data=data, x='epoch', y=metric, ax=ax.flat[i])
    plot.set_title(metric)

plt.tight_layout()
plt.savefig(os.path.join(save_dir, name, version, 'metrics.png'))