### Set GPU device
Set CUDA device to enable single-GPU training. This step is essential in a multi-GPU environment, such as the Heimdall cluster. It's important to do this before importing torch or any ML-related packages.

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

### Import packages
We'll need the `pytorch_lightning` and `nugraph` packages imported in order to train

In [None]:
import nugraph as ng
import pytorch_lightning as pl

### Configure data module
Declare a data module. Depending on where you're working, you should edit the data path below to point to a valid data location.

In [None]:
nudata = ng.data.H5DataModule(data_path='/raid/uboone/CHEP2023/enhanced.gnn.h5', batch_size=64)

### Configure network
Declare a model. You can edit the arguments below to change the network configuration.

In [None]:
nugraph = ng.models.NuGraph2(
    in_features=4,
    node_features=64,
    edge_features=16,
    sp_features=16,
    planes=nudata.planes,
    classes=nudata.classes,
    num_iters=5,
    event_head=False,
    semantic_head=True,
    filter_head=False,
    checkpoint=True,
    lr=0.001)

### Configure logger and callbacks
Declare a tensorboard logger and define the output directory, so we can monitor network training. Also define a callback so we can monitor learning rate evolution.

In [None]:
logger = pl.loggers.TensorBoardLogger(save_dir='/raid/vhewes/logs', name='semantic', version='64_16_16')
callbacks = [ pl.callbacks.LearningRateMonitor(logging_interval='step') ]

### Declare trainer and run training
Instantiate a PyTorch Lightning trainer that we'll use for training, and then run the training stage, which iterates over all batches in the train and validation datasets to optimise model parameters, writing output metrics to tensorboard.

In [None]:
trainer = pl.Trainer(max_epochs=80,
                     logger=logger,
                     callbacks=callbacks)
trainer.fit(nugraph, datamodule=nudata)