# Benchmarking FRNN Performance

Explore the techniques described in [this paper](https://reader.elsevier.com/reader/sd/pii/0020019077900709?token=E45C0E1870EA26C21C1F149B6090CE4630A51269D324BE1206B7BF2764FB48B2DDC93F4B86FBFBD8CBDED63B15BBC6DA&originRegion=us-east-1&originCreation=20210428165528).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
from time import time as tt
import importlib

# External imports
import matplotlib.pyplot as plt
import scipy as sp
import numpy as np
import pandas as pd
import seaborn as sns
import torch
# from torch_geometric.data import DataLoader

from itertools import chain
from random import shuffle, sample
from scipy.optimize import root_scalar as root

from torch.nn import Linear
import torch.nn.functional as F
import trackml.dataset
from itertools import permutations
import itertools
from sklearn import metrics, decomposition
from torch.utils.checkpoint import checkpoint

import faiss

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Model and Dataset

Load the lightning module and setup the model to get the dataset

CHANGE THE EXATRKX PIPELINE LOCATION TO YOUR PARTICULAR LOCATION

In [None]:
exatrkx_pipeline = '/global/homes/d/danieltm/ExaTrkX/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example'
sys.path.append(exatrkx_pipeline)

In [3]:
from LightningModules.Embedding.Models.layerless_embedding import LayerlessEmbedding
from LightningModules.Embedding.utils import graph_intersection, build_edges

## Load model for inference (if data not already saved)

CHANGE THE CHECKPOINT DIRECTORY TO YOUR PARTICULAR LOCATION

In [4]:
chkpt_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/lightning_checkpoints/CodaEmbeddingStudy/pbn07koj"
chkpt_file = "last.ckpt"
chkpt_path = os.path.join(chkpt_dir, chkpt_file)

In [5]:
model = LayerlessEmbedding.load_from_checkpoint(chkpt_path)

In [6]:
model.hparams["train_split"] = [100,10,10]

In [7]:
model.setup(stage="fit")

In [8]:
model = model.to(device)

# FRNN Testing

## Prepare and save input data

In [82]:
batch = model.trainset[0].to(device)
with torch.no_grad():
    spatial = model(torch.cat([batch.cell_data, batch.x], axis=-1))

In [87]:
batch.spatial = spatial
torch.save(batch, "example_event.pkl")

## Load saved input data

In [88]:
batch = torch.load("example_event.pkl")

In [89]:
spatial = batch.spatial

## Run FRNN algorithm

In [60]:
import frnn

In [80]:
r = 0.3

In [81]:
%%time
# first time there is no cached grid
dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=32, r=r, grid=None, return_nn=False, return_sorted=True
)

ValueError: for now only grid in 2D/3D is supported

In [63]:
%%time
dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=32, r=r, grid=grid, return_nn=False, return_sorted=True
)

CPU times: user 1.53 ms, sys: 0 ns, total: 1.53 ms
Wall time: 1.09 ms


In [64]:
# Remove the unneccessary batch dimension
idxs = idxs.squeeze()

Convert the array of IDs into an edge list

In [75]:
ind = torch.Tensor.repeat(torch.arange(idxs.shape[0], device=device), (idxs.shape[1], 1), 1).T
positive_idxs = idxs >= 0
edge_list = torch.stack([ind[positive_idxs], idxs[positive_idxs]])

# Remove self-loops
e_spatial = edge_list[:, edge_list[0] != edge_list[1]]

## Accuracy Performance

Find truth

In [76]:
e_bidir = torch.cat([batch.layerless_true_edges, batch.layerless_true_edges.flip(0)], axis=-1) 
e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir, using_weights=False)

In [77]:
t = e_bidir.shape[1]
tp = y_cluster.sum()
p = e_spatial.shape[1]

In [78]:
print(f'Efficiency: {tp / t}, Purity: {tp / p}')

Efficiency: 0.027304455637931824, Purity: 0.0021169965621083975


# FAISS Compare

In [8]:
import faiss
import torch
import faiss.contrib.torch_utils

In [None]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

res = faiss.StandardGpuResources()
D, I = faiss.knn_gpu(res, spatial, spatial, K)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

# Tweaking FRNN Library

## Generate Data

In [5]:
import frnn

In [26]:
N = 100000
d = 3
spatial = torch.rand(N, d).to(device)

In [30]:
r = 0.3
K = 1000

# Hacking Library

In [31]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

# first time there is no cached grid
dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=None, return_nn=False, return_sorted=True
)

# Remove the unneccessary batch dimension
idxs = idxs.squeeze()

ind = torch.Tensor.repeat(torch.arange(idxs.shape[0], device=device), (idxs.shape[1], 1), 1).T
positive_idxs = idxs >= 0
edge_list = torch.stack([ind[positive_idxs], idxs[positive_idxs]])

# Remove self-loops
e_spatial = edge_list[:, edge_list[0] != edge_list[1]]

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

PC2 Grid Count: tensor([[313, 332, 327, 347, 345, 317, 220, 312, 363, 373, 327, 340, 326, 212,
         360, 319, 334, 324, 354, 341, 227, 300, 324, 326, 318, 346, 333, 237,
         332, 363, 327, 335, 320, 324, 233, 336, 342, 352, 363, 364, 325, 228,
         240, 236, 205, 221, 239, 222, 150, 330, 349, 332, 326, 366, 318, 220,
         322, 321, 357, 330, 334, 347, 238, 323, 308, 370, 375, 350, 329, 240,
         326, 348, 321, 340, 363, 310, 207, 367, 368, 342, 303, 372, 322, 224,
         345, 325, 350, 334, 287, 322, 208, 254, 224, 216, 236, 211, 221, 136,
         346, 340, 308, 338, 358, 354, 222, 342, 339, 366, 327, 354, 319, 235,
         351, 327, 322, 341, 326, 314, 238, 343, 356, 315, 333, 312, 344, 273,
         357, 306, 322, 367, 342, 336, 223, 333, 364, 358, 332, 329, 304, 217,
         226, 240, 246, 220, 225, 221, 146, 344, 357, 302, 352, 344, 335, 228,
         361, 338, 343, 327, 337, 358, 231, 377, 358, 341, 349, 352, 325, 240,
         338, 357, 327, 335, 329, 32

TODO: Can we estimate worst-case from grid in 8 dimensions, assign K to that worst-case, then scan in 3 dimensions? This would give the most accurate K, while only running in low-D time.

In [32]:
e_spatial.shape

torch.Size([2, 99996773])

In [29]:
e_spatial.shape

torch.Size([2, 9999900])