/
radius_graph.py
62 lines (50 loc) · 1.66 KB
/
radius_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import Union
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
def radius_graph(
pos: Union[e3nn.IrrepsArray, jax.Array],
r_max: float,
*,
batch: jax.Array = None,
size: int = None,
loop: bool = False,
fill_src: int = -1,
fill_dst: int = -1,
):
r"""Try to use ``matscipy.neighbours.neighbour_list`` instead.
Args:
pos (`jax.Array`): array of shape ``(n, 3)``
r_max (float):
batch (`jax.Array`): indices
size (int): size of the output
loop (bool): whether to include self-loops
Returns:
(tuple): tuple containing:
jax.Array: source indices
jax.Array: destination indices
Examples:
>>> key = jax.random.PRNGKey(0)
>>> pos = jax.random.normal(key, (20, 3))
>>> batch = jnp.arange(20) < 10
>>> radius_graph(pos, 0.8, batch=batch)
(Array([ 3, 7, 10, 11, 12, 18], dtype=int32), Array([ 7, 3, 11, 10, 18, 12], dtype=int32))
"""
# TODO(mario): replace with the function made for Allan once the project is finished
if isinstance(pos, e3nn.IrrepsArray):
pos = pos.array
r = jax.vmap(
jax.vmap(lambda x, y: jnp.linalg.norm(x - y), (None, 0), 0), (0, None), 0
)(pos, pos)
if loop:
mask = r < r_max
else:
mask = (r < r_max) & (r > 0)
src, dst = jnp.where(mask, size=size, fill_value=-1)
if fill_src != -1:
src = jnp.where(src == -1, fill_src, src)
if fill_dst != -1:
dst = jnp.where(dst == -1, fill_dst, dst)
if batch is None:
return src, dst
return src[batch[src] == batch[dst]], dst[batch[src] == batch[dst]]