### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [None]:
%load_ext autoreload
%autoreload 2

### Import packages
We'll need the `pytorch_lightning` and `nugraph` packages imported in order to train

In [None]:
import os
from pathlib import Path
import nugraph as ng
import pytorch_lightning as pl

### Configure data module
Declare a data module. Depending on where you're working, you should edit the data path below to point to a valid data location.

In [None]:
nudata = ng.data.H5DataModule()

### Configure network
Declare a model. You can edit the arguments below to change the network configuration.

In [None]:
nugraph = ng.models.NuGraph3(
    in_features=5,
    hit_features=128,
    nexus_features=32,
    instance_features=32,
    interaction_features=32,
    semantic_classes=nudata.semantic_classes,
    event_classes=nudata.event_classes,
    num_iters=5,
    event_head=False,
    semantic_head=True,
    filter_head=True,
    vertex_head=False,
    instance_head=True,
    use_checkpointing=True,
    lr=0.001)

### Configure logger and callbacks
Declare a tensorboard logger and define the output directory, so we can monitor network training. Also define a callback so we can monitor learning rate evolution.

In [None]:
name = "test"
logdir = Path(os.environ["NUGRAPH_LOG"])/name
logdir.mkdir(parents=True, exist_ok=True)
logger = pl.loggers.WandbLogger(save_dir=logdir, project="nugraph3", name="test",
                                log_model="all")
callbacks = [
    pl.callbacks.LearningRateMonitor(logging_interval="step"),
    pl.callbacks.ModelCheckpoint(monitor="loss/val", mode="min"),
]

### Declare trainer and run training
First we set the training device. To train with a GPU, pass an integer  otherwise, it defaults to CPU training. We then instantiate a PyTorch Lightning trainer that we'll use for training, and then run the training stage, which iterates over all batches in the train and validation datasets to optimise model parameters, writing output metrics to tensorboard.

In [None]:
accelerator, devices = ng.util.configure_device()
trainer = pl.Trainer(accelerator=accelerator,
                     devices=devices,
                     max_epochs=80,
                     logger=logger,
                     callbacks=callbacks)
trainer.fit(nugraph, datamodule=nudata)
trainer.test(datamodule=nudata)