In [57]:
import jax
import jax.numpy as jnp
import functools as ft

def loop_body(i, carry):
    centers, edges_to_nef, nef_to_edges_neighbor, nef_mask, node_counter = carry
    center = centers[i]
    edges_to_nef = edges_to_nef.at[center, node_counter[center]].set(i)
    nef_mask = nef_mask.at[center, node_counter[center]].set(True)
    nef_to_edges_neighbor = nef_to_edges_neighbor.at[i].set(node_counter[center])
    node_counter = node_counter.at[center].add(1)
    return centers, edges_to_nef, nef_to_edges_neighbor, nef_mask, node_counter

@ft.partial(jax.jit, static_argnums=(1, 2))
def get_nef_indices(centers, n_nodes: int, n_edges_per_node: int):
    int_dtype = jnp.int64 if jax.config.jax_enable_x64 else jnp.int32
    n_edges = len(centers)
    edges_to_nef = jnp.zeros((n_nodes, n_edges_per_node), dtype=int_dtype)
    nef_to_edges_neighbor = jnp.empty((n_edges,), dtype=int_dtype)
    node_counter = jnp.zeros((n_nodes,), dtype=int_dtype)
    nef_mask = jnp.full((n_nodes, n_edges_per_node), False, dtype=bool)
    return jax.lax.fori_loop(
        0, n_edges, loop_body, (centers, edges_to_nef, nef_to_edges_neighbor, nef_mask, node_counter)
    )[1:4]

In [58]:
centers = jnp.array([0, 4, 3, 1, 0, 0, 3, 3, 3, 4])
nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices(centers, 5, 3)

In [59]:
def edge_array_to_nef(edge_array, nef_indices, mask=None, fill_value=0.0):
    if mask is None:
        return edge_array[nef_indices]
    else:
        return jnp.where(mask, edge_array[nef_indices], fill_value)

def nef_array_to_edges(nef_array, centers, nef_to_edges_neighbor):
    return nef_array[centers, nef_to_edges_neighbor]

In [60]:
nef_centers = edge_array_to_nef(centers, nef_indices)
nef_centers

Array([[0, 0, 0],
       [1, 0, 0],
       [0, 0, 0],
       [3, 3, 3],
       [4, 4, 0]], dtype=int32)

In [61]:
nef_array_to_edges(nef_centers, centers, nef_to_edges_neighbor)

Array([0, 4, 3, 1, 0, 0, 3, 3, 3, 4], dtype=int32)