In [1]:
import os
import yaml

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
import torch
import tqdm

from source.dataset import JetEvents, JetClass, JetNet, JetLightningDataModule
from source.litmodel import TorchLightningModule
from source.models.part import ParticleTransformer

### `yaml` configurations

In [None]:
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

rnd_seed = config['rnd_seed']
L.seed_everything(rnd_seed)

### Dataset

In [None]:
# Class of the target dataset.
dataset_str_name: str = config['dataset']
DATASET: JetEvents = globals()[dataset_str_name]

# Load the dataset.
channels = tqdm.tqdm(DATASET.CHANNELS, desc='Loading dataset')
jet_events_list = [DATASET(channel, **config[DATASET.__name__]) for channel in channels]

# Create the data module.
print('Creating data module.')
data_module = JetLightningDataModule(
    jet_events_list=jet_events_list,
    pad_num_ptcs=config[DATASET.__name__]['pad_num_ptcs'],
    **config['DataModule'],
)

### Model

In [4]:
hparams_part = {
    'ParEmbed': {
        'input_dim': DATASET.INPUT_DIM,
        'embed_dim': [128, 512, 128],
    },
    'IntEmbed': {
        'input_dim': 4 if DATASET.INCLUDE_MASS else 3,
        'embed_dim': [64, 64, 64],
    },
    'ParAtteBlock': {
        'num_heads': 8,
        'fc_dim': 512,
        'dropout': 0.1,
    },
    'ClassAtteBlock': {
        'num_heads': 8,
        'fc_dim': 512,
        'dropout': 0.0,
    },
    'num_ParAtteBlock': 8,
    'num_ClassAtteBlock': 2,
}

model = ParticleTransformer(score_dim=len(channels), parameters=hparams_part)

### Training

In [None]:
# Lightning DataModule and Model
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
optimizer = torch.optim.RAdam(model.parameters(), lr=config['lr'])
lit_model = TorchLightningModule(model, optimizer=optimizer, score_dim=len(channels), print_log=False)

# Lightning Logger
save_dir = os.path.join('training_logs', DATASET.__name__)
logger = CSVLogger(save_dir=save_dir, name=f"{model.__class__.__name__}_{rnd_seed}", version='lastest_run')

# Lightning Trainer
trainer = L.Trainer(
    accelerator=accelerator,
    max_epochs=config['num_epochs'],
    logger=logger,
    num_sanity_val_steps=0,
    callbacks=[ModelCheckpoint(
        monitor='valid_auc',
        mode='max',
        every_n_epochs=1,
        save_last=True,
        save_top_k=-1,
        filename='{epoch}',
    )],
)

# Save the model for quick loading.
os.makedirs(logger.log_dir, exist_ok=True)
torch.save(model, f"{logger.log_dir}/model.pt")

trainer.fit(lit_model, data_module)
trainer.test(lit_model, data_module)