In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch import Tensor, tensor

In [2]:
node_feat = torch.from_numpy(np.random.rand(5, 10)).to(torch.float32)
node_pos = torch.from_numpy(np.random.rand(5, 3)).to(torch.float32)
edge_index = np.array([[0, 1], [1, 0],
                        [0, 2], [2, 0],
                        [0, 3], [3, 0],
                        [0, 4], [4, 0],
                        [1, 2], [2, 1],
                        [1, 4], [4, 1],
                        [2, 3], [3, 2],
                        [2, 4], [4, 2], 
                        [3, 4], [4, 3],]).T
edge_index = torch.from_numpy(edge_index).to(torch.long)

In [3]:
v_feat = torch.from_numpy(np.random.rand(5, 10, 3)).to(torch.float32)

In [4]:
atomic_number = torch.randint(0, 95, (5,))

In [5]:
class EmbeddingBlock(nn.Module):
    def __init__(self, embedding_dim: int):
        super(EmbeddingBlock, self).__init__()
        self.atomic_num_embedding = nn.Embedding(95, embedding_dim, padding_idx = 0)

    def forward(self, atomic_num) -> Tensor:
        r"""
        Initialize 

        Parameters:
            atomic_num (torch.tensor):
                Shape [N]

        Returns:
            scalar_feat (torch.Tensor):
                Shape [N, embedding_dim]
        """
        scalar_feat = self.atomic_num_embedding(atomic_num)
        return scalar_feat

In [6]:
class ScaledSiLU(nn.Module):
    def __init__(self):
        super(ScaledSiLU, self).__init__()
        self.scale_factor = 1 / 0.6
        self.activation = nn.SiLU()

    def forward(self, x):
        return self.activation(x) * self.scale_factor

In [7]:
class RadialBasisFunction(nn.Module):
    def __init__(self, num_radial_basis: int, cut_off: float, trainable: bool = False):
        super(RadialBasisFunction, self).__init__()
        self.num_radial_basis = num_radial_basis
        self.cut_off = cut_off

        expanded_distance = nn.Parameter(torch.Tensor(num_radial_basis))
        with torch.no_grad():
            torch.arange(1, num_radial_basis + 1, out = expanded_distance).mul_(torch.pi)
        
        if trainable:
            expanded_distance.requires_grad_()
        else:
            self.register_buffer("expanded_distance", expanded_distance)

        self.expanded_distance = expanded_distance

    def forward(self, distance: Tensor) -> Tensor:
        r"""
        Construct radial basis for distance

        Parameters:
            distance (torch.Tensor):
                Interatomic distance. Shape [E]

        Returns:
            expanded_distance (torch.Tensor):
                Shape [E, num_radial_basis]
        """
        distance_scaled = distance / self.cut_off

        # shape [E, num_radial_basis]
        expanded_distance = torch.sin(self.expanded_distance * distance_scaled.unsqueeze(-1)) / distance.unsqueeze(-1)

        return expanded_distance

In [8]:
class CosineCutoff(nn.Module):
    def __init__(self, cut_off: float):
        super(CosineCutoff, self).__init__()
        self.cut_off = cut_off

    def forward(self, x):
        return 0.5 * (1 + torch.cos(x * torch.pi / self.cut_off)) * (x < self.cut_off).to(x.dtype)

In [9]:
class MessageBlock(nn.Module):
    def __init__(self, embedding_dim, num_radial_basis, cut_off):
        super(MessageBlock, self).__init__()
        self.embedding_dim = embedding_dim
        self.scalar_feat_proj = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            ScaledSiLU(),
            nn.Linear(embedding_dim, embedding_dim * 3)   
        )
        self.radial_basis = RadialBasisFunction(num_radial_basis, cut_off)
        self.rbf_proj = nn.Linear(num_radial_basis, embedding_dim * 3)
        self.cosine_cut_off = CosineCutoff(cut_off)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.scalar_feat_proj[0].weight)
        self.scalar_feat_proj[0].bias.data.fill_(0)
        nn.init.xavier_uniform_(self.scalar_feat_proj[2].weight)
        self.scalar_feat_proj[2].bias.data.fill_(0)
        nn.init.xavier_uniform_(self.rbf_proj.weight)
        self.rbf_proj.bias.data.fill_(0)        

    def forward(self, vectorial_feat: Tensor, scalar_feat: Tensor, node_pos: Tensor, edge_index: Tensor):
        r"""
        Parameters:
            vectorial_feat (torch.Tensor):
                Vectorial representations. Shape [N, embedding_dim, 3]
            scalar_feat (torch.Tensor):
                Scalar representations. Shape [N, embedding_dim]
            edge_index (torch.Tensor):
                Shape [2, E]
            node_pos (torch.Tensor):
                Atom's 3D coordinates. Shape [N, 3]

        Returns:
        """
        num_nodes, num_edges = node_pos.shape[0], edge_index.shape[-1]
        # shape [N, embedding_dim] -> [N, embedding_dim * 3]
        scalar_feat = self.scalar_feat_proj(scalar_feat)

        source, target = edge_index
        # r_ij = r_i - r_j. Shape [E, 3]
        relative_distance = node_pos[target] - node_pos[source]
        # Shape [E]
        distance = torch.sum(relative_distance ** 2, dim = -1) ** 0.5
        # Shape [E, num_radial_basis]
        expanded_distance = self.radial_basis(distance)
        # Shape [E, num_radial_basis] -> [E, embedding_dim * 3]
        filter = self.rbf_proj(expanded_distance)
        filter = self.cosine_cut_off(filter)

        # Shape [E, embedding_dim * 3]
        message = scalar_feat[source] * filter
        # Shape [E, embedding_dim * 3] -> [E, embedding_dim, 3]
        message = message.view(-1, self.embedding_dim, 3)#.permute(2, 0, 1)
        # Shape [E, embedding_dim, 1]
        scalar_message, equivariant_vectorial_message, invariant_vectorial_message = torch.split(message, [1, 1, 1], dim = -1)
        
        # Shape [E, embedding_dim]
        scalar_message = scalar_message.squeeze(-1)
        equivariant_vectorial_message = equivariant_vectorial_message.squeeze(-1)
        invariant_vectorial_message = invariant_vectorial_message.squeeze(-1)

        # aggregrate message
        target_index_lifted = torch.broadcast_to(target.unsqueeze(-1), (num_edges, self.embedding_dim))

        # shape [N, embedding_dim]
        scalar_message = torch.zeros(num_nodes, self.embedding_dim).scatter_add_(0, target_index_lifted, scalar_message)

        # shape [E, 3] -> [E, embedding_dim, 3]
        relative_distance = torch.broadcast_to((relative_distance / distance.unsqueeze(-1)).unsqueeze(1), (num_edges, self.embedding_dim, 3))
        
        # shape [E, embedding_dim] -> [E, embedding_dim, 3]
        invariant_vectorial_message = invariant_vectorial_message.unsqueeze(-1) * relative_distance
        equivariant_vectorial_message = equivariant_vectorial_message.unsqueeze(-1) * vectorial_feat[source]
        vectorial_message = invariant_vectorial_message + equivariant_vectorial_message
        
        target_index_lifted = torch.broadcast_to(target_index_lifted.unsqueeze(-1), (num_edges, self.embedding_dim, 3))
        # shape [N, embedding_dim, 3]
        vectorial_message = torch.zeros(num_nodes, self.embedding_dim, 3).scatter_add_(0, target_index_lifted, vectorial_message)

        return vectorial_message, scalar_message

In [10]:
m = MessageBlock(10, 20, 0.5)

In [11]:
m.forward(v_feat, node_feat, node_pos, edge_index)

(tensor([[[ 7.2948e-02,  3.0525e-01,  2.0908e-01],
          [ 2.1422e-01,  2.3073e-01,  2.6142e-01],
          [ 4.3262e-01,  4.5471e-01,  2.8785e-01],
          [ 3.4445e-04, -2.1261e-02, -3.0663e-02],
          [-1.5536e-01,  4.6588e-01,  5.6555e-01],
          [-1.5878e-01,  3.0535e-02,  1.3734e-01],
          [-7.1694e-01,  2.6245e-01, -2.2045e-01],
          [ 5.1232e-02, -7.1793e-01, -5.6427e-01],
          [-5.5384e-01, -5.2732e-01, -4.4876e-01],
          [-3.0370e-01, -3.7373e-01, -2.8933e-01]],
 
         [[ 5.6659e-02, -2.8151e-01, -2.5984e-01],
          [-8.8295e-02, -1.6611e-02, -1.0053e-01],
          [ 6.3831e-01,  4.7819e-01,  1.1651e-01],
          [ 3.2418e-03,  5.4691e-02,  7.3105e-03],
          [-2.2864e-01, -7.8132e-02, -2.4717e-01],
          [-3.7047e-01, -3.7912e-01, -2.5769e-01],
          [-1.3032e-01, -2.9274e-02, -1.0830e-01],
          [-1.4560e-01,  8.0079e-02, -4.3078e-02],
          [-2.0663e-01, -2.2244e-01, -2.4056e-01],
          [ 5.2094e-02, -8.4

In [12]:
class UpdateBlock(nn.Module):
    def __init__(self, embedding_dim):
        super(UpdateBlock, self).__init__()

        self.scalar_feat_proj = nn.Sequential(
                nn.Linear(2 * embedding_dim, embedding_dim),
                ScaledSiLU(),
                nn.Linear(embedding_dim, 3 * embedding_dim)
        )

        self.vectorial_feat_proj = nn.Linear(embedding_dim, 2 * embedding_dim, bias = False)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.scalar_feat_proj[0].weight)
        self.scalar_feat_proj[0].bias.data.fill_(0)
        nn.init.xavier_uniform_(self.scalar_feat_proj[2].weight)
        self.scalar_feat_proj[2].bias.data.fill_(0)
        nn.init.xavier_uniform_(self.vectorial_feat_proj.weight)

    def forward(self, vectorial_feat: Tensor, scalar_feat: Tensor):
        num_nodes = vectorial_feat.shape[0]
        # shape [N, embedding_dim, 3] -> [N, 2 * embedding_dim, 3]
        vectorial_feat = self.vectorial_feat_proj(vectorial_feat.permute(0, 2, 1)).permute(0, 2, 1)
        # shape [N, 2 * embedding_dim, 3] -> [N, embedding_dim, 2, 3] -> split [N, embedding_dim, 1, 3] [N, embedding_dim, 1, 3] 
        U, V = torch.split(vectorial_feat.view(num_nodes, -1, 2, 3), [1, 1], dim = 2)
        # shape [N, embedding_dim, 1, 3] -> [N, embedding_dim, 3]
        U, V = U.squeeze(2), V.squeeze(2)

        # shape [N, 2 * embedding_dim] -> [N, 3 * embedding_dim]
        a = self.scalar_feat_proj(torch.cat([scalar_feat, torch.sum(V, dim = -1)], dim = -1))
        
        # shape [N, 3 * embedding_dim] -> [N, embedding_dim, 3] -> split into 3 tensors [N, embedding_dim]
        a = a.view(num_nodes, -1, 3)
        a_vv, a_sv, a_ss = torch.split(a, [1, 1, 1], dim = -1)
        a_vv, a_sv, a_ss = a_vv.squeeze(-1), a_sv.squeeze(-1), a_ss.squeeze(-1)

        # [N, embedding_dim]
        scalar_product = torch.sum(U * V, dim = -1)
        # [N, embedding_dim]
        scalar_update = a_ss + a_sv * scalar_product

        # [N, embedding_dim, 3]
        vectorial_update = U * a_vv.unsqueeze(-1)

        return vectorial_update, scalar_update

In [13]:
u = UpdateBlock(10)

In [16]:
class PaiNN(nn.Module):
    def __init__(self, embedding_dim: int, num_blocks: int, num_radial_basis: int, cut_off: float, out_dim: int):
        super(PaiNN, self).__init__()
        self.embedding_dim = embedding_dim
        self.embedding_block = EmbeddingBlock(embedding_dim)

        self.message_blocks = nn.ModuleList([
            MessageBlock(embedding_dim, num_radial_basis, cut_off) for _ in range(num_blocks)
        ])

        self.update_blocks = nn.ModuleList([
            UpdateBlock(embedding_dim) for _ in range(num_blocks)
        ])

        self.out_proj = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim // 2),
            ScaledSiLU(),
            nn.Linear(embedding_dim // 2, out_dim)
        )

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.out_proj[0].weight)
        self.out_proj[0].bias.data.fill_(0)
        nn.init.xavier_uniform_(self.out_proj[2].weight)
        self.out_proj[2].bias.data.fill_(0)

        for message_block, update_block in zip(self.message_blocks, self.update_blocks):
            message_block.reset_parameters()
            update_block.reset_parameters()

    def forward(self, atomic_num: Tensor, node_pos: Tensor, edge_index: Tensor) -> Tensor:
        r"""
        Parameters:
            atomic_num (torch.Tensor):
                Atomic number of each atom in the molecular graph. Shape [N]
            node_pos (torch.Tensor):
                Atoms' 3D coordinates. Shape [N, 3]
            edge_index (torch.Tensor):
                Shape [2, E]
        """
        num_nodes = atomic_number.shape[0]
        # Initialize vectorial representations and scalar representations
        # Shape [N, embedding_dim, 3], [N, embedding_dim]
        vectorial_feat, scalar_feat = torch.zeros(num_nodes, self.embedding_dim, 3), self.embedding_block(atomic_num)
        
        _, scalar_message = self.message_blocks[0](vectorial_feat, scalar_feat, node_pos, edge_index)
        scalar_feat = scalar_feat + scalar_message
        vectorial_update, scalar_update = self.update_blocks[0](vectorial_feat, scalar_feat)

        for message_block, update_block in zip(self.message_blocks[1:], self.update_blocks[1:]):
            vectorial_message, scalar_message = message_block(vectorial_feat, scalar_feat, node_pos, edge_index)
            vectorial_feat = vectorial_feat + vectorial_message
            scalar_feat = scalar_feat + scalar_message

            vectorial_update, scalar_update = update_block(vectorial_feat, scalar_feat)
            vectorial_feat = vectorial_feat + vectorial_message
            scalar_feat = scalar_feat + scalar_message

        # [N, out_dim]
        scalar = self.out_proj(scalar_feat)

        return scalar

In [17]:
p = PaiNN(10, 5, 20, 0.5, 1)

In [18]:
p.forward(atomic_number, node_pos, edge_index).shape

torch.Size([5, 1])

In [27]:
torch.sum(torch.tensor([1, 2, 3]))

tensor(6)