In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.data import Batch
from torch_cluster import knn_graph

In [2]:
dataset = QM9("../data/qm9")
batch_size = 2
batch = Batch.from_data_list(list(dataset[:batch_size]))
batch

DataBatch(x=[9, 11], edge_index=[2, 14], edge_attr=[14, 4], y=[2, 19], pos=[9, 3], idx=[2], name=[2], z=[9], batch=[9], ptr=[3])

In [3]:
batch.pos, batch.batch

(tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
         [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
         [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
         [-5.4080e-01,  1.4475e+00, -8.7660e-01],
         [-5.2380e-01,  1.4379e+00,  9.0640e-01],
         [-4.0400e-02,  1.0241e+00,  6.2600e-02],
         [ 1.7300e-02,  1.2500e-02, -2.7400e-02],
         [ 9.1580e-01,  1.3587e+00, -2.8800e-02],
         [-5.2030e-01,  1.3435e+00, -7.7550e-01]]),
 tensor([0, 0, 0, 0, 0, 1, 1, 1, 1]))

In [4]:
def knn_graph_static(x, k, batch, cutoff=6.0):
    r_ij = torch.norm(x[:, None] - x, dim=-1)
    rows = torch.arange(batch.shape[0]).view(-1, 1)
    cols = rows.view(1, -1)
    mask = (rows == cols) | (batch.view(-1, 1) != batch.view(1, -1))
    r_ij = r_ij.masked_fill(mask, cutoff)
    r_ij = r_ij.clamp(max=cutoff)
    vals, indices = torch.topk(r_ij, k=k, largest=False)
    print(vals)
    print(indices)
    rows = rows.expand_as(indices)
    edge_index = torch.vstack([indices.flatten(), rows.flatten()])
    return edge_index

In [5]:
k = 3
torch.allclose(
    knn_graph(batch.pos, k, batch.batch), knn_graph_static(batch.pos, k, batch.batch)
)

tensor([[1.0919, 1.0919, 1.0919],
        [1.0919, 1.7831, 1.7831],
        [1.0919, 1.7831, 1.7831],
        [1.0919, 1.7831, 1.7831],
        [1.0919, 1.7831, 1.7831],
        [1.0172, 1.0172, 1.0172],
        [1.0172, 1.6185, 1.6187],
        [1.0172, 1.6185, 1.6187],
        [1.0172, 1.6187, 1.6187]])
tensor([[3, 1, 4],
        [0, 2, 3],
        [0, 1, 4],
        [0, 1, 4],
        [0, 2, 1],
        [7, 8, 6],
        [5, 7, 8],
        [5, 6, 8],
        [5, 7, 6]])


True

In [6]:
knn_graph_static(batch.pos, 4, batch.batch, 30.0)

tensor([[ 1.0919,  1.0919,  1.0919,  1.0919],
        [ 1.0919,  1.7831,  1.7831,  1.7831],
        [ 1.0919,  1.7831,  1.7831,  1.7831],
        [ 1.0919,  1.7831,  1.7831,  1.7831],
        [ 1.0919,  1.7831,  1.7831,  1.7831],
        [ 1.0172,  1.0172,  1.0172, 30.0000],
        [ 1.0172,  1.6185,  1.6187, 30.0000],
        [ 1.0172,  1.6185,  1.6187, 30.0000],
        [ 1.0172,  1.6187,  1.6187, 30.0000]])
tensor([[3, 1, 4, 2],
        [0, 2, 3, 4],
        [0, 1, 4, 3],
        [0, 1, 4, 2],
        [0, 2, 1, 3],
        [7, 8, 6, 4],
        [5, 7, 8, 4],
        [5, 6, 8, 7],
        [5, 7, 6, 4]])


tensor([[3, 1, 4, 2, 0, 2, 3, 4, 0, 1, 4, 3, 0, 1, 4, 2, 0, 2, 1, 3, 7, 8, 6, 4,
         5, 7, 8, 4, 5, 6, 8, 7, 5, 7, 6, 4],
        [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,
         6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8]])

In [7]:
knn_graph(batch.pos, k, batch.batch)

tensor([[3, 1, 4, 0, 2, 3, 0, 1, 4, 0, 1, 4, 0, 2, 1, 7, 8, 6, 5, 7, 8, 5, 6, 8,
         5, 7, 6],
        [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7,
         8, 8, 8]])

In [8]:
batch.batch

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1])

In [9]:
def fully_connected(num_nodes):
    node_ids = torch.arange(0, num_nodes)
    all_edges = torch.cartesian_prod(node_ids, node_ids)

    splits = [1, num_nodes] * (num_nodes - 1)
    splits.append(1)
    splits = torch.split(torch.arange(num_nodes * num_nodes), splits)
    keepers = splits[1:-1:2]
    return all_edges[torch.hstack(keepers)].t()

In [24]:
x = batch.pos
num_nodes = x.shape[0]
u = x.unsqueeze(1)
u = u.repeat(1, num_nodes, 1)

v = x.unsqueeze(0)
v = v.repeat(num_nodes, 1, 1)
torch.sum((u - v).pow(2), dim=-1).sqrt()

tensor([[0.0000, 1.0919, 1.0919, 1.0919, 1.0919, 0.0869, 1.0743, 0.9685, 0.9685],
        [1.0919, 0.0000, 1.7831, 1.7831, 1.7831, 1.0328, 0.0379, 1.6426, 1.6428],
        [1.0919, 1.7831, 0.0000, 1.7831, 1.7831, 1.1420, 1.7595, 0.1452, 1.7214],
        [1.0919, 1.7831, 1.7831, 0.0000, 1.7831, 1.1453, 1.7584, 1.6877, 0.1465],
        [1.0919, 1.7831, 1.7831, 1.7831, 0.0000, 1.0568, 1.7879, 1.7185, 1.6846],
        [0.0869, 1.0328, 1.1420, 1.1453, 1.0568, 0.0000, 1.0172, 1.0172, 1.0172],
        [1.0743, 0.0379, 1.7595, 1.7584, 1.7879, 1.0172, 0.0000, 1.6185, 1.6187],
        [0.9685, 1.6426, 0.1452, 1.6877, 1.7185, 1.0172, 1.6185, 0.0000, 1.6187],
        [0.9685, 1.6428, 1.7214, 0.1465, 1.6846, 1.0172, 1.6187, 1.6187, 0.0000]])

In [25]:
torch.norm(x - x[:, None], dim=-1)

tensor([[0.0000, 1.0919, 1.0919, 1.0919, 1.0919, 0.0869, 1.0743, 0.9685, 0.9685],
        [1.0919, 0.0000, 1.7831, 1.7831, 1.7831, 1.0328, 0.0379, 1.6426, 1.6428],
        [1.0919, 1.7831, 0.0000, 1.7831, 1.7831, 1.1420, 1.7595, 0.1452, 1.7214],
        [1.0919, 1.7831, 1.7831, 0.0000, 1.7831, 1.1453, 1.7584, 1.6877, 0.1465],
        [1.0919, 1.7831, 1.7831, 1.7831, 0.0000, 1.0568, 1.7879, 1.7185, 1.6846],
        [0.0869, 1.0328, 1.1420, 1.1453, 1.0568, 0.0000, 1.0172, 1.0172, 1.0172],
        [1.0743, 0.0379, 1.7595, 1.7584, 1.7879, 1.0172, 0.0000, 1.6185, 1.6187],
        [0.9685, 1.6426, 0.1452, 1.6877, 1.7185, 1.0172, 1.6185, 0.0000, 1.6187],
        [0.9685, 1.6428, 1.7214, 0.1465, 1.6846, 1.0172, 1.6187, 1.6187, 0.0000]])

In [28]:
u.flatten().view(-1, 1) @ torch.ones(u.numel()).view(1, -1)

tensor([[-0.0127, -0.0127, -0.0127,  ..., -0.0127, -0.0127, -0.0127],
        [ 1.0858,  1.0858,  1.0858,  ...,  1.0858,  1.0858,  1.0858],
        [ 0.0080,  0.0080,  0.0080,  ...,  0.0080,  0.0080,  0.0080],
        ...,
        [-0.5203, -0.5203, -0.5203,  ..., -0.5203, -0.5203, -0.5203],
        [ 1.3435,  1.3435,  1.3435,  ...,  1.3435,  1.3435,  1.3435],
        [-0.7755, -0.7755, -0.7755,  ..., -0.7755, -0.7755, -0.7755]])

In [30]:
u[0, 0]

tensor([-0.0127,  1.0858,  0.0080])

In [12]:
import torch

torch.manual_seed(0)
M = 16
N = 4
L = 8

out = torch.zeros(L, N)
src = torch.randn(M, N)
index = torch.randint(0, L, (M,))

out.index_add(dim=0, index=index, source=src)

tensor([[-0.7931, -0.3459,  2.3242,  0.9874],
        [ 0.0766,  3.0462,  0.0521,  1.8238],
        [-1.2089, -2.4975,  2.1697, -0.2434],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.1149, -0.4624, -0.4037, -0.6049],
        [ 2.1396,  1.0008,  0.7729, -0.6602],
        [-0.6136,  0.0316, -0.4927,  0.2484],
        [-1.7989, -0.2795,  0.8048, -0.2560]])

In [20]:
out = torch.zeros(L, N)
broadcasted_index = index.view(-1, 1).expand_as(src)
for i in range(M):
    for j in range(N):
        out[broadcasted_index[i, j], j] += src[i, j]

out

tensor([[-0.7931, -0.3459,  2.3242,  0.9874],
        [ 0.0766,  3.0462,  0.0521,  1.8238],
        [-1.2089, -2.4975,  2.1697, -0.2434],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.1149, -0.4624, -0.4037, -0.6049],
        [ 2.1396,  1.0008,  0.7729, -0.6602],
        [-0.6136,  0.0316, -0.4927,  0.2484],
        [-1.7989, -0.2795,  0.8048, -0.2560]])

In [21]:
out = torch.zeros(L, N)
for i in range(M):
    out[index[i], :] += src[i, :]

out

tensor([[-0.7931, -0.3459,  2.3242,  0.9874],
        [ 0.0766,  3.0462,  0.0521,  1.8238],
        [-1.2089, -2.4975,  2.1697, -0.2434],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.1149, -0.4624, -0.4037, -0.6049],
        [ 2.1396,  1.0008,  0.7729, -0.6602],
        [-0.6136,  0.0316, -0.4927,  0.2484],
        [-1.7989, -0.2795,  0.8048, -0.2560]])