In [1]:
import torch

In [14]:
B, N, K, C = 1, 16, 4, 320

h_V = torch.randn(B, N, C)
h_E = torch.randn(B, N, K, C)
E_idx = torch.randint(0, N, (B, N, K))




In [15]:
def gather_nodes(nodes: torch.Tensor, neighbor_idx: torch.Tensor) -> torch.Tensor:
    """
    Gather node features from a neighbor index.

    Parameters
    ----------
    nodes: torch.Tensor
        Node features with shape [B, N, C].

    neighbor_idx: torch.Tensor
        Neighbor indices with shape [B, N, K].

    Returns
    -------
    neighbor_features: torch.Tensor
        Gathered neighbor features with shape [B, N, K, C].

    """
    # flatten neighbor indices [B, N, K] => [B, NK]
    neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1))
    # expand flattened indices [B, NK] => [B, NK, C]
    neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
    # gather nodes [B, N, C] => [B, NK, C]
    neighbor_features = torch.gather(nodes, 1, neighbors_flat)
    # reshape to [B, N, K, C]
    neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
    return neighbor_features



def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
    """
    Gather node features from a neighbor index and concatenate with their edge features.

    Parameters
    ----------
    h_nodes: torch.Tensor
        Node features with shape [B, N, C].

    h_neighbors: torch.Tensor
        Neighbor features with shape [B, N, K, C].

    E_idx: torch.Tensor
        Indices of the neighbors for each node with shape [B, N, K].

    Returns
    -------
    h_nn: torch.Tensor
        Concatenated node and neighbor features with shape [B, N, K, 2C].

    """
    # gather node features [B, N, C] with neighbor indices [B, N, K] => [B, N, K, C]
    h_nodes = gather_nodes(h_nodes, E_idx)
    # concatenate neighbor and node features [B, N, K, C] + [B, N, K, C] => [B, N, K, 2C]
    h_nn = torch.cat([h_neighbors, h_nodes], -1)
    return h_nn

In [18]:
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)

In [19]:
h_EV.shape

torch.Size([1, 16, 4, 640])

In [20]:
h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1).shape

torch.Size([1, 16, 4, 320])

In [21]:
h_EV = torch.cat([h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1), h_EV], -1)

In [22]:
h_EV.shape

torch.Size([1, 16, 4, 960])