In [None]:
import numpy as np

points = np.random.rand(100, 2)

In [None]:
from scipy.spatial import KDTree
import matplotlib.pyplot as plt
import jax.numpy as jnp

MAX_NEIGHBORS = 20
SEARCH_RADIUS = 0.1
POINT = np.array([0.5, 0.5])

tree = KDTree(points)
neighbor_indexes = tree.query_ball_tree(tree, r=SEARCH_RADIUS)
neighborhood_matrix = -np.ones((len(points), MAX_NEIGHBORS), dtype=int)
for i, neighbors in enumerate(neighbor_indexes):
    neighborhood_matrix[i, :min(MAX_NEIGHBORS, len(neighbors))] = neighbors[:MAX_NEIGHBORS]
neighborhood_matrix = jnp.array(neighborhood_matrix)

In [None]:
neighborhood_matrix

In [None]:
# Alternatively create a grid of bins of size SEARCH_RADIUS
DOMAIN_SIZE = 1.0
GRID_SIZE = int(DOMAIN_SIZE / SEARCH_RADIUS) + 1
grid = np.zeros((GRID_SIZE, GRID_SIZE), dtype=object)
for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        grid[i, j] = []

for p in points:
    grid_x = int(p[0] // SEARCH_RADIUS)
    grid_y = int(p[1] // SEARCH_RADIUS)
    grid[grid_x, grid_y].append(p)

# Find the bin of the query point
for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        random_color = np.random.rand(3,)
        cell_points = grid[i, j]
        if len(cell_points) > 0:
            plt.scatter(np.array(cell_points)[:, 0], np.array(cell_points)[:, 1], c=random_color, s=5, alpha=0.7)

plt.show()

In [None]:
def get_neighborhood_mask(points, domain_size, search_radius):
    # 1st, create the cells of the grid
    grid_size = int(domain_size / search_radius) + 1
    grid = np.zeros((grid_size, grid_size), dtype=object)
    for i in range(grid_size):
        for j in range(grid_size):
            grid[i, j] = []

    point_cell_map = {}
    for idx, p in enumerate(points):
        grid_x = int(p[0] // search_radius)
        grid_y = int(p[1] // search_radius)
        grid[grid_x, grid_y].append(idx)
        point_cell_map[idx] = (grid_x, grid_y)

    for i in range(grid_size):
        for j in range(grid_size):
            grid[i, j] = np.array(grid[i, j])

    # Compute neighborhood mask
    neighborhood_mask = np.zeros((len(points), len(points)), dtype=bool)
    for idx, p in enumerate(points):
        cell_x, cell_y = point_cell_map[idx]
        for i in range(cell_x - 1, cell_x + 2):
            for j in range(cell_y - 1, cell_y + 2):
                # Add points in the neighboring cells
                neighbors = grid[i, j]
                if len(neighbors) == 0:
                    continue
                neighborhood_mask[idx, neighbors] = True

    return neighborhood_mask

neighborhood_mask = get_neighborhood_mask(points, DOMAIN_SIZE, SEARCH_RADIUS)
NEIGHBOR_INDEX = 1
plt.scatter(points[:, 0], points[:, 1], c='blue', label='Points', s=5, alpha=0.7)
plt.scatter(points[neighborhood_mask[NEIGHBOR_INDEX], 0], points[neighborhood_mask[NEIGHBOR_INDEX], 1], c='green', label='Neighbors', s=5, alpha=0.7)
plt.scatter(points[NEIGHBOR_INDEX, 0], points[NEIGHBOR_INDEX, 1], c='red', label='Query Point')