### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [None]:
%load_ext autoreload
%autoreload 2

### Import packages
We'll need the `nugraph` and `pynuml` packages imported in order to plot

In [None]:
import torch
import nugraph as ng
import pynuml

### Configure data module
Declare a data module. Depending on where you're working, you should edit the data path below to point to a valid data location.

In [None]:
nudata = ng.data.H5DataModule(data_path='/raid/uboone/CHEP2023/enhanced-vertex.gnn.h5', batch_size=64)

## Configure plotting utility
Instantiate the **pynuml** utility for plotting graph objects, which will do the heavy lifting for us here!

In [None]:
plot = pynuml.plot.GraphPlot(planes=nudata.planes,
                             classes=nudata.semantic_classes)

## Plot ground truth labels

### Iterable dataset

First we define an iterator over the test dataset:

In [None]:
test_iter = iter(nudata.test_dataset)

### Retrieve the next graph

This block retrieves a graph from the testing dataset. Since we defined `test_iter` as an iterator over the dataset, the following block can be executed multiple times, and each time it's executed, it will step to the next graph in the dataset.

In [None]:
data = next(test_iter)

### Instantiate loss function and calculate value

define placeholder values for the inputs, and then pass them into the class to calculate the loss

In [None]:
loss_func = ng.util.ObjCondensationLoss()

In [None]:
y = torch.cat([data[p].y_instance for p in nudata.planes], dim=0)
x = torch.rand(y.size(0), 3)
beta = torch.rand(y.size(0))
loss_func(x, beta, y)