# Develop Fixed-radius NN Linear-time Algorithm

Explore the techniques described in [this paper](https://reader.elsevier.com/reader/sd/pii/0020019077900709?token=E45C0E1870EA26C21C1F149B6090CE4630A51269D324BE1206B7BF2764FB48B2DDC93F4B86FBFBD8CBDED63B15BBC6DA&originRegion=us-east-1&originCreation=20210428165528).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
from time import time as tt
import importlib

# External imports
import matplotlib.pyplot as plt
import scipy as sp
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch_geometric.data import DataLoader

from itertools import chain
from random import shuffle, sample
from scipy.optimize import root_scalar as root

from torch.nn import Linear
import torch.nn.functional as F
from torch_cluster import knn_graph, radius_graph
import trackml.dataset
import torch_geometric
from itertools import permutations
import itertools
from sklearn import metrics, decomposition
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from torch.utils.checkpoint import checkpoint

import faiss

sys.path.append('/global/homes/d/danieltm/ExaTrkX/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example')
device = "cuda" if torch.cuda.is_available() else "cpu"

## Sample Data

In [3]:
spatial = 10*torch.rand(1000, 3).to(device)

## Get Max K

In [85]:
r_max = 1
r_query = 1
nb = spatial.shape[0]
d = spatial.shape[1]

In [86]:
pos_spatial = (spatial - spatial.min(dim=0)[0].T).half()
spatial_ind = torch.arange(len(pos_spatial), device=device).int()
L_box = pos_spatial.max()

In [87]:
x_cell_ref = (pos_spatial // r_max).int()

In [88]:
%%time
unique_cells, counts = x_cell_ref.unique(dim=0, return_counts=True)

CPU times: user 276 µs, sys: 403 µs, total: 679 µs
Wall time: 446 µs


In [103]:
unique_cells.shape

torch.Size([624, 3])

In [89]:
reshape_dims = [int(L_box // r_max + 1)]*d
cell_index_length = np.product(reshape_dims)
cell_lookup = torch.arange(cell_index_length, device=device, dtype=torch.int).reshape(reshape_dims)

In [90]:
inner_cells = torch.flatten(torch.stack(torch.meshgrid([torch.arange(1, unique_cells.max()).to(device)]*d)), start_dim=1).T
print(inner_cells.shape)

torch.Size([512, 3])


In [91]:
inclusive_nhood = torch.flatten(torch.stack(torch.meshgrid([torch.tensor([-1, 0, 1])]*d)), start_dim=1).T.to(device)
nbhood_map = torch.transpose(inner_cells.expand(len(inclusive_nhood), len(inner_cells), d) + torch.transpose(inclusive_nhood.expand(len(inner_cells), len(inclusive_nhood), d), 1, 0), 0, 1)

In [92]:
cell_nhood_lookup = cell_lookup[nbhood_map.long().chunk(chunks=d, dim=2)].squeeze()

In [97]:
unique_lookup = cell_lookup[unique_cells.long().chunk(chunks=d, dim=1)]

In [98]:
cell_nhood_lookup.shape

torch.Size([512, 27])

In [118]:
reverse_lookup = torch.zeros(cell_nhood_lookup.max()+1).long()

In [120]:
reverse_lookup[unique_lookup.long().squeeze(-1)] = torch.arange(len(unique_lookup))

In [121]:
reverse_lookup

tensor([  0,   1,   2,   0,   3,   4,   5,   0,   0,   6,   7,   8,   0,   0,
          9,  10,  11,  12,  13,  14,   0,   0,   0,  15,  16,   0,  17,  18,
         19,   0,   0,   0,   0,   0,  20,  21,  22,  23,  24,   0,  25,  26,
         27,  28,  29,  30,  31,  32,  33,  34,  35,  36,   0,  37,  38,  39,
         40,  41,  42,  43,  44,   0,  45,  46,  47,  48,  49,   0,  50,   0,
          0,  51,  52,  53,   0,  54,   0,   0,  55,  56,  57,  58,   0,   0,
         59,  60,  61,  62,  63,   0,  64,   0,  65,  66,  67,  68,   0,  69,
          0,  70,  71,  72,   0,  73,  74,  75,  76,  77,  78,   0,  79,   0,
         80,  81,  82,  83,  84,   0,  85,   0,   0,  86,  87,  88,   0,   0,
         89,  90,   0,   0,  91,   0,  92,   0,   0,  93,  94,  95,  96,   0,
          0,  97,   0,  98,  99,   0, 100, 101,   0,   0,   0, 102,   0, 103,
          0,   0,   0, 104,   0, 105,   0, 106, 107,   0,   0, 108, 109, 110,
        111,   0, 112, 113, 114, 115,   0,   0, 116, 117, 118, 1

In [115]:
cell_nhood_lookup.max()

tensor(999, device='cuda:0', dtype=torch.int32)

In [117]:
unique_lookup.max()

tensor(997, device='cuda:0', dtype=torch.int32)

In [122]:
reverse_lookup.shape

torch.Size([1000])

In [123]:
counts[reverse_lookup[cell_nhood_lookup[0].long()]]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 3, 1, 1, 2, 1, 1, 2, 1, 2, 1,
        1, 1, 4], device='cuda:0')

In [126]:
counts[reverse_lookup[cell_nhood_lookup.long()]].sum(1).max()

tensor(47, device='cuda:0')

^^ This is the maximum choice of K ^^