In [None]:
import os
import yaml

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
import torch
import tqdm

from source.data import jetclass, jetnet, topqcd
from source.data.datamodule import JetLightningDataModule
from source.models.litmodel import TorchLightningModule
from source.models.part import ParticleTransformer, AttentionBlock
from source.models.pnet import ParticleNet

L.seed_everything(42)

dataset = jetclass
model_class = ParticleTransformer
train_model = False

### Data setup and `data_module`

In [4]:
def lightning_data_module(dataset, num_train, num_valid, num_test, batch_size) -> JetLightningDataModule:
    """Create a Lightning DataModule for the given dataset."""

    # Jet Class Dataset
    if dataset == 'jetclass':
        tqdm_channels = tqdm.tqdm(jetclass.channels, desc='Loading JetClass Dataset')
        jet_events_list = [jetclass.JetEvents(channel, num_root=1) for channel in tqdm_channels]
    
    # Jet Net Dataset
    elif dataset == 'jetnet':
        tqdm_channels = tqdm.tqdm(jetnet.channels, desc='Loading JetNet Dataset')
        jet_events_list = [jetnet.JetEvents(channel) for channel in tqdm_channels]
    
    # Top QCD Dataset
    elif dataset == 'topqcd':
        tqdm_channels = tqdm.tqdm(topqcd.channels, desc='Loading TopQCD Dataset')
        jet_events_list = [
            (
                topqcd.JetEvents(channel, mode='train', num_data=num_train) + 
                topqcd.JetEvents(channel, mode='valid', num_data=num_valid) + 
                topqcd.JetEvents(channel, mode='test', num_data=num_test)
            ) for channel in tqdm_channels
        ]
    
    return JetLightningDataModule(jet_events_list, num_train, num_valid, num_test, batch_size)

In [5]:
# Dimension of the model output (equivalent to the number of classes)
score_dim: int = len(dataset.channels)

# The dataset is a python module stored at `source.data.dataset_name`
dataset_name: str = dataset.__name__.split('.')[-1]

# Dictionary of the data configuration
data_module_config = {
    'dataset': dataset_name,
    'num_train': 1000,
    'num_valid': 100,
    'num_test': 100,
    'batch_size': 32,
}

# Create the Lightning DataModule
data_module: JetLightningDataModule = lightning_data_module(**data_module_config)

Loading JetClass Dataset: 100%|██████████| 10/10 [00:14<00:00,  1.42s/it]
Creating JetLightningDataModule: 100%|██████████| 10/10 [00:28<00:00,  2.88s/it]


### Model training (trained with `LightningModule`)

In [6]:
# Model setup (`yaml` configuration file)
with open(f"configs/{model_class.__name__}.yaml", 'r') as file:
    hparams = yaml.safe_load(file)[model_class.__name__]
    model = model_class(score_dim=score_dim, parameters=hparams)

# Lightning DataModule and Model
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
optimizer = torch.optim.RAdam(model.parameters(), lr=1E-3)
lightning_model = TorchLightningModule(model, optimizer=optimizer, score_dim=score_dim, print_log=False)

# Lightning Logger
save_dir = os.path.join('training_logs', dataset_name)
os.makedirs(save_dir, exist_ok=True)
logger = CSVLogger(save_dir=save_dir, name=f"{model_class.__name__}", version='lastest_run')

# Lightning Trainer
trainer = L.Trainer(
    accelerator=accelerator,
    max_epochs=10,
    logger=logger,
    num_sanity_val_steps=0,
    callbacks=[ModelCheckpoint(
        monitor='valid_auc',
        mode='max',
        every_n_epochs=1,
        save_last=True,
        save_top_k=-1,
        filename='{epoch}',
    )],
)

if train_model:
    trainer.fit(lightning_model, data_module)
    trainer.test(lightning_model, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
