# SPA-Net on Jupyter Notebooks

**Author: Shahzad Sanjrani**

**Date: 21.11.24**

This notebook is about seeing if it is, in principle, possible to run SPA-Net on a jupyter notebook. Whether it's efficient is another story (can always ask for a GPU node right?)

In [7]:
import spanet
import os

## Training

Essentially copying train.py into jupyter because we want to see what's up...

1. Initialise variables
2. Import extra modules
3. Training script

### Initial variables

In [8]:
store_dir = "/nfs/dust/cms/user/sanjrani/SPANet_Investigations/investigation2/pepper_analysis/output/h4t_systematics/spanet/input"
sample_dir = "genstudies_2017_jpt20_GENRECO_training/TTZprimeToTT_M-500_Width4_TuneCP5_13TeV-madgraph-pythia8"

In [14]:
### --- WHO, WHAT, HOW ARE WE TRAINING --- ###
event_file = "./event_files/round2/full_hadronic_tttt_reco_tops.yaml"
options_file = "./options_files/round2/reco_four_tops/full_hadronic_tttt_reconstruct_1.json"
training_file = os.path.join(store_dir, sample_dir, "TTZprimeToTT_M-500_Width4_TuneCP5_13TeV-madgraph-pythia8_even_train.h5")
validation_file = os.path.join(store_dir, sample_dir, "TTZprimeToTT_M-500_Width4_TuneCP5_13TeV-madgraph-pythia8_even_val.h5")
gpus = 0
epochs = 30
batch_size = 512

### -- WHERE TO PUT OUTPUT --- ###
log_dir = "/nfs/dust/cms/user/sanjrani/SPANet_Investigations/investigation2/pepper_analysis/output/h4t_systematics/spanet/models"
name = "spanet_output"

### --- EXTRAS (IGNORE) --- ###
checkpoint = None # load from a training state
state_dict = None # load from checkpoing by only model weights
freeze_state_dict = False # freeze weights loaded from state_dict (for finetuning new layers)

torch_script = False # compile using torch_script
fp16 = False # use torch AMP for training
verbose = False
full_events = False
profile = False # profile network for single training epoch
time_limit = None
limit_dataset = None
random_seed = 0 

### Load in extra modules

In [10]:
from argparse import ArgumentParser
from typing import Optional
from os import getcwd, makedirs, environ
import shutil
import json

import torch
import pytorch_lightning as pl
from pytorch_lightning.profilers import PyTorchProfiler
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE, WandbLogger

from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    RichProgressBar,
    RichModelSummary,
    DeviceStatsMonitor,
    ModelSummary,
    TQDMProgressBar
)

from spanet import JetReconstructionModel, Options

### Train

#### Setup the situation

In [16]:
# Whether or not this script version is the master run or a worker
master = True
if "NODE_RANK" in environ:
    master = False

# -------------------------------------------------------------------------------------------------------
# Create options file and load any optional extra information.
# -------------------------------------------------------------------------------------------------------
options = Options(event_file, training_file, validation_file)

if options_file is not None:
    with open(options_file, 'r') as json_file:
        options.update_options(json.load(json_file))

# -------------------------------------------------------------------------------------------------------
# Command line overrides for common option values.
# -------------------------------------------------------------------------------------------------------
options.verbose_output = verbose
if master and verbose:
    print(f"Verbose output activated.")

if full_events:
    if master:
        print(f"Overriding: Only using full events")
    options.partial_events = False
    options.balance_particles = False

if gpus is not None:
    if master:
        print(f"Overriding GPU count: {gpus}")
    options.num_gpu = gpus

if batch_size is not None:
    if master:
        print(f"Overriding Batch Size: {batch_size}")
    options.batch_size = batch_size

if limit_dataset is not None:
    if master:
        print(f"Overriding Dataset Limit: {limit_dataset}%")
    options.dataset_limit = limit_dataset / 100

if epochs is not None:
    if master:
        print(f"Overriding Number of Epochs: {epochs}")
    options.epochs = epochs

if random_seed > 0:
    options.dataset_randomization = random_seed

# -------------------------------------------------------------------------------------------------------
# Print the full hyperparameter list
# -------------------------------------------------------------------------------------------------------
if master:
    options.display()

# -------------------------------------------------------------------------------------------------------
# Begin the training loop
# -------------------------------------------------------------------------------------------------------

# Create the initial model on the CPU
model = JetReconstructionModel(options, torch_script)

if state_dict is not None:
    if master:
        print(f"Loading state dict from: {state_dict}")

    state_dict = torch.load(state_dict, map_location="cpu")["state_dict"]
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

    if master:
        print(f"Missing Keys: {missing_keys}")
        print(f"Unexpected Keys: {unexpected_keys}")

    if freeze_state_dict:
        for pname, parameter in model.named_parameters():
            if pname in state_dict:
                parameter.requires_grad_(False)

# Construct the logger for this training run. Logs will be saved in {logdir}/{name}/version_i
log_dir = getcwd() if log_dir is None else log_dir
logger = TensorBoardLogger(save_dir=log_dir, name=name)
# logger = (
#     WandbLogger(name=name, save_dir=log_dir)
#     if _WANDB_AVAILABLE else
#     TensorBoardLogger(save_dir=log_dir, name=name)
# )

# Create the checkpoint for this training run. We will save the best validation networks based on 'accuracy'
callbacks = [
    ModelCheckpoint(
        verbose=options.verbose_output,
        monitor='validation_accuracy',
        save_top_k=3,
        mode='max',
        save_last=True
    ),
    LearningRateMonitor(),
    DeviceStatsMonitor(),
    RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar(),
    RichModelSummary(max_depth=1) if _RICH_AVAILABLE else ModelSummary(max_depth=1)
]

epochs = options.epochs
profiler = None
if profile:
    epochs = 1
    profiler = PyTorchProfiler(emit_nvtx=True)

# Create the final pytorch-lightning manager
trainer = pl.Trainer(
    accelerator="gpu" if options.num_gpu > 0 else "auto",
    devices=options.num_gpu if options.num_gpu > 0 else "auto",
    strategy="ddp" if options.num_gpu > 1 else "auto",
    precision="16-mixed" if fp16 else "32-true",

    gradient_clip_val=options.gradient_clip if options.gradient_clip > 0 else None,
    max_epochs=epochs,
    max_time=time_limit,

    logger=logger,
    profiler=profiler,
    callbacks=callbacks
)

# Save the current hyperparameters to a json file in the checkpoint directory
if master:
    print(f"Training Version {trainer.logger.version}")
    makedirs(trainer.logger.log_dir, exist_ok=True)

    with open(f"{trainer.logger.log_dir}/options.json", 'w') as json_file:
        json.dump(options.__dict__, json_file, indent=4)

    shutil.copy2(options.event_info_file, f"{trainer.logger.log_dir}/event.yaml")

trainer.fit(model, ckpt_path=checkpoint)
# -------------------------------------------------------------------------------------------------------


### Verdict?

**Training** on SPA-Net works in jupyter. Is this useful? Not entirely sure, since we want to explore as much as possible and it takes forever to run through these. I think we'll keep training to NAF except for small testing here and there...