### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [None]:
%load_ext autoreload
%autoreload 2

### Import packages
We'll need the `nugraph` and `pynuml` packages imported in order to plot

In [None]:
import torch
import nugraph as ng
import pynuml

### Configure data module
Declare a data module. Depending on where you're working, you should edit the data path below to point to a valid data location.

In [None]:
nudata = ng.data.H5DataModule(data_path='/raid/uboone/CHEP2023/enhanced-vertex.gnn.h5', batch_size=64)

## Configure plotting utility
Instantiate the **pynuml** utility for plotting graph objects, which will do the heavy lifting for us here!

In [None]:
plot = pynuml.plot.GraphPlot(planes=nudata.planes,
                             classes=nudata.semantic_classes)

## Plot ground truth labels

### Iterable dataset

First we define an iterator over the test dataset:

In [None]:
test_iter = iter(nudata.test_dataset)

### Retrieve the next graph

This block retrieves a graph from the testing dataset. Since we defined `test_iter` as an iterator over the dataset, the following block can be executed multiple times, and each time it's executed, it will step to the next graph in the dataset.

In [None]:
data = next(test_iter)

### Plot a single graph

We can now use pynuml's plotting utilities to plot the graph as a figure. Each time you call the above block to retrieve a new graph, you can then re-execute the plotting blocks to re-plot with the new graph.

In [None]:
fig = plot.plot(data, target='instance', how='true', filter='true')

### Visualise the plot

Now that we've plotted the graph as a figure, we can visualise that figure interactively simply by calling it here!

In [None]:
fig

### combine instance truth across all planes

combine the hits from all planes into a single tensor

In [None]:
y_instance = torch.cat([data[p].y_instance for p in nudata.planes], dim=0)
y_instance

### number of background hits

in the object condensation paper, $n_i$ is a Boolean array with a value of 1 for background hits (ie. hits that are not part of any instance), and a value of 0 otherwise. similarly, $N_B$ is the overall *number* of background hits.

In [None]:
n_i = (y_instance == -1)
N_B = n_i.sum()
n_i, N_B

### number of objects

since the instance indices are ascending beginning with 0, the number of instances in the event $K$ is simply the maximum instance index plus 1.

In [None]:
K = y_instance.max() + 1
K

### cluster assignment matrix $M_{ik}$

the cluster assignment matrix maps the hits onto their corresponding indices using a one-hot tensor.

In [None]:
M_ik = torch.zeros((y_instance.size(0), K)).long()
M_ik[~n_i,:] = torch.nn.functional.one_hot(y_instance[~n_i], num_classes=K)
M_ik

### $\beta_{i}$

the $\beta_{i}$ array will come out of the decoder's convolution block itself, but for the purposes of testing, we can just create a fake one here with random values

In [None]:
beta_i = torch.rand(y_instance.size(0))
beta_i

### $\beta_{\alpha k}$

the $\beta_{\alpha k}$ tensor summarises the information from $\beta_{i}$ across instances 

In [None]:
beta_ak = (beta_i[:,None] * M_ik).max(dim=0)
beta_ak

### exercise: define background loss function

equation 7 in the [object condensation paper](https://arxiv.org/abs/2002.03605) defines the background loss term as $L_{\beta} = \frac{1}{K} \sum_{k} (1 - \beta_{\alpha k}) + s_{B} \frac{1}{N_{B}} \sum_{i}^{N} n_{i} \beta_{i}$

In [None]:
def background_loss(beta, y_instance):
    # implement the code here!
    return 0

# call the function and get a return value
background_loss(beta_i, y_instance)