In [None]:
import argparse
import time
import yaml

import lightning as L
import torch
import torch.nn as nn
import wandb

from src.data.preprocessing import FatJetEvents
from src.data.datamodule import JetTorchDataModule, JetGraphDataModule
from src.models.qcgnn import QuantumRotQCGNN
from src.models.mpgnn import ClassicalMPGNN
from src.training.litmodel import TorchLightningModule, GraphLightningModel
from src.training.loggers import csv_logger, wandb_logger

with open('configs/config.yaml', 'r') as file:
    config = yaml.safe_load(file)
    config['date'] = time.strftime('%Y%m%d_%H%M%S', time.localtime())

In [None]:
def create_data_module(graph: bool) -> L.LightningDataModule:
    """Randomly create a data module."""
    
    # Read jet data (reading is not affected by random seed).
    sig_events = FatJetEvents(channel=config['Data']['sig'], **config['Data'])
    bkg_events = FatJetEvents(channel=config['Data']['bkg'], **config['Data'])

    # Randomly generated uniform events (will be affected by random seed).
    sig_events = sig_events.generate_uniform_pt_events()
    bkg_events = bkg_events.generate_uniform_pt_events()

    # Turn into data module (for lightning training module).
    if graph:
        data_module = JetGraphDataModule(sig_events, bkg_events, **config['Data'])
    else:
        data_module = JetTorchDataModule(sig_events, bkg_events, **config['Data'])

    return data_module


def create_training_info(model: nn.Module, model_description: str, model_hparams: dict, random_seed: int) -> dict:
    """Create a training information dictionary that will be recorded."""

    # Hyperparameters and configurations that will be recorded.
    training_info = {}
    training_info.update(config['Data'])
    training_info.update(config['Train'])
    training_info.update(config['Model'])
    training_info.update(model_hparams)

    # Data information.
    max_num_ptcs_per_jet = config['Data']['max_num_ptcs']
    num_data_per_channel = config['Data']['num_bins'] * config['Data']['num_data_per_bin']
    data_description = f"Maxptc{max_num_ptcs_per_jet}_N{num_data_per_channel}"

    # Name of the model training.
    name = '-'.join([
        model.__class__.__name__,
        model_description,
        data_description,
        f"{config['date']}_{random_seed}",
    ])
    
    # Suffix of the training.
    if config['Settings']['suffix'] != '':
        name = name + '-' + config['Settings']['suffix']

    training_info['model'] = model.__class__.__name__
    training_info['random_seed'] = random_seed
    training_info['model_description'] = model_description
    training_info['data_description'] = data_description
    training_info['name'] = name

    return training_info

    
def create_lightning_model(model: nn.Module, graph: bool) -> L.LightningModule:
    """Create a lightning model for trainer."""

    # Optimizer.
    lr = eval(config['Train']['lr'])
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Create lightning model depends on graph or not.
    score_dim = config['Model']['score_dim']
    if graph:
        return GraphLightningModel(model, optimizer=optimizer, score_dim=score_dim)
    else:
        return TorchLightningModule(model, optimizer=optimizer, score_dim=score_dim)
    

def create_trainer(model: nn.Module, training_info: dict, accelerator: str) -> L.Trainer:
    """Create lightning trainer for training."""

    # Create logger for monitoring the training.
    if config['Settings']['use_wandb']:
        wandb.login()
        logger = wandb_logger(training_info)
        logger.watch(model)
    else:
        logger = csv_logger(training_info)
    
    # Return the lightning trainer.
    return L.Trainer(
        logger=logger,
        accelerator=accelerator,
        max_epochs=config['Train']['max_epochs'],
        log_every_n_steps=config['Train']['log_every_n_steps'],
        num_sanity_val_steps=config['Train']['num_sanity_val_steps']
    )

In [None]:
def train(
        model: nn.Module, model_description: str, model_hparams: dict,
        random_seed: int, accelerator: str, graph: bool,
    ):
    
    # Fix all random stuff.
    L.seed_everything(random_seed)

    # Traditional training procedure.
    data_module = create_data_module(graph=graph)
    lightning_model = create_lightning_model(model=model, graph=graph)
    training_info = create_training_info(model, model_description, model_hparams, random_seed)
    trainer = create_trainer(model, training_info, accelerator)
    
    trainer.fit(lightning_model, datamodule=data_module)

    # Finish wandb if used.
    if config['Settings']['use_wandb']:
        wandb.finish()

def train_qcgnn(random_seed: int, model_hparams: dict):

    model_hparams.update(config['Model'])
    model = QuantumRotQCGNN(**model_hparams)

    num_ir_qubits = model_hparams['num_ir_qubits']
    num_nr_qubits = model_hparams['num_nr_qubits']
    num_layers = model_hparams['num_layers']
    num_reupload = model_hparams['num_reupload']
    model_description = f"nI{num_ir_qubits}_nQ{num_nr_qubits}_l{num_layers}_r{num_reupload}"

    accelerator = 'cpu'
    train(model, model_description, model_hparams, random_seed, accelerator, graph=False)

def train_mpgnn(random_seed: int, model_hparams: dict):
    
    model_hparams.update(config['Model'])
    model = ClassicalMPGNN(**model_hparams)

    gnn_out = model_hparams['gnn_out']
    gnn_hidden = model_hparams['gnn_hidden']
    gnn_layers = model_hparams['gnn_layers']
    model_description = f"go{gnn_out}_gh{gnn_hidden}_gl{gnn_layers}"

    accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
    train(model, model_description, model_hparams, random_seed, accelerator, graph=True)

In [None]:
for random_seed in range(3):
    print('=' * 10 + f"Random Seed {random_seed}" + '-' * 10)

    # MPGNN
    for gnn_dim in [3, 5, 7]:
        print(f" * Train MPGNN {gnn_dim}.")
        mpgnn_hparams = {'gnn_in': 6, 'gnn_out': gnn_dim, 'gnn_layers': 2, 'gnn_hidden': gnn_dim}
        train_mpgnn(random_seed=random_seed, model_hparams=mpgnn_hparams)

    # SUPER MPGNN
    print(f" * Train MPGNN {1024}.")
    mpgnn_hparams = {'gnn_in': 6, 'gnn_out': 1024, 'gnn_layers': 2, 'gnn_hidden': 1024}
    train_mpgnn(random_seed=random_seed, model_hparams=mpgnn_hparams)

    # QCGNN
    for num_nr_qubits in [3, 5, 7]:
        print(f" * Train QCGNN {num_nr_qubits}.")
        qcgnn_hparams = {'num_ir_qubits': 4, 'num_nr_qubits': num_nr_qubits, 'num_layers': 1, 'num_reupload': num_nr_qubits}
        train_qcgnn(random_seed=0, model_hparams=qcgnn_hparams)