### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [2]:
%load_ext autoreload
%autoreload 2

### Import packages
We'll need the `nugraph` and `pynuml` packages imported in order to plot

In [3]:
import torch
import nugraph as ng
import pynuml

### Configure data module
Declare a data module. Depending on where you're working, you should edit the data path below to point to a valid data location.

In [4]:
nudata = ng.data.H5DataModule(data_path='/raid/uboone/CHEP2023/enhanced-vertex.gnn.h5', batch_size=64)

## Configure plotting utility
Instantiate the **pynuml** utility for plotting graph objects, which will do the heavy lifting for us here!

In [5]:
plot = pynuml.plot.GraphPlot(planes=nudata.planes,
                             classes=nudata.semantic_classes)

## Plot ground truth labels

### Iterable dataset

First we define an iterator over the test dataset:

In [6]:
test_iter = iter(nudata.test_dataset)

### Retrieve the next graph

This block retrieves a graph from the testing dataset. Since we defined `test_iter` as an iterator over the dataset, the following block can be executed multiple times, and each time it's executed, it will step to the next graph in the dataset.

In [7]:
data = next(test_iter)

### Plot a single graph

We can now use pynuml's plotting utilities to plot the graph as a figure. Each time you call the above block to retrieve a new graph, you can then re-execute the plotting blocks to re-plot with the new graph.

In [8]:
fig = plot.plot(data, target='instance', how='true', filter='true')

### Visualise the plot

Now that we've plotted the graph as a figure, we can visualise that figure interactively simply by calling it here!

In [17]:
fig

FigureWidget({
    'data': [{'hovertemplate': ('plane=u<br>wire=%{x}<br>time=%' ... '%{marker.color}<extra></extra>'),
              'legendgroup': '',
              'marker': {'color': array([-1, -1, -1, ..., -1,  1,  1]),
                         'coloraxis': 'coloraxis',
                         'symbol': 'circle'},
              'mode': 'markers',
              'name': '',
              'orientation': 'v',
              'showlegend': False,
              'type': 'scatter',
              'uid': 'a5ab56c3-1c58-4682-9d7c-a19d171b0405',
              'x': array([438.90002, 438.6    , 438.30002, ..., 348.90002, 348.6    , 348.30002],
                         dtype=float32),
              'xaxis': 'x',
              'y': array([205.617  , 205.87093, 206.14066, ..., 291.85364, 292.59937, 292.53076],
                         dtype=float32),
              'yaxis': 'y'},
             {'hovertemplate': ('plane=v<br>wire=%{x}<br>time=%' ... '%{marker.color}<extra></extra>'),
              'leg

### combine instance truth across all planes

combine the hits from all planes into a single tensor

In [10]:
y_instance = torch.cat([data[p].y_instance for p in nudata.planes], dim=0)
y_instance

tensor([-1, -1, -1,  ..., -1, -1, -1])

### number of background hits

in the object condensation paper, $n_i$ is a Boolean array with a value of 1 for background hits (ie. hits that are not part of any instance), and a value of 0 otherwise. similarly, $N_B$ is the overall *number* of background hits.

In [11]:
n_i = (y_instance == -1)
N_B = n_i.sum()
n_i, N_B

(tensor([True, True, True,  ..., True, True, True]), tensor(1102))

### number of objects

since the instance indices are ascending beginning with 0, the number of instances in the event $K$ is simply the maximum instance index plus 1.

In [12]:
K = y_instance.max() + 1
K

tensor(2)

### cluster assignment matrix $M_{ik}$

the cluster assignment matrix maps the hits onto their corresponding indices using a one-hot tensor.

In [13]:
M_ik = torch.zeros((y_instance.size(0), K)).long()
M_ik[~n_i,:] = torch.nn.functional.one_hot(y_instance[~n_i], num_classes=K)
M_ik

tensor([[0, 0],
        [0, 0],
        [0, 0],
        ...,
        [0, 0],
        [0, 0],
        [0, 0]])

### $\beta_{i}$

the $\beta_{i}$ array will come out of the decoder's convolution block itself, but for the purposes of testing, we can just create a fake one here with random values

In [14]:
beta_i = torch.rand(y_instance.size(0))
beta_i

tensor([0.3789, 0.9288, 0.4104,  ..., 0.9905, 0.2571, 0.6387])

### $\beta_{\alpha k}$

the $\beta_{\alpha k}$ tensor summarises the information from $\beta_{i}$ across instances 

In [15]:
beta_ak = (beta_i[:,None] * M_ik).max(dim=0)
beta_ak

torch.return_types.max(
values=tensor([0.9927, 0.9982]),
indices=tensor([1116,  319]))

### exercise: define background loss function

equation 7 in the [object condensation paper](https://arxiv.org/abs/2002.03605) defines the background loss term as $L_{\beta} = \frac{1}{K} \sum_{k} (1 - \beta_{\alpha k}) + s_{B} \frac{1}{N_{B}} \sum_{i}^{N} n_{i} \beta_{i}$

In [16]:
import torch

def background_loss(y_instance, beta_i):
    S_b=1
    n_i = (y_instance == -1)
    M_ik = torch.zeros((y_instance.size(0), K)).long()
    M_ik[~n_i,:] = torch.nn.functional.one_hot(y_instance[~n_i], num_classes=K)
    beta_ak = (beta_i[:,None] * M_ik).max(dim=0).values
    N_b = n_i.sum()
    L_beta_1 = (1 - beta_ak).sum() / K
    L_beta_2 = (S_b / N_b) * (n_i * beta_i).sum()
    L_beta_T = torch.sum(L_beta_1 + L_beta_2)
    return L_beta_T

    # implement the code here!

# call the function and get a return value
background_loss(y_instance, beta_i )

tensor(0.4914)