# Benchmarking FRNN Performance

Explore the techniques described in [this paper](https://reader.elsevier.com/reader/sd/pii/0020019077900709?token=E45C0E1870EA26C21C1F149B6090CE4630A51269D324BE1206B7BF2764FB48B2DDC93F4B86FBFBD8CBDED63B15BBC6DA&originRegion=us-east-1&originCreation=20210428165528).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
from time import time as tt
import importlib

# External imports
import matplotlib.pyplot as plt
import scipy as sp
import numpy as np
import pandas as pd
import seaborn as sns
import torch
# from torch_geometric.data import DataLoader

from itertools import chain
from random import shuffle, sample
from scipy.optimize import root_scalar as root

from torch.nn import Linear
import torch.nn.functional as F
import trackml.dataset
from itertools import permutations
import itertools
from sklearn import metrics, decomposition
from torch.utils.checkpoint import checkpoint

import faiss

device = "cuda" if torch.cuda.is_available() else "cpu"

exatrkx_pipeline = '/global/homes/d/danieltm/ExaTrkX/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example'
sys.path.append(exatrkx_pipeline)

from LightningModules.Embedding.utils import graph_intersection, build_edges

# Load Model and Dataset

Load the lightning module and setup the model to get the dataset

CHANGE THE EXATRKX PIPELINE LOCATION TO YOUR PARTICULAR LOCATION

In [None]:
exatrkx_pipeline = '/global/homes/d/danieltm/ExaTrkX/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example'
sys.path.append(exatrkx_pipeline)

[autoreload of torch.serialization failed: Traceback (most recent call last):
  File "/global/homes/d/danieltm/.conda/envs/frnn-test/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/global/homes/d/danieltm/.conda/envs/frnn-test/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/global/homes/d/danieltm/.conda/envs/frnn-test/lib/python3.8/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/global/homes/d/danieltm/.conda/envs/frnn-test/lib/python3.8/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 604, in _exec
  File "<frozen importlib._bootstrap_external>", line 848, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/global/homes/d/danieltm/.conda/envs/frnn-test/lib/python3.8/site-p

In [None]:
from LightningModules.Embedding.Models.layerless_embedding import LayerlessEmbedding
from LightningModules.Embedding.utils import graph_intersection, build_edges

## Load model for inference (if data not already saved)

CHANGE THE CHECKPOINT DIRECTORY TO YOUR PARTICULAR LOCATION

In [4]:
chkpt_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/lightning_checkpoints/CodaEmbeddingStudy/pbn07koj"
chkpt_file = "last.ckpt"
chkpt_path = os.path.join(chkpt_dir, chkpt_file)

In [5]:
model = LayerlessEmbedding.load_from_checkpoint(chkpt_path)

In [6]:
model.hparams["train_split"] = [100,10,10]

In [7]:
model.setup(stage="fit")

In [8]:
model = model.to(device)

# FRNN Testing

## Prepare and save input data

In [82]:
batch = model.trainset[0].to(device)
with torch.no_grad():
    spatial = model(torch.cat([batch.cell_data, batch.x], axis=-1))

In [87]:
batch.spatial = spatial
torch.save(batch, "example_event.pkl")

## Load saved input data

In [18]:
batch = torch.load("example_event.pkl")

In [34]:
spatial = batch.spatial

## Run FRNN algorithm

In [36]:
import frnn

In [126]:
r = 2.
K = 1000

In [127]:
%%time
# first time there is no cached grid
dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=None, return_nn=False, return_sorted=True
)

PC2 Grid Off: tensor([[    0,     0,     0,  ..., 85374, 85374, 85374]], device='cuda:0',
       dtype=torch.int32)
CPU times: user 1.55 ms, sys: 1.39 ms, total: 2.95 ms
Wall time: 2.49 ms


In [128]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=grid, return_nn=False, return_sorted=True
)

# Remove the unneccessary batch dimension
idxs = idxs.squeeze()

ind = torch.Tensor.repeat(torch.arange(idxs.shape[0], device=device), (idxs.shape[1], 1), 1).T
positive_idxs = idxs >= 0
edge_list = torch.stack([ind[positive_idxs], idxs[positive_idxs]])

# Remove self-loops
e_spatial = edge_list[:, edge_list[0] != edge_list[1]]

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

PC2 Grid Off: tensor([[    0,     0,     0,  ..., 85374, 85374, 85374]], device='cuda:0',
       dtype=torch.int32)
Time taken: 50.731ms


## Accuracy Performance

Find truth

In [129]:
e_bidir = torch.cat([batch.layerless_true_edges, batch.layerless_true_edges.flip(0)], axis=-1) 
e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir, using_weights=False)

In [130]:
t = e_bidir.shape[1]
tp = y_cluster.sum()
p = e_spatial.shape[1]

In [131]:
print(f'Efficiency: {tp / t}, Purity: {tp / p}')

Efficiency: 0.9823640584945679, Purity: 0.010063880123198032


# FAISS Compare

In [74]:
import faiss
import torch
import faiss.contrib.torch_utils

In [113]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

res = faiss.StandardGpuResources()
D, I = faiss.knn_gpu(res, spatial, spatial, K)

ind = torch.Tensor.repeat(
    torch.arange(I.shape[0], device=device), (I.shape[1], 1), 1
).T
edge_truth = torch.stack([ind[D <= r ** 2], I[D <= r ** 2]])

# Remove self-loops
e_spatial = edge_truth[:, edge_truth[0] != edge_truth[1]]

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

Time taken: 386.58ms


# Tweaking FRNN Library

## Generate Data

In [6]:
import frnn

In [7]:
N = 100000
d = 8
spatial = torch.rand(N, d).to(device)

In [8]:
r = 0.4
K = 500

### Sorted

In [73]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

# first time there is no cached grid
dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=None, return_nn=False, return_sorted=True
)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

Time taken: 1312.5ms


In [74]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=grid, return_nn=False, return_sorted=True
)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

Time taken: 1291.1ms


### Unsorted

In [6]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

# first time there is no cached grid
dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=None, return_nn=False, return_sorted=True
)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

Time taken: 216.5ms


In [7]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=grid, return_nn=False, return_sorted=True
)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

Time taken: 202.67ms


### No KNN

In [69]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

# first time there is no cached grid
dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=None, return_nn=False, return_sorted=True
)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

Time taken: 55.207ms


In [70]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=grid, return_nn=False, return_sorted=True
)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

Time taken: 54.697ms


# FAISS vs FRNN

### FRNN

In [13]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

# first time there is no cached grid
dists, idxs, nn, grid = frnn.frnn_grid_points(
    points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=K, r=r, grid=None, return_nn=False, return_sorted=True
)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

PC2 Grid Off: tensor([[    0,   769,  1561,  2373,  3164,  3958,  4683,  5453,  6253,  7067,
          7880,  8689,  9444, 10262, 11047, 11816, 12575, 13327, 14099, 14877,
         15662, 16505, 17311, 18184, 19016, 19807, 20544, 21323, 22165, 22971,
         23740, 24534, 25293, 26078, 26863, 27658, 28468, 29296, 30101, 30956,
         31747, 32525, 33307, 34093, 34981, 35734, 36558, 37289, 38057, 38859,
         39676, 40519, 41318, 42167, 42997, 43767, 44595, 45380, 46197, 47001,
         47777, 48581, 49401, 50200, 51001, 51836, 52624, 53430, 54179, 54951,
         55744, 56565, 57368, 58163, 59017, 59798, 60656, 61466, 62262, 63047,
         63868, 64674, 65527, 66332, 67143, 67956, 68743, 69583, 70415, 71238,
         72072, 72897, 73738, 74565, 75324, 76108, 76925, 77746, 78570, 79359,
         80136, 80945, 81779, 82566, 83325, 84087, 84894, 85681, 86496, 87305,
         88144, 88976, 89807, 90596, 91377, 92150, 92954, 93741, 94585, 95338,
         96077, 96826, 97637, 98431, 9

In [14]:
%%time
idxs = idxs.squeeze()
ind = torch.Tensor.repeat(torch.arange(idxs.shape[0], device=device), (idxs.shape[1], 1), 1).T
positive_idxs = idxs >= 0
edge_list = torch.stack([ind[positive_idxs], idxs[positive_idxs]])

# Remove self-loops
edge_pred = edge_list[:, edge_list[0] != edge_list[1]]

CPU times: user 4.15 ms, sys: 4.22 ms, total: 8.37 ms
Wall time: 6.86 ms


### FAISS

In [24]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

res = faiss.StandardGpuResources()
D, I = faiss.knn_gpu(res, spatial, spatial, K)

torch.cuda.synchronize()
end.record()
time = start.elapsed_time(end)
print(f"Time taken: {time:.5}ms")

Time taken: 512.49ms


In [25]:
%%time
ind = torch.Tensor.repeat(
    torch.arange(I.shape[0], device=device), (I.shape[1], 1), 1
).T
edge_truth = torch.stack([ind[D <= r ** 2], I[D <= r ** 2]])

# Remove self-loops
edge_truth = edge_truth[:, edge_truth[0] != edge_truth[1]]

CPU times: user 9.82 ms, sys: 7.63 ms, total: 17.5 ms
Wall time: 17.6 ms


### Compare Perf

In [26]:
edge_pred, y_cluster = graph_intersection(edge_pred, edge_truth, using_weights=False)

In [27]:
t = edge_truth.shape[1]
tp = y_cluster.sum().item()
p = edge_pred.shape[1]

In [28]:
print(f'Efficiency: {tp / t}, Purity: {tp / p}')

Efficiency: 0.9999970670879795, Purity: 0.999998166927971
