# Develop Fixed-radius NN Linear-time Algorithm

Explore the techniques described in [this paper](https://reader.elsevier.com/reader/sd/pii/0020019077900709?token=E45C0E1870EA26C21C1F149B6090CE4630A51269D324BE1206B7BF2764FB48B2DDC93F4B86FBFBD8CBDED63B15BBC6DA&originRegion=us-east-1&originCreation=20210428165528).

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
# System imports
import os
import sys
from time import time as tt
import importlib

# External imports
import matplotlib.pyplot as plt
import scipy as sp
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch_geometric.data import DataLoader

from itertools import chain
from random import shuffle, sample
from scipy.optimize import root_scalar as root

from torch.nn import Linear
import torch.nn.functional as F
from torch_cluster import knn_graph, radius_graph
import trackml.dataset
import torch_geometric
from itertools import permutations
import itertools
from sklearn import metrics, decomposition
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from torch.utils.checkpoint import checkpoint

import faiss

sys.path.append('/global/homes/d/danieltm/ExaTrkX/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example')
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Model and Dataset

Load the lightning module and setup the model to get the dataset

In [5]:
from LightningModules.Embedding.Models.layerless_embedding import LayerlessEmbedding
from LightningModules.Embedding.utils import graph_intersection, build_edges

In [6]:
chkpt_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/lightning_checkpoints/CodaEmbeddingStudy/pbn07koj"
chkpt_file = "last.ckpt"
chkpt_path = os.path.join(chkpt_dir, chkpt_file)

In [7]:
model = LayerlessEmbedding.load_from_checkpoint(chkpt_path)

In [8]:
model.hparams["train_split"] = [100,10,10]

In [9]:
model.setup(stage="fit")

In [10]:
model = model.to(device)

# Packaged and Tested

## Function Definitions

In [22]:
def pca_transform(spatial, pca):
    
    spatial_pca = torch.from_numpy(pca.transform(spatial.cpu())).float().to(device)

    pos_spatial = (spatial_pca - spatial_pca.min(dim=0)[0].T).half()
    half_spatial = spatial.half()
    spatial_ind = torch.arange(len(pos_spatial), device=device).int()
    
    L_box = pos_spatial.max()
    
    return pos_spatial, half_spatial, spatial_ind, L_box

def point_locations(pos_spatial, r_max):
    
    x_cell_ref = (pos_spatial // r_max).int()
    
    return x_cell_ref
    
def build_cell_lookup(r_max, L_box, projection_d):
    
    reshape_dims = [int(L_box // r_max + 1)]*projection_d
    cell_index_length = np.product(reshape_dims)
    cell_lookup = torch.arange(cell_index_length, device=device, dtype=torch.int).reshape(reshape_dims)

    return cell_lookup
    
def build_point_lookup(x_cell_ref, cell_lookup, nb, projection_d):
    
    inclusive_nhood = torch.flatten(torch.stack(torch.meshgrid([torch.tensor([-1, 0])]*projection_d)), start_dim=1).T.to(device)
    nbhood_map = torch.transpose(x_cell_ref.expand(len(inclusive_nhood), nb, projection_d) + torch.transpose(inclusive_nhood.expand(nb, len(inclusive_nhood), projection_d), 1, 0), 0, 1)
    hit_nhood_lookup = cell_lookup[nbhood_map.long().chunk(chunks=projection_d, dim=2)].squeeze()
    hit_lookup = cell_lookup[x_cell_ref.long().chunk(chunks=projection_d, dim=1)].squeeze()
    
    return hit_lookup, hit_nhood_lookup

def find_non_empty_cells(hit_lookup, hit_nhood_lookup):
    
    _, flat_inverses, flat_counts = torch.unique(hit_nhood_lookup, return_inverse=True, return_counts=True)
    non_empty_cells = torch.unique(hit_lookup[(flat_counts[flat_inverses] > 1).any(1)])
    
    return non_empty_cells
    
def run_search(r_query, hit_lookup, hit_nhood_lookup, non_empty_cells, half_spatial, spatial_ind):
    
    all_radius_edges = []
    
    for cell in non_empty_cells:
        x_in_cell = spatial_ind[(hit_lookup == cell)]
        x_in_nhood = spatial_ind[(hit_nhood_lookup == cell).any(1)]
        if len(x_in_nhood)>0:
            all_cell_combinations = torch.flatten(torch.stack(torch.meshgrid(x_in_cell, x_in_nhood)), start_dim=1)
            all_radius_edges.append(all_cell_combinations[:, torch.sum( (half_spatial[all_cell_combinations[0].long()] - half_spatial[all_cell_combinations[1].long()])**2, dim=1) < r_query**2])
    
    return all_radius_edges
    
def postprocess(all_radius_edges):
    
    all_radius_edges = torch.cat(all_radius_edges, dim = 1)
    all_radius_edges = all_radius_edges[:, all_radius_edges[0] != all_radius_edges[1]]
    all_radius_edges = torch.cat([all_radius_edges, all_radius_edges.flip(0)], dim=1)
    
    return all_radius_edges
    
def build_edges_grid(spatial, r_query, pca, r_max=None):
    
    nb = spatial.shape[0] # The number of hits in the event
    d = spatial.shape[1] # The dimension of the embedding space
    projection_d = pca.n_components # The dimension of the PCA projection
    
    if r_max is None:
        r_max = r_query

    # 1. Run PCA transform
    pos_spatial, half_spatial, spatial_ind, L_box = pca_transform(spatial, pca)

    # 2. Get point locations in search grid
    x_cell_ref = point_locations(pos_spatial, r_max)
    
    # 3. Build cell lookup table
    cell_lookup = build_cell_lookup(r_max, L_box, projection_d)

    # 4. Build hit lookup table (basically the inverse of the cell lookup table)
    hit_lookup, hit_nhood_lookup = build_point_lookup(x_cell_ref, cell_lookup, nb, projection_d)

    # 5. Find cells that are not empty
    non_empty_cells = find_non_empty_cells(hit_lookup, hit_nhood_lookup)

    # 6. Run the search loop over each cell in the grid
    all_radius_edges = run_search(r_query, hit_lookup, hit_nhood_lookup, non_empty_cells, half_spatial, spatial_ind)
    
    # 7. Postprocess the edges to make them symmetrical, and remove self-edges
    all_radius_edges = postprocess(all_radius_edges)
    
    return all_radius_edges

## Pretrain PCA

Load an example batch

In [13]:
batch = model.trainset[0].to(device)
with torch.no_grad():
    spatial = model(torch.cat([batch.cell_data, batch.x], axis=-1))

Define a 2D PCA projection to fit

In [14]:
#PCA construction
projection_d = 2
pca = decomposition.PCA(n_components = projection_d)

In [15]:
%%time
pca.fit(spatial.cpu())

CPU times: user 162 ms, sys: 126 ms, total: 289 ms
Wall time: 162 ms


PCA(copy=True, iterated_power='auto', n_components=2, random_state=None,
    svd_solver='auto', tol=0.0, whiten=False)

## Time the Grid Search

Set the `r_cell` (which is the size of the grid spacing) and `r_max` (which is the radius to construct a graph from).

In [16]:
r_cell = 1
r_max = 1

In [26]:
%%time
timelist = []
with torch.no_grad():
    for i in range(100):
        tic = tt()
        batch = model.trainset[i].to(device)
        spatial = model(torch.cat([batch.cell_data, batch.x], axis=-1))
        e_spatial = build_edges_grid(spatial, r_max, pca, r_cell)
        timelist.append(tt() - tic)

CPU times: user 30.7 s, sys: 12.2 s, total: 42.9 s
Wall time: 31.7 s


In [27]:
print(f'Time mean: {np.mean(timelist)} +- {np.std(timelist)}')

Time mean: 0.3173021197319031 +- 0.060290047426455766
