In [None]:
import math
import time
import yaml

import lightning as L
import torch
import torch.nn as nn
import wandb

from source.data.preprocessing import FatJetEvents
from source.data.datamodule import JetTorchDataModule, JetGraphDataModule
from source.models.qcgnn import QuantumRotQCGNN, HybridQCGNN
from source.models.mpgnn import ClassicalMPGNN
from source.models.part import ParticleTransformer
from source.models.pnet import ParticleNet
from source.training.litmodel import TorchLightningModule, GraphLightningModel
from source.training.loggers import csv_logger, wandb_logger
from source.training.result import plot_metrics

with open('configs/config.yaml', 'r') as file:
    config = yaml.safe_load(file)
    config['date'] = time.strftime('%Y%m%d_%H%M%S', time.localtime())

In [None]:
def create_data_module(graph: bool) -> L.LightningDataModule:
    """Randomly create a data module."""
    
    # Read jet data (reading is not affected by random seed).
    channels = ['VzToQCD', 'VzToZH', 'VzToTT']
    num_train = config['Data']['num_train']
    num_valid = config['Data']['num_valid']
    num_test = config['Data']['num_test']
    num_bins = config['Data']['num_bins']
    num_data_per_bin = math.ceil((num_train + num_valid + num_test) / num_bins)
    events = [FatJetEvents(channel=channel, num_data_per_bin=num_data_per_bin, **config['Data']) for channel in channels]
    events = [_events.generate_uniform_pt_events() for _events in events]

    # Turn into data module (for lightning training module).
    if graph:
        data_module = JetGraphDataModule(events, **config['Data'])
    else:
        data_module = JetTorchDataModule(events, **config['Data'])

    return data_module


def create_training_info(model: nn.Module, model_description: str, model_hparams: dict, random_seed: int) -> dict:
    """Create a training information dictionary that will be recorded."""

    # Data information.
    data_description = f"P{config['Data']['max_num_ptcs']}_N{config['Data']['num_train']}"

    # The description used for grouping different result of random seeds.
    group_rnd = '-'.join([model.__class__.__name__, model_description, data_description])
    
    # Name for this particular training (including random seed).
    name = group_rnd
    if config['Settings']['suffix'] != '':
        name += '-' + config['Settings']['suffix']
    name += '-' + str(random_seed)

    # Hyperparameters and configurations that will be recorded.
    training_info = config.copy()
    training_info.update(model_hparams)
    training_info.update({
        'model_description': model_description,
        'data_description': data_description,
        'group_rnd': group_rnd,
        'name': name,
        'date': config['date'],
        'model': model.__class__.__name__,
        'random_seed': random_seed,
    })

    return training_info

    
def create_lightning_model(model: nn.Module, graph: bool) -> L.LightningModule:
    """Create a lightning model for trainer."""

    # Optimizer.
    lr = eval(config['Train']['lr'])
    optimizer = torch.optim.RAdam(model.parameters(), lr=lr)

    # Create lightning model depends on graph or not.
    score_dim = config['Model']['score_dim']
    print_log = config['Settings']['print_log']
    if graph:
        return GraphLightningModel(model, optimizer=optimizer, score_dim=score_dim, print_log=print_log)
    else:
        return TorchLightningModule(model, optimizer=optimizer, score_dim=score_dim, print_log=print_log)
    

def create_trainer(model: nn.Module, training_info: dict, accelerator: str) -> L.Trainer:
    """Create lightning trainer for training."""

    # Create logger for monitoring the training.
    if config['Settings']['use_wandb']:
        wandb.login()
        logger = wandb_logger(training_info)
        logger.watch(model)
        loggers = [logger, csv_logger(training_info)]
    else:
        loggers = csv_logger(training_info)
    
    # Return the lightning trainer.
    return L.Trainer(
        logger=loggers,
        accelerator=accelerator,
        max_epochs=config['Train']['max_epochs'],
        log_every_n_steps=config['Train']['log_every_n_steps'],
        num_sanity_val_steps=config['Train']['num_sanity_val_steps'],
        
    )

In [None]:
def train(
        model: nn.Module, model_description: str, model_hparams: dict,
        random_seed: int, accelerator: str, graph: bool,
    ):
    
    # Fix all random stuff.
    L.seed_everything(random_seed)

    # Traditional training procedure.
    data_module = create_data_module(graph=graph)
    lightning_model = create_lightning_model(model=model, graph=graph)
    training_info = create_training_info(model, model_description, model_hparams, random_seed)
    trainer = create_trainer(model, training_info, accelerator)
    
    if config['Settings']['mode'] == '100':
        trainer.fit(lightning_model, train_dataloaders=data_module.train_dataloader())
    elif config['Settings']['mode'] == '110':
        trainer.fit(lightning_model, datamodule=data_module)
    elif config['Settings']['mode'] == '111':
        trainer.fit(lightning_model, datamodule=data_module)
        print('\nTesting the model.\n')
        trainer.test(lightning_model, datamodule=data_module)

    # Finish wandb if used.
    if config['Settings']['use_wandb']:
        wandb.finish()

    return training_info['name']

def train_quantum(random_seed: int, model_class: nn.Module, model_hparams: dict):

    model_hparams.update(config['Model'])
    model = model_class(**model_hparams)

    num_ir_qubits = model_hparams['num_ir_qubits']
    num_nr_qubits = model_hparams['num_nr_qubits']
    num_layers = model_hparams['num_layers']
    num_reupload = model_hparams['num_reupload']
    model_description = f"nI{num_ir_qubits}_nQ{num_nr_qubits}_l{num_layers}_r{num_reupload}"

    accelerator = 'cpu'
    name = train(model, model_description, model_hparams, random_seed, accelerator, graph=False)

    return name

def train_mpgnn(random_seed: int, model_hparams: dict):
    
    model_hparams.update(config['Model'])
    model = ClassicalMPGNN(**model_hparams)

    gnn_out = model_hparams['gnn_out']
    gnn_hidden = model_hparams['gnn_hidden']
    gnn_layers = model_hparams['gnn_layers']
    model_description = f"go{gnn_out}_gh{gnn_hidden}_gl{gnn_layers}"

    accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
    name = train(model, model_description, model_hparams, random_seed, accelerator, graph=True)
    
    return name

def train_benchmark(random_seed: int, model_class: nn.Module):
    with open('configs/benchmark.yaml', 'r') as file:
        hparams = yaml.safe_load(file)[model_class.__name__]

    model = model_class(score_dim=config['Model']['score_dim'], parameters=hparams)
    
    accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
    name = train(model, model_class.__name__, hparams, random_seed, accelerator, graph=False)
    
    return name

In [None]:
# name_list = []

# for random_seed in range(3):
#     print('=' * 10 + f"Random Seed {random_seed}" + '-' * 10)

#     # MPGNN
#     for gnn_dim in [3, 5, 7]:
#         print(f" * Train MPGNN {gnn_dim}.")
#         mpgnn_hparams = {'gnn_in': 6, 'gnn_out': gnn_dim, 'gnn_layers': 2, 'gnn_hidden': gnn_dim}
#         name = train_mpgnn(random_seed=random_seed, model_hparams=mpgnn_hparams)
#         name_list.append(name)

#     # Super MPGNN
#     print(f" * Train MPGNN {256}.")
#     mpgnn_hparams = {'gnn_in': 6, 'gnn_out': 1024, 'gnn_layers': 2, 'gnn_hidden': 1024}
#     name = train_mpgnn(random_seed=random_seed, model_hparams=mpgnn_hparams)
#     name_list.append(name)

#     # Particle Transformer (without interaction).
#     print(f" * Train Particle Transformer.")
#     name = train_benchmark(random_seed=random_seed, model_class=ParticleTransformer)
#     name_list.append(name)

#     # Particle Net.
#     print(f" * Train Particle Net.")
#     name = train_benchmark(random_seed=random_seed, model_class=ParticleNet)
#     name_list.append(name)

#     # QCGNN
#     for num_nr_qubits in [3, 5, 7]:
#         print(f" * Train QCGNN {num_nr_qubits}.")
#         qcgnn_hparams = {'num_ir_qubits': 4, 'num_nr_qubits': num_nr_qubits, 'num_layers': 1, 'num_reupload': num_nr_qubits}
#         name = train_quantum(random_seed=random_seed, model_class=QuantumRotQCGNN, model_hparams=qcgnn_hparams)
#         name_list.append(name)
    
#     # Hybrid
#     for num_nr_qubits in [3, 5, 7]:
#         print(f" * Train Hybrid {num_nr_qubits}.")
#         qcgnn_hparams = {'num_ir_qubits': 4, 'num_nr_qubits': num_nr_qubits, 'num_layers': 1, 'num_reupload': num_nr_qubits}
#         name = train_quantum(random_seed=random_seed, model_class=HybridQCGNN, model_hparams=qcgnn_hparams)
#         name_list.append(name)

In [None]:
# # Super MPGNN
# print(f" * Train MPGNN {1024}.")
# mpgnn_hparams = {'gnn_in': 6, 'gnn_out': 1024, 'gnn_layers': 2, 'gnn_hidden': 1024}
# name = train_mpgnn(random_seed=0, model_hparams=mpgnn_hparams)

# plot_metrics(name, num_classes=3)

In [None]:
# # Particle Transformer (without interaction).
# print(f" * Train Particle Transformer.")
# name = train_benchmark(random_seed=0, model_class=ParticleTransformer)

# plot_metrics(name, num_classes=3)

In [None]:
# # Particle Net.
# print(f" * Train Particle Net.")
# name = train_benchmark(random_seed=0, model_class=ParticleNet)

# plot_metrics(name, num_classes=3)