In [1]:
# schnet
# https://arxiv.org/pdf/1706.08566.pdf

In [2]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
# interaction block
# cfconv
# ssp(x) = ln(0.5(exp(x + 1))) = ln(0.5) + ln(exp(x) + 1) = ln(exp(x) + 1) - log(2)

In [None]:
# input shape: 
# in_node_feat: [N, feat_dim]
# node_pos: [N, 3]

In [13]:
node_feat = torch.from_numpy(np.random.rand(5, 10)).to(torch.float32)
node_pos = torch.from_numpy(np.random.rand(5, 3)).to(torch.float32)
edge_index = np.array([[0, 1], [1, 0],
                        [0, 2], [2, 0],
                        [0, 3], [3, 0],
                        [0, 4], [4, 0],
                        [1, 2], [2, 1],
                        [1, 4], [4, 1],
                        [2, 3], [3, 2],
                        [2, 4], [4, 2], 
                        [3, 4], [4, 3],]).T
edge_index = torch.from_numpy(edge_index).to(torch.long)

In [69]:
# ssp(x) = ln(0.5(exp(x + 1))) = ln(0.5) + ln(exp(x) + 1)
# why using this activation function? 
def shifted_softplus(x, shift):
    r"""
    Parameters:
        x: torch.tensor
        shift: float

    Returns:
        ln(exp(x) + 1) + ln(0.5)
    """
    return nn.functional.softplus(x) + torch.log(torch.tensor(0.5)).to(x.dtype)

In [18]:
source, target = edge_index

In [25]:
mu = torch.linspace(0, 30, 300)

In [43]:
# [E, 300]
a = torch.broadcast_to(mu.unsqueeze(0), (edge_index.shape[-1], 300))


In [46]:
# [E
d = torch.sum((node_pos[source] - node_pos[target]) ** 2, dim = -1) ** 0.5
# [E] -> [E, 300]
d = torch.broadcast_to(d.unsqueeze(-1), (edge_index.shape[-1], 300))

In [58]:
# why fix gamma and split mu evenly?
def radial_basis_function(x, gamma, mu):
    r"""
    Parameters:
        x (torch.tensor):
        gamma (float):
        mu (torch.tensor)
    """
    return torch.exp(-gamma * ((x - mu) ** 2))

In [63]:
from torch import Tensor

class FilterGeneratingNetworks(nn.Module):
    def __init__(self, num_filters: int):
        super(FilterGeneratingNetworks, self).__init__()
        r"""
        Args:
            num_filters (int):
                number of filters
        """
        self.num_filters = num_filters
        self.rbf = radial_basis_function

    def forward(self, node_pos: Tensor, edge_index: Tensor, lower_bound: float, upper_bound: float, gamma: float) -> Tensor:
        r"""
        Parameters:
            lower_bound (float):
                lower bound for mu values
            upper_bound (float):
                upper bound for mu values
            gamma (float)
            node_pos (torch.tensor):
                3D coordinates of nodes. Shape [N, 3]
            edge_index (torch.tensor)
                Shape [2, E]

        Returns:
            expanded_distance (torch.tensor):
                Shape [E, num_filters]
        """
        source, target = edge_index
        # shape [E]
        distance = torch.sum((node_pos[source] - node_pos[target]) ** 2, dim = -1) ** 0.5
        # shape [E] -> [E, num_filters]
        distance_lifted = torch.broadcast_to(distance.unsqueeze(-1), (edge_index.shape[-1], self.num_filters))

        # shape [num_filters]
        mu = torch.linspace(lower_bound, upper_bound, self.num_filters)
        # shape [num_filters] -> [E, num_filters]
        mu_lifted = torch.broadcast_to(mu.unsqueeze(0), (edge_index.shape[-1], self.num_filters))

        expanded_distance = self.rbf(distance_lifted, gamma, mu_lifted)

        return expanded_distance    

In [76]:
class CFConv(nn.Module):
    def __init__(self, num_filters: int, hidden_dim: int, out_dim: int):
        super(CFConv, self).__init__()
        r"""
        Args:
            num_filters (int):
                number of filters filter-generating networks
            hidden_dim (int):
                hidden dim for linear layer
            out_dim (int):
                final output dimension for CFConv
        """
        self.num_filters = num_filters
        self.out_dim = out_dim
        self.filter_generating = FilterGeneratingNetworks(num_filters)
        self.linear_1 = nn.Linear(num_filters, hidden_dim)
        self.linear_2 = nn.Linear(hidden_dim, out_dim)
        self.activation = shifted_softplus

        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.linear_1.weight)
        self.linear_1.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.linear_2.weight)
        self.linear_2.bias.data.fill_(0)

    def forward(self, in_node_feat: Tensor, node_pos: Tensor, edge_index: Tensor, lower_bound: float, upper_bound: float, gamma: float, shift = None) -> Tensor:
        r"""
        Apply Continuous-filter convolution on input node features

        Parameters:
            in_node_feat (torch.tensor):
                Input node features. Shape [N, feat_dim]
            node_pos (torch.tensor):
                3D coordinates of nodes. Shape [N, 3]
            edge_index (torch.tensor):
                Shape [2, E]
            upper_bound (float):
                Upper bound for mu values
            lower_bound (float):
                Lower bound for mu values
            gamma (float):
                gamma value for Radial Basis Function
            shift (float):
                shift value for Shifted Softplus function. If None then shift is assigned to 0.5

        Returns: 
            out_node_feat (torch.tensor):
                Output node features. Shape [N, out_dim]
        """
        assert in_node_feat.shape[-1] == self.out_dim
        # shape [E, num_filters]
        expanded_distance = self.filter_generating(node_pos, edge_index, lower_bound, upper_bound, gamma)
        # shape [E, num_filters] -> [E, hidden_dim]
        expanded_distance = self.linear_1(expanded_distance)
        if shift is None:
            shift = 0.5
        expanded_distance = self.activation(expanded_distance, shift)
        # shape [E, hidden_dim] -> [E, out_dim]
        expanded_distance = self.linear_2(expanded_distance)
        expanded_distance = self.activation(expanded_distance, shift)

        assert expanded_distance.shape[-1] == in_node_feat.shape[-1]

        source_index, target_index = edge_index
        # compute message for each edge. shape [E, out_dim]
        message = in_node_feat[target_index] * expanded_distance

        # aggregrate message
        # shape [N, out_dim]
        out_node_feat = torch.zeros(in_node_feat.shape[0], self.out_dim).to(message.dtype)
        # lift target index. shape [E] -> [E, out_dim]
        target_index_lifted = torch.broadcast_to(target_index.unsqueeze(-1), (edge_index.shape[-1], self.out_dim))
        out_node_feat = out_node_feat.scatter_add_(0, target_index_lifted, message)

        return out_node_feat

In [84]:
class InteractionBlock(nn.Module):
    def __init__(self, in_dim: int, num_filters: int, hidden_dim: int):
        super(InteractionBlock, self).__init__()
        
        self.atom_wise_layer_1 = nn.Linear(in_dim, in_dim)
        self.cfconv = CFConv(num_filters, hidden_dim, in_dim)
        self.atom_wise_layer_2 = nn.Linear(in_dim, in_dim)
        self.atom_wise_layer_3 = nn.Linear(in_dim, in_dim)
        self.activation = shifted_softplus

        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.atom_wise_layer_1.weight)
        self.atom_wise_layer_1.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.atom_wise_layer_2.weight)
        self.atom_wise_layer_2.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.atom_wise_layer_3.weight)
        self.atom_wise_layer_3.bias.data.fill_(0)
        

    def forward(self, in_node_feat: Tensor, node_pos: Tensor, edge_index: Tensor, lower_bound: Tensor, upper_bound: Tensor, gamma: Tensor, shift = None) -> Tensor:
        r"""
        Parameters:

        Returns:
        """
        if shift is None:
            shift = 0.5
            
        # [N, in_dim] -> [N, in_dim]
        node_feat = self.atom_wise_layer_1(in_node_feat)

        node_feat = self.cfconv(node_feat, node_pos, edge_index, lower_bound, upper_bound, gamma, shift)
        node_feat = self.atom_wise_layer_2(node_feat)
        node_feat = self.activation(node_feat, shift)
        out_node_feat = self.atom_wise_layer_3(node_feat)
        # residual connection
        out_node_feat = out_node_feat + in_node_feat

        return out_node_feat

In [85]:
i = InteractionBlock(node_feat.shape[-1], 300, 64)

In [87]:
i.forward(node_feat, node_pos, edge_index, 0, 300, 10, 0.5).shape

torch.Size([5, 10])

In [1]:
import torch 
import torch.nn as nn
import numpy as np
from torch import Tensor


class ShiftedSoftplus(nn.Module):
    def __init__(self):
        super(ShiftedSoftplus, self).__init__()

    def forward(self, x):
        return nn.functional.softplus(x) - torch.log(torch.tensor(2)).to(x.dtype)

class RadialBasisFunction(nn.Module):
    def __init__(self):
        super(RadialBasisFunction, self).__init__()

    def forward(self, x, gamma, mu):
        r"""
        Parameters:
            x (torch.tensor):
            gamma (float):
            mu (torch.tensor)
        """
        return torch.exp(-gamma * ((x - mu) ** 2))

class EmbeddingBlock(nn.Module):
    def __init__(self, embedding_dim: int):
        super(EmbeddingBlock, self).__init__()

        self.atomic_num_embedding = nn.Embedding(95, embedding_dim, padding_idx = 0)

    def forward(self, atomic_num: Tensor):
        r"""
        Initialize atom's representation

        Parameters:
            atomic_num (torch.Tensor):
                Shape [N]

        Returns:
            node_feat (torch.Tensor):
                Shape [N, embedding_dim]
        """

        return self.atomic_num_embedding(atomic_num)

class FilterGeneratingNetworks(nn.Module):
    def __init__(self, num_filters):
        super(FilterGeneratingNetworks, self).__init__()
        r"""
        Args:
            num_filters (int):
                number of filters
        """
        self.num_filters = num_filters
        self.rbf = RadialBasisFunction()

    def forward(self, node_pos, edge_index, lower_bound, upper_bound, gamma):
        r"""
        Parameters:
            lower_bound (float):
                lower bound for mu values
            upper_bound (float):
                upper bound for mu values
            gamma (float)
            node_pos (torch.tensor):
                3D coordinates of nodes. Shape [N, 3]
            edge_index (torch.tensor)
                Shape [2, E]

        Returns:
            expanded_distance (torch.tensor):
                Shape [E, num_filters]
        """
        source, target = edge_index
        # shape [E]
        distance = torch.sum((node_pos[source] - node_pos[target]) ** 2, dim = -1) ** 0.5
        # shape [E] -> [E, num_filters]
        distance_lifted = torch.broadcast_to(distance.unsqueeze(-1), (edge_index.shape[-1], self.num_filters))

        # shape [num_filters]
        mu = torch.linspace(lower_bound, upper_bound, self.num_filters)
        # shape [num_filters] -> [E, num_filters]
        mu_lifted = torch.broadcast_to(mu.unsqueeze(0), (edge_index.shape[-1], self.num_filters))

        expanded_distance = self.rbf(distance_lifted, gamma, mu_lifted)

        return expanded_distance    
    
# Why Continuous
# The discrete filter (left) is not able to capture the subtle positional changes of the atoms resulting in 
# discontinuous energy predictions Eˆ (bottom left). The continuous filter captures these changes and yields 
# smooth energy predictions (bottom right).  

class CFConv(nn.Module):
    def __init__(self, num_filters, hidden_dim, out_dim):
        super(CFConv, self).__init__()
        r"""
        Args:
            num_filters (int):
                number of filters filter-generating networks
            hidden_dim (int):
                hidden dim for linear layer
            out_dim (int):
                final output dimension for CFConv
        """
        self.num_filters = num_filters
        self.out_dim = out_dim
        self.filter_generating = FilterGeneratingNetworks(num_filters)
        self.linear_1 = nn.Linear(num_filters, hidden_dim)
        self.linear_2 = nn.Linear(hidden_dim, out_dim)
        self.activation = ShiftedSoftplus()

        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.linear_1.weight)
        self.linear_1.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.linear_2.weight)
        self.linear_2.bias.data.fill_(0)

    def forward(self, in_node_feat, node_pos, edge_index, lower_bound, upper_bound, gamma):
        r"""
        Apply Continuous-filter convolution on input node features

        Parameters:
            in_node_feat (torch.tensor):
                Input node features. Shape [N, feat_dim]
            node_pos (torch.tensor):
                3D coordinates of nodes. Shape [N, 3]
            edge_index (torch.tensor):
                Shape [2, E]
            upper_bound (float):
                Upper bound for mu values
            lower_bound (float):
                Lower bound for mu values
            gamma (float):
                gamma value for Radial Basis Function

        Returns: 
            out_node_feat (torch.tensor):
                Output node features. Shape [N, out_dim]
        """
        assert in_node_feat.shape[-1] == self.out_dim
        # shape [E, num_filters]
        expanded_distance = self.filter_generating(node_pos, edge_index, lower_bound, upper_bound, gamma)
        # shape [E, num_filters] -> [E, hidden_dim]
        expanded_distance = self.linear_1(expanded_distance)
        
        expanded_distance = self.activation(expanded_distance)
        # shape [E, hidden_dim] -> [E, out_dim]
        expanded_distance = self.linear_2(expanded_distance)
        expanded_distance = self.activation(expanded_distance)

        assert expanded_distance.shape[-1] == in_node_feat.shape[-1]

        source_index, target_index = edge_index
        # compute message for each edge. shape [E, out_dim]
        message = in_node_feat[target_index] * expanded_distance

        # aggregrate message
        # shape [N, out_dim]
        out_node_feat = torch.zeros(in_node_feat.shape[0], self.out_dim).to(message.dtype)
        # lift target index. shape [E] -> [E, out_dim]
        target_index_lifted = torch.broadcast_to(target_index.unsqueeze(-1), (edge_index.shape[-1], self.out_dim))
        out_node_feat = out_node_feat.scatter_add_(0, target_index_lifted, message)

        return out_node_feat

class InteractionBlock(nn.Module):
    def __init__(self, in_dim, num_filters, hidden_dim):
        super(InteractionBlock, self).__init__()
        
        self.atom_wise_layer_1 = nn.Linear(in_dim, in_dim)
        self.cfconv = CFConv(num_filters, hidden_dim, in_dim)
        self.atom_wise_layer_2 = nn.Linear(in_dim, in_dim)
        self.atom_wise_layer_3 = nn.Linear(in_dim, in_dim)
        self.activation = ShiftedSoftplus()

        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.atom_wise_layer_1.weight)
        self.atom_wise_layer_1.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.atom_wise_layer_2.weight)
        self.atom_wise_layer_2.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.atom_wise_layer_3.weight)
        self.atom_wise_layer_3.bias.data.fill_(0)
        

    def forward(self, in_node_feat, node_pos, edge_index, lower_bound, upper_bound, gamma):
        r"""
        Parameters:

        Returns:
        """
            
        # [N, in_dim] -> [N, in_dim]
        node_feat = self.atom_wise_layer_1(in_node_feat)

        node_feat = self.cfconv(node_feat, node_pos, edge_index, lower_bound, upper_bound, gamma)
        node_feat = self.atom_wise_layer_2(node_feat)
        node_feat = self.activation(node_feat)
        out_node_feat = self.atom_wise_layer_3(node_feat)
        # residual connection
        out_node_feat = out_node_feat + in_node_feat

        return out_node_feat

class SchNet(nn.Module):
    def __init__(self, num_interaction_block: int, hidden_dim: int, num_filters):
        super(SchNet, self).__init__()

        self.embedding_block = EmbeddingBlock(hidden_dim)

        self.interaction_blocks = nn.ModuleList([
            InteractionBlock(hidden_dim, num_filters, hidden_dim) for _ in range(num_interaction_block)
        ])

        self.linear_1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.activation = ShiftedSoftplus()
        self.linear_2 = nn.Linear(hidden_dim // 2, 1)

    def forward(self, atomic_num: Tensor, node_pos: Tensor, edge_index: Tensor,
            lower_bound: float = 0.0, upper_bound: float = 30.0, gamma: float = 10.0):
        r"""
        
        """

        # Shape [N, hidden_dim]
        in_node_feat = self.embedding_block(atomic_num)

        for interaction_block in self.interaction_blocks:
            in_node_feat = interaction_block(in_node_feat, node_pos, edge_index, lower_bound, upper_bound, gamma)

        # Shape [N, hidden_dim] -> [N, hidden_dim // 2]
        in_node_feat = self.linear_1(in_node_feat)
        in_node_feat = self.activation(in_node_feat)

        # Shape [N, hidden_dim] -> [N, 1]
        in_node_feat = self.linear_2(in_node_feat)

        energy = torch.sum(in_node_feat.squeeze())

        return energy

node_feat = torch.from_numpy(np.random.rand(5, 10)).to(torch.float32)
node_pos = torch.from_numpy(np.random.rand(5, 3)).to(torch.float32)
edge_index = np.array([[0, 1], [1, 0],
                        [0, 2], [2, 0],
                        [0, 3], [3, 0],
                        [0, 4], [4, 0],
                        [1, 2], [2, 1],
                        [1, 4], [4, 1],
                        [2, 3], [3, 2],
                        [2, 4], [4, 2], 
                        [3, 4], [4, 3],]).T
edge_index = torch.from_numpy(edge_index).to(torch.long)

atomic_num = torch.randint(0, 95, (5,))

In [2]:
s = SchNet(10, 15, 8)

In [3]:
s.forward(atomic_num, node_pos, edge_index)

tensor(-1.2510, grad_fn=<SumBackward0>)