In [None]:
import os
import shutil
import time
import yaml

import lightning
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC

from source.data_augment import aug_phi_shift
from source.data_preprocess import MCSimData
from source.data_cwola import split_by_pure_random, split_by_jet_flavor
from source.model_cnn import CNN_Baseline, CNN_Light, CNN_EventCNN
from source.model_part import ParT_Baseline, ParT_Medium, ParT_Light, ParT_SuperLight, ParT_ExtremeLight

sns.set_theme()

with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

rnd_seed = config['training']['rnd_seed']
lightning.seed_everything(rnd_seed)

In [None]:
class LitDataModule(lightning.LightningDataModule):
    def __init__(self, data_format: str, data_info: dict, batch_size: int,
                 preprocessings: list[str] = [], augmentations: dict = {'functions': []}):
        super().__init__()

        self.data_format = data_format
        self.data_info = data_info
        self.batch_size = batch_size
        self.preprocessings = preprocessings
        self.augmentations = augmentations

        # Information of signal and background datasets
        sig_info = data_info['signal']
        bkg_info = data_info['background']

        # Monte Carlo simulation data
        SIG = MCSimData(sig_info['path'], include_decay=data_info['include_decay'])
        BKG = MCSimData(bkg_info['path'], include_decay=data_info['include_decay'])

        ''' ***** Preprocessing ***** '''
        SIG = self._data_preprocessings(SIG)
        BKG = self._data_preprocessings(BKG)

        # Choose the representation of the dataset
        if data_format == 'image':
            sig_tensor = SIG.to_image()
            bkg_tensor = BKG.to_image()
        elif data_format == 'sequence':
            sig_tensor = SIG.to_sequence()
            bkg_tensor = BKG.to_sequence()

        # Create mixed dataset for implementing CWoLa
        if data_info['CWoLa_mode'] == 'jet_flavor':
            train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg = split_by_jet_flavor(
                sig_tensor=sig_tensor, bkg_tensor=bkg_tensor,
                sig_flavor=SIG.jet_flavor, bkg_flavor=BKG.jet_flavor,
                branching_ratio=data_info['branching_ratio'], luminosity=data_info['luminosity'],
                sig_cross_section=sig_info['cross_section'], bkg_cross_section=bkg_info['cross_section'],
                sig_preselection_rate=sig_info['preselection_rate'], bkg_preselection_rate=bkg_info['preselection_rate'],
                train_fraction=data_info['train_fraction'], num_test=data_info['num_test'],
            )
        elif data_info['CWoLa_mode'] == 'pure_random':
            train_sig, train_bkg, valid_sig, valid_bkg, test_sig, test_bkg = split_by_pure_random(
                sig_tensor=sig_tensor, bkg_tensor=bkg_tensor,
                num_train=data_info['num_train'], num_valid=data_info['num_valid'], num_test=data_info['num_test'],
            )

        ''' ***** Augmentation ***** '''
        train_sig = self._data_augmentations(train_sig)
        train_bkg = self._data_augmentations(train_bkg)

        # Create torch datasets
        self.train_dataset = TensorDataset(torch.cat([train_sig, train_bkg], dim=0), torch.cat([torch.ones(len(train_sig)), torch.zeros(len(train_bkg))], dim=0))
        self.valid_dataset = TensorDataset(torch.cat([valid_sig, valid_bkg], dim=0), torch.cat([torch.ones(len(valid_sig)), torch.zeros(len(valid_bkg))], dim=0))
        self.test_dataset = TensorDataset(torch.cat([test_sig, test_bkg], dim=0), torch.cat([torch.ones(len(test_sig)), torch.zeros(len(test_bkg))], dim=0))

        # Calculate positive weight for loss function
        num_pos = len(train_sig)  # y == 1
        num_neg = len(train_bkg)  # y == 0
        self.pos_weight = torch.tensor([num_neg / num_pos], dtype=torch.float32)
    
    def _data_preprocessings(self, Data: MCSimData) -> MCSimData:
        if 'cop' in self.preprocessings:
            Data.preprocess_center_of_phi()
        return Data

    def _data_augmentations(self, data: torch.Tensor) -> torch.Tensor:
        aug_dict = self.augmentations
        for func in aug_dict['functions']:
            if func == 'phi_uni':
                data = aug_phi_shift(data, mode='uniform', format=self.data_format, rotations=aug_dict['rotations'])
            elif func == 'phi_rand':
                data = aug_phi_shift(data, mode='random', format=self.data_format, rotations=aug_dict['rotations'])
        return data

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

In [None]:
class BinaryLitModel(lightning.LightningModule):
    def __init__(self, model: nn.Module, lr: float, pos_weight: torch.Tensor, scheduler_settings: dict = None):
        super().__init__()

        self.model = model
        self.lr = lr
        self.scheduler_settings = scheduler_settings
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        self.train_accuracy = BinaryAccuracy()
        self.valid_accuracy = BinaryAccuracy()
        self.test_accuracy = BinaryAccuracy()

        self.train_auc = BinaryAUROC()
        self.valid_auc = BinaryAUROC()
        self.test_auc = BinaryAUROC()

    def forward(self, x: torch.Tensor) -> torch.Tensor:        
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.RAdam(self.parameters(), lr=self.lr)
        if self.scheduler_settings is None:
            return optimizer
        else:
            scheduler_settings = self.scheduler_settings
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode=scheduler_settings['mode'],
                factor=scheduler_settings['factor'],
                patience=scheduler_settings['patience'],
                threshold=scheduler_settings['threshold'],
            )
            return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': scheduler,
                    'monitor': scheduler_settings['monitor'],
                    'interval': scheduler_settings['interval'],
                    'frequency': scheduler_settings['frequency'],
                }
            }

    def _shared_step(self, batch: tuple[torch.Tensor, torch.Tensor], mode: str):
        x, y_true = batch
        logits = self(x)
        loss = self.loss_fn(logits.view(-1), y_true.float())
        y_pred = torch.sigmoid(logits.view(-1))

        if mode == 'train':
            self.train_auc.update(y_pred, y_true)
            self.train_accuracy.update(y_pred, y_true)
        elif mode == 'valid':
            self.valid_auc.update(y_pred, y_true)
            self.valid_accuracy.update(y_pred, y_true)
        elif mode == 'test':
            self.test_auc.update(y_pred, y_true)
            self.test_accuracy.update(y_pred, y_true)

        self.log(f"{mode}_loss", loss, on_epoch=True, prog_bar=(mode == 'train'))

        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, mode='valid')

    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, mode='test')

    def on_train_epoch_end(self):
        self.log('train_auc', self.train_auc.compute(), prog_bar=True)
        self.log('train_accuracy', self.train_accuracy.compute(), prog_bar=True)
        self.train_auc.reset()
        self.train_accuracy.reset()

    def on_validation_epoch_end(self):
        self.log('valid_auc', self.valid_auc.compute(), prog_bar=True)
        self.log('valid_accuracy', self.valid_accuracy.compute(), prog_bar=True)
        self.valid_auc.reset()
        self.valid_accuracy.reset()

    def on_test_epoch_end(self):
        self.log('test_auc', self.test_auc.compute(), prog_bar=True)
        self.log('test_accuracy', self.test_accuracy.compute(), prog_bar=True)
        self.test_auc.reset()
        self.test_accuracy.reset()

In [None]:
def count_model_parameters(lit_model: lightning.LightningModule, output_dir: str):
    with open(os.path.join(output_dir, 'num_params.txt'), 'w') as file_num_params:
        for depth in range(1, 4):
            print(f"Model Summary (max_depth={depth}):", file=file_num_params)
            print(ModelSummary(lit_model, max_depth=depth), file=file_num_params)
            print(f"\n{'='*100}\n", file=file_num_params)


def plot_metrics(output_dir: str):
    df = pd.read_csv(os.path.join(output_dir, 'metrics.csv'))

    fig, ax = plt.subplots(2, 3, figsize=(10, 6))
    metrics = ['train_loss_epoch', 'train_accuracy', 'train_auc', 'valid_loss', 'valid_accuracy', 'valid_auc']

    for i, metric in enumerate(metrics):
        data = df[df[metric].notna()]
        plot = sns.lineplot(data=data, x='epoch', y=metric, ax=ax.flat[i])
        plot.set_title(metric)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics.png'))

    return fig, ax

In [None]:
current_time = time.strftime('%Y%m%d-%H%M%S', time.localtime())
num_channels = config['dataset']['num_channels']

for data_format, model, lr, batch_size_step, batch_accumulate in [
    ('image', CNN_EventCNN(num_channels=num_channels), 2e-4, 64, 8),
    ('image', CNN_Baseline(num_channels=num_channels), 1e-5, 64, 8),
    ('image', CNN_Light(num_channels=num_channels), 5e-4, 64, 8),
    ('sequence', ParT_Baseline(num_channels=num_channels), 5e-5, 64, 8),
    ('sequence', ParT_Medium(num_channels=num_channels), 1e-4, 64, 8),
    ('sequence', ParT_Light(num_channels=num_channels), 5e-4, 64, 8),
    ('sequence', ParT_SuperLight(num_channels=num_channels), 1e-3, 64, 8),
    ('sequence', ParT_ExtremeLight(num_channels=num_channels), 5e-3, 64, 8),
]:
    # Save directory and name
    save_dir = os.path.join(config['result_dir'], '_'.join(config['tags']))
    name = model.__class__.__name__
    version = f"{current_time}_lr{lr:.0e}_b{batch_size_step}x{batch_accumulate}_seed{rnd_seed}"

    # Lightning DataModule & Model & Learning Rate Scheduler
    lit_data_module = LitDataModule(
        data_format=data_format,
        data_info=config['dataset'],
        batch_size=batch_size_step,
        preprocessings=config['preprocessings'],
        augmentations=config['augmentations'],
    )
    scheduler_settings = config['training']['lr_scheduler']
    lit_model = BinaryLitModel(model=model, lr=lr, pos_weight=lit_data_module.pos_weight, scheduler_settings=scheduler_settings)

    # Lightning Logger & Trainer & Early Stopping
    logger = CSVLogger(save_dir=save_dir, name=name, version=version)
    earlystop_settings = config['training']['early_stopping']
    trainer = lightning.Trainer(
        accelerator=config['training']['device'],
        max_epochs=config['training']['num_epochs'],
        logger=logger,
        accumulate_grad_batches=batch_accumulate,
        callbacks=[
            ModelCheckpoint(
                monitor=config['training']['monitor'],
                mode='max',
                save_top_k=5,
                save_last=True,
                filename='{epoch}-{valid_auc:.3f}-{valid_accuracy:.3f}',
            ),
            EarlyStopping(
                monitor=earlystop_settings['monitor'],
                min_delta=earlystop_settings['min_delta'],
                patience=earlystop_settings['patience'],
                mode=earlystop_settings['mode'],
            )
        ],
    )

    # Train and test the model
    if config['training']['fit']:
        trainer.fit(lit_model, lit_data_module)
    if config['training']['test']:
        trainer.test(lit_model, datamodule=lit_data_module, ckpt_path='best')

    # Output directory
    output_dir = os.path.join(save_dir, name, version)
    os.makedirs(output_dir, exist_ok=True)

    # Summary of the training
    shutil.copy(src='config.yaml', dst=os.path.join(output_dir, 'config.yaml'))
    count_model_parameters(lit_model, output_dir)
    fig, ax = plot_metrics(output_dir)