## _Inference_

**_Inference_** is done using callbacks defined in the _LightningModules/GNN/Models/inference.py_. The callbacks run during the _test_step()_ _a.k.a_ model _**evalution**_.


### _How to Run Inference?_

1. _`traintrack config/pipeline_quickstart.yaml`_: One can use `--inference` flag to run only the `test_step()` (don't forget to give `resume_id` of a checkpoint)
2. _`infer.ipynb`_ notebook runs the _pl.Trainer().test()_

In [None]:
import sys, os, glob, yaml

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
from tqdm import tqdm
import trackml.dataset

In [None]:
import torch
import torchmetrics
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import itertools

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

In [None]:
from LightningModules.GNN import InteractionGNN
from LightningModules.GNN import GNNBuilder, GNNMetrics
from LightningModules.GNN.Models.infer import GNNTelemetry

### _Load Checkpoint_

Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. We have checkpoint stored after training is finished.

```python
# load a LightningModule along with its weights & hyperparameters from a checkpoint
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.input_dir)
```

Note that we have saved our hyperparameters when our **LightningModule** was initialized i.e. `self.save_hyperparameters(hparams)`

```python
# hyperparameters are saved to the “hyper_parameters” key in the checkpoint, to access them
checkpoint = torch.load(path/to/checkpoint, map_location=device)
print(checkpoint["hyper_parameters"])
```

One can also initialize the model with different hyperparameters (if they are saved).


For more details, consult [Lighting Checkpointing](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing.html).

### _Get Checkpoint Hparams_

- Either from the configs folder 
- Or extract it from the checkpoint, favoured if model is trained and evaluated on two different machines.

In [None]:
# load processing config file (trusted source)
config = None
config_file = os.path.join(os.curdir, 'LightningModules/GNN/configs/train_alldata_IGNN.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

In [None]:
# print(config)

In [None]:
# Load Model Checkpoint
# ckpnt_path = "run_all/lightning_models/lightning_checkpoints/GNNStudy/a58b2mlx/checkpoints/last.ckpt"  # muons with Filtering=True
# ckpnt_path = "run_all/lightning_models/lightning_checkpoints/HypGNN/uibb0ir9/checkpoints/last.ckpt"  # fwp with Filtering=True
ckpnt_path = "run_all/lightning_models/lightning_checkpoints/HypGNN/p9rjmknq/checkpoints/last.ckpt"  # fwp with Filtering=False

In [None]:
checkpoint = torch.load(ckpnt_path, map_location=device)
hparams = checkpoint["hyper_parameters"]

In [None]:
# View Hyperparameters
print(hparams)

In [None]:
# One Can Modify Hyperparameters
hparams["checkpoint_path"] = ckpnt_path
hparams["input_dir"] = "run_quick/fwp_feature_store"
hparams["output_dir"] = "run_quick/fwp_gnn_processed"
hparams["artifact_library"] = "lightning_models/lightning_checkpoints"
hparams["train_split"] = [0, 0, 19803]
hparams["map_location"] = device

In [None]:
# View Hyperparameters (Modified)
print(hparams)

### _Get Checkpoint Model_

In [None]:
# Init EdgeClassifier with New Config
model = InteractionGNN(hparams)

In [None]:
# model.hparams

In [None]:
# Load Checkpoint with New Config (It will Provide Path and Other Parameters, Most will be Overwritten)
model = model.load_from_checkpoint(**hparams)

### _(1) - Inference: Callbacks_

* _Test with LightingModule_

In [None]:
# Lightning Trainer
trainer = pl.Trainer(callbacks=[GNNBuilder()])

In [None]:
# Run TestStep
trainer.test(model=model, verbose=True)

* _Test with LightningDataModule_

In [None]:
# from Predict import SttDataModule

In [None]:
# Prepare LightningDataModule
# dm = SttDataModule(config)

In [None]:
# dm.setup(stage='test')
# test_dataloaders = dm.test_dataloader

In [None]:
# Run TestStep with LightningDataModule
# trainer.test(model=model, dataloaders=None, ckpt_path=None, verbose=True, datamodule=dm)

### _(2) - Inference: Manual_

In [None]:
# from Predict import eval_model

- _Get Test Dataset from LightningModuel_

In [None]:
# run setup() for datasets
# model.setup(stage="fit")

In [None]:
# get testset or test_dataloader
# testset = model.testset
# test_dataloader = model.test_dataloader()

- _Run `eval_model()` on `test_dataloader()`_

In [None]:
# evaluate model, returns torch tensors
# scores, truths = eval_model(model, test_dataloader)

### _(3) - Inference: BNNBuilder_

_If **GNNBuilder** callback has been run during training, just load data from `gnn_processed/test` and extract `scores` and `y_pid ~ truth` and simply run the following metrics_.

In [None]:
# fetch all files
inputdir = "run_all/gnn_processed/pred"
gnn_files = sorted(glob.glob(os.path.join(inputdir, "*")))
print("Number of Files: ", len(gnn_files))

- _Load all `truth` and `scores` from the `testset` from the `DNN` stage_

In [None]:
scoresl, truthsl = [], []

for e in range(len(gnn_files)):
    
    # logging
    if e !=0 and e%1000==0:
        print("Processed Batches: ", e)
    
    gnn_data = torch.load(gnn_files[e], map_location=device)
    
    truth = gnn_data.y_pid
    score = gnn_data.scores
    score = score[:truth.size(0)]
    
    # append each batch
    scoresl.append(score)
    truthsl.append(truth)

In [None]:
# concatenate all
scores = torch.cat(scoresl)
truths = torch.cat(truthsl)

In [None]:
# save as .npy files
np.save("gnn_scores.npy", scores.numpy())
np.save("gnn_truths.npy", truths.numpy())

### _Test Dataset_

- _Get Data from LightningModule_

In [None]:
# Method 1: Directly Get Test Dataset
# testset = model.testset

# Get singel Batch
# batch = testset[0]

# OR, loop over
# for index, batch in enumerate(testset):
# for batch in testset:
#    print(index, batch)

In [None]:
# Method 2: Directly Get Test Dataloader
# test_dataloader = model.test_dataloader()

# Get singel Batch
# batch = next(iter(test_dataloader))

# OR, loop over
# for batch_idx, batch in enumerate(test_dataloader):
# for batch in test_dataloader:
#    print(batch)

- _Get Data from Test Dataset_

In [None]:
# Test Dataset from GNNBuilder
inputdir="run_all/gnn_processed/test"
all_events = sorted(glob.glob(os.path.join(inputdir, "*")))

In [None]:
loaded_events = []
for e in tqdm(all_events):
    loaded_events.append(torch.load(e, map_location=device))

In [None]:
# PyG DataLoader
test_dataloader = DataLoader(loaded_events, batch_size=1, num_workers=0)

In [None]:
# Fetch One Batch
sampled_data = next(iter(test_dataloader))

In [None]:
# Print One Batch
print(sampled_data)

### _Plot Test Event_

In [None]:
testset = model.testset

In [None]:
example_data = testset[0]
r, phi, ir = example_data.x.T

In [None]:
x, y = r * np.cos(phi * np.pi), r * np.sin(phi * np.pi)

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(x, y, s=2)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

In [None]:
e = example_data.edge_index
pid = example_data.pid
true_edges = pid[e[0]] == pid[e[1]]

In [None]:
plt.figure(figsize=(6,6))
# plt.plot(x[e[:, ~true_edges]], y[e[:, ~true_edges]], c="r")
plt.plot(x[e[:, true_edges]], y[e[:, true_edges]], c="k")
plt.scatter(x, y, s=5)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

In [None]:
plt.figure(figsize=(6,6))
plt.plot(x[e[:, (~true_edges)][:, 0:-1:5]], y[e[:, (~true_edges)][:, 0:-1:5]], c="r")
plt.scatter(x, y, s=5)
plt.title("Azimuthal View of Detector", fontsize=24), plt.xlabel(
    "x", fontsize=18
), plt.ylabel("y", fontsize=18)

### _TensorBoard Logger_

In [None]:
# Load TensorBoard notebook extension
# %load_ext tensorboard

In [None]:
# %tensorboard --logdir=run_all/lightning_models/lightning_checkpoints/DNNStudy