In [None]:
# Install libraries
!pip install --pre dgl-cu101

# Import libraries
import torch
from torch import nn
import torch.nn.functional as F

import dgl
import dgl.function as fn
from dgl import DGLGraph

from scipy.special import sph_harm as sph_harm_func

In [None]:
from numpy import sqrt
from scipy.special import factorial


def clebsch_gordan(j1,j2,m1,m2,J,M=None):
  ''' Using equation from Wikipedia:
  https://en.wikipedia.org/wiki/Table_of_Clebsch–Gordan_coefficients.

  The equation isn't numerically stable and not fast (speed isn't too important)
  TODO: Possibly find more stable implementation'''
  if M is None:
    M=m1+m2
  if M<0:
    return (-1)**(J-j1-j2)*clebsch_gordan(j1,j2,-m1,-m2,J,-M)
  if j1<j2:
    return (-1)**(J-j1-j2)*clebsch_gordan(j2,j1,m2,m1,J,M)
  if not M==m1+m2:
    return 0
  A = sqrt((2*J+1)*factorial(J+j1-j2)*factorial(J-j1+j2)*factorial(j1+j2-J)/factorial(j1+j2+J+1))
  B = sqrt(factorial(J+M)*factorial(J-M)*factorial(j1-m1)*factorial(j1+m1)*factorial(j2-m2)*factorial(j2+m2))
  k_max = min([j1+j2-J,j1-m1,j2+m2])
  k_min = max([0,-(J-j2+m1),-(J-j1-m2)])
  C = 0
  for k in range(int(k_min), int(k_max)+1):
    C += (-1)**k/(factorial(k)*factorial(j1+j2-J-k)*factorial(j1-m1-k)*factorial(j2+m2-k)*factorial(J-j2+m1+k)*factorial(J-j1-m2+k))
  return A*B*C

def clebsch_gordan_mat(j1,j2,J):
  mat = torch.zeros((int(2*j1+1),int(2*j2+1)))
  for i, m1 in enumerate(torch.arange(-j1, j1+1)):
    for j, m2 in enumerate(torch.arange(-j2, j2+1)):
      mat[i,j] = clebsch_gordan(j1,j2,m1,m2,J,m1+m2)
  return mat

def clebsch_gordan_mats(j1,j2):
  J_size = j1+j2-abs(j1-j2)+1
  mats = torch.zeros(J_size,2*j1+1,2*j2+1)
  for x, J in enumerate(torch.arange(abs(j1-j2),j1+j2+1)):
    for y, m1 in enumerate(torch.arange(-j1, j1+1)):
     for z, m2 in enumerate(torch.arange(-j2, j2+1)):
       mats[i,j,k] = clebsch_gordan(j1,j2,m1,m2,J)
  return mats

clebsch_gordan(1/2,1/2,-1/2,1/2,1,0)

In [None]:
class TFNLayer(nn.Module):
  def __init__(self, l, k, feats):
    super(TFNLayer, self).__init()
    self.feats = feats
    self.l = l
    self.k = k
    self.j_range = l+k-abs(l-k)

    self.radial = nn.Sequential(nn.Linear(1,5), nn.ReLU(), nn.Linear(5,self.c))

  def forward(self, g, feat):
    with g.local_scope():
      g.ndata['feat'] = feat
      g.edata['wj'] = g.get_wj(self.l, self.k)
      dist = self.edata['dist']
      g.edata['r'] = self.radial(dist)
      g.update_all(self.message_func, self.reduce_func)
      return g.ndata['feat']

  def message_func(edges):
    r    = edges['r']
    wj   = edges.src['wj']
    feat = edges.src['feat']
    return {'f': torch.einsum('eclk,ejkl,eck->el',r,wj,feat)}

  def reduce_func(nodes):
    return {'feat': torch.sum(nodes.mailbox['f'], dim=1)}

In [None]:
class PointCloudGraph(dgl.DGLGraph):
  '''Represents a point cloud in R^3. This class calculates and stores relevant
  information such as the vectors, distances, directions, and spherical
  harmonics of the vectors.'''
  def __init__(self, pos, cutoff=5.0):
    # This code needs to be optimized. It is currently O(n^2) in distance
    # calculation.
    #
    # TODO: Allow user to provide own graph if there is a better method of
    # finding the graph structure (i.e. MD simulations sometimes use cell lists
    # or other structures which would be more efficient).

    self.cutoff = cutoff
    u,v = self._create_graph_(pos, cutoff)
    super(PointCloudGraph, self).__init__((u,v))
    self.ndata['pos'] = pos.clone().detach()
    self._calc_edge_info_(pos)
    self.sph_harm = dict()
    self.wj = dict()

  def _create_graph_(self, pos, cutoff):
    # Use positions to create graph. Need to improve! Currently O(n^2)
    self.vec_mat = pos[:,None,:]-pos[None,:,:]
    self.dist_mat = torch.sqrt(torch.sum((self.vec_mat)**2,axis=-1))
    u = []
    v = []
    for j in range(len(pos)):
      for i in range(j):
        if self.dist_mat[i,j] < self.cutoff:
          u.append(i)
          v.append(j)
    u, v = torch.tensor(u+v), torch.tensor(v+u)
    return u,v

  def _calc_edge_info_(self):
    # Calculate and store position and angle information
    vec = self.vec_mat[u,v]
    self.edata['vec'] = vec
    dist = self.dist_mat[u,v]
    self.edata['dist'] = dist
    dir = vec/dist[:,None]
    self.edata['dir'] = dir
    self.edata['theta'] = torch.atan2(dir[:,1], dir[:,0])
    self.edata['phi'] = torch.arccos(dir[:,2])

  def get_sph_harm(self, J):
    # Returns spherical harmonic of order J.
    if not J in self.sph_harm.keys():
      m = torch.arange(-J,J+1)
      self.sph_harm[J] = sph_harm_func(J, m[None,:], self.theta[:,None],
                                       self.phi[:,None])
    return self.sph_harm[J]
    
  def get_wj(self, l, k):
    if not (l,k) in self.wj.keys():
      cg_mats = clebsch_gordan_mats(l,k)
      sh = torch.tensor([self.get_sph_harm(J) for J in range(abs(l-k),l+k+1)])
      self.wj[(l,k)] = torch.einsum('jm,jmlk->jlk',sh,cg)
    return self.wj[(l,k)]

In [None]:
# Create a test graph
num_pts = 10
pos = torch.rand(num_pts,3)
graph = PointCloudGraph(pos)
print(graph)

theta = graph.edata['theta']
phi = graph.edata['phi']

# Check that the spherical harmonics are calculated correctly. They aren't.
cache = SphHarmCache(theta, phi)
print(cache[1][0], graph.edata['dir'][0]) # These two outputs should be
    # same. TODO: Figure this out and fix it! Change spherical harmonic
    # implementation first though.