In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Embedding
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
import numpy as np
import math
from typing import Optional, Tuple, Dict, Any
from rdkit import Chem
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [None]:
class GaussianRadialBasisFunction(nn.Module):
    
    def __init__(self, num_rbf: int = 50, rbf_max: float = 10.0, rbf_min: float = 0.0):
        super().__init__()
        self.num_rbf = num_rbf
        self.rbf_max = rbf_max
        self.rbf_min = rbf_min
        
        # Centers and widths for Gaussian RBFs
        centers = torch.linspace(rbf_min, rbf_max, num_rbf)
        self.register_buffer('centers', centers)
        
        # Width parameter
        self.width = (rbf_max - rbf_min) / (num_rbf - 1)
        
    def forward(self, distances: torch.Tensor) -> torch.Tensor:
        """
        Args:
            distances: (n_edges,) edge distances
        Returns:
            rbf_features: (n_edges, num_rbf) RBF encoded distances
        """
        # Expand dimensions for broadcasting
        d = distances.unsqueeze(-1)  # (n_edges, 1)
        centers = self.centers.unsqueeze(0)  # (1, num_rbf)
        
        # Gaussian RBF
        rbf = torch.exp(-((d - centers) / self.width) ** 2)
        return rbf

def spherical_harmonics_l1(pos: torch.Tensor) -> torch.Tensor:
    # Normalize positions
    pos_norm = torch.norm(pos, dim=-1, keepdim=True) + 1e-8
    pos_normalized = pos / pos_norm
    
    x, y, z = pos_normalized[..., 0], pos_normalized[..., 1], pos_normalized[..., 2]
    
    # l=0: Y_0^0
    Y_00 = torch.ones_like(x) * 0.28209479177  # 1/(2*sqrt(π))
    
    # l=1: Y_1^{-1}, Y_1^0, Y_1^1
    Y_1m1 = 0.48860251190 * y  # sqrt(3/(4π)) * y
    Y_10 = 0.48860251190 * z   # sqrt(3/(4π)) * z  
    Y_11 = 0.48860251190 * x   # sqrt(3/(4π)) * x
    
    return torch.stack([Y_00, Y_1m1, Y_10, Y_11], dim=-1)

* Gaussian Radial Basis Function (RBF): Embedding  atom pair distance
* Spherical Harmonics: Uses Low order spherical harmmonics (l = 0, 1 ) for encode directions (angular information) of vector