# Test pre-trained QCGNN model on IBMQ

1. Must have an existing [IBMQ](https://quantum.ibm.com) account.
2. Create a local `./config.toml` file for PennyLane to link to the IBMQ, see [PennyLane configuration file](https://docs.pennylane.ai/en/latest/introduction/configuration.html#format) for further detail.

In [None]:
from itertools import product
import os
import yaml

import lightning as L
from lightning.pytorch.loggers import CSVLogger
import pennylane as qml
import torch

from source.data.datamodule import JetTorchDataModule
from source.data.opendata import TopQuarkEvents
from source.models.qcgnn import QuantumRotQCGNN
from source.training.litmodel import TorchLightningModule

In [None]:
dataset = 'TopQCD'

with open(f"configs/config.yaml", 'r') as file:
    config = yaml.safe_load(file)

with open(f"configs/ibmq.yaml", 'r') as file:
    ibmq_config = yaml.safe_load(file)
    save_dir = os.path.join('ibmq_result', 'noise')
    os.makedirs(save_dir, exist_ok=True)

In [None]:
# Create QCGNN model.
n_I = ibmq_config['Pretrain']['n_I']

for n_Q, rnd_seed in product([3, 6], range(5)):

    L.seed_everything(rnd_seed)

    num_data = ibmq_config['Data']['num_data']
    num_ptcs = ibmq_config['Data']['num_ptcs']
    dataset_config = {}
    dataset_config.update(config['Data'])
    dataset_config.update(config[dataset])
    dataset_config['min_num_ptcs'] = num_ptcs
    dataset_config['max_num_ptcs'] = num_ptcs

    # Get 'TopQCD' events.
    events = []
    for y, channel in enumerate(['Top', 'QCD']):
        y_true = [y] * num_data
        top_qcd_events = TopQuarkEvents(mode='test', is_signal_new=y, **dataset_config)
        events.append(top_qcd_events.generate_events(num_data))
        print(f"{channel} has {len(top_qcd_events.events)} events -> selected = {num_data}\n")

    # Turn into data-module.    
    data_module = JetTorchDataModule(
        events=events,
        num_train=0,
        num_valid=0,
        num_test=num_data * config[dataset]['num_classes'],
        batch_size=ibmq_config['Data']['batch_size'],
        max_num_ptcs=num_ptcs,
        pi_scale=True
    )

    for noise_prob in [0, 1E-6, 1E-5, 1E-4, 1E-3, 1E-2]:
        model = QuantumRotQCGNN(
            num_ir_qubits=n_I,
            num_nr_qubits=n_Q,
            num_layers=n_Q//3,
            num_reupload=2,
            vqc_ansatz=qml.StronglyEntanglingLayers,
            score_dim=1,
            qdevice='default.mixed',
            noise_prob=noise_prob,
        )

        # Load pre-trained checkpoint.
        last_ckpt = torch.load(f"./training_logs/WandbLogger/QCGNN_Rev1/QuantumRotQCGNN-nI4_nQ{n_Q}_L{n_Q//3}_R2_D0.00-TopQCD_P16_N25000-CQ0715-{rnd_seed}/checkpoints/last.ckpt") # Last checkpoint.
        best_ckpt = list(last_ckpt['callbacks'].values())[0]['best_model_path'] # Get the best checkpoint path.
        best_state_dict = torch.load(best_ckpt)['state_dict'] # Load the best state dict.
        best_state_dict = {k.replace('model.', ''): v for k, v in best_state_dict.items()} # Remove the 'model.' prefix.
        model.load_state_dict(best_state_dict)

        # Turn into a lightning model.
        model.eval()
        lit_model = TorchLightningModule(model=model, optimizer=None, score_dim=1, print_log=False)

        # Start testing.
        logger = CSVLogger(save_dir=save_dir, name=f"Noise_{noise_prob:.0e}-Q{n_Q}-{rnd_seed}")
        trainer = L.Trainer(accelerator='cpu', default_root_dir=save_dir, logger=logger)
        trainer.test(model=lit_model, datamodule=data_module)