In [None]:
# Install libraries
!pip install --pre dgl-cu101

# Import libraries
import torch
from torch import nn
import torch.nn.functional as F

import dgl
import dgl.function as fn
from dgl import DGLGraph

import numpy as np

from scipy.special import sph_harm as sph_harm_func

# The following library has functions for clebsch-gordan and wigner-d but they 
# cause the session to crash

#!pip install spherical
#import spherical
#import quaternionic

In [None]:
from numpy import sqrt
from scipy.special import factorial


def clebsch_gordan(j1,j2,m1,m2,J,M=None):
  ''' Using equation from Wikipedia:
  https://en.wikipedia.org/wiki/Table_of_Clebsch–Gordan_coefficients.

  The equation isn't numerically stable and not fast (speed isn't too important)
  TODO: Possibly find more stable implementation'''
  # TODO: Replace with spherical.clebsch_gordan
  if M is None:
    M=m1+m2
  if M<0:
    return (-1)**(J-j1-j2)*clebsch_gordan(j1,j2,-m1,-m2,J,-M)
  if j1<j2:
    return (-1)**(J-j1-j2)*clebsch_gordan(j2,j1,m2,m1,J,M)
  if not M==m1+m2:
    return 0
  A = sqrt((2*J+1)*factorial(J+j1-j2)*factorial(J-j1+j2)*factorial(j1+j2-J)/factorial(j1+j2+J+1))
  B = sqrt(factorial(J+M)*factorial(J-M)*factorial(j1-m1)*factorial(j1+m1)*factorial(j2-m2)*factorial(j2+m2))
  k_max = min([j1+j2-J,j1-m1,j2+m2])
  k_min = max([0,-(J-j2+m1),-(J-j1-m2)])
  C = 0
  for k in range(int(k_min), int(k_max)+1):
    C += (-1)**k/(factorial(k)*factorial(j1+j2-J-k)*factorial(j1-m1-k)*factorial(j2+m2-k)*factorial(J-j2+m1+k)*factorial(J-j1-m2+k))
  return A*B*C

def clebsch_gordan_mats(j1,j2,J):
  mats = torch.zeros(2*J+1,2*j1+1,2*j2+1)
  for x, m in enumerate(torch.arange(-J, J+1)):
    for y, m1 in enumerate(torch.arange(-j1, j1+1)):
      for z, m2 in enumerate(torch.arange(-j2, j2+1)):
        mats[x,y,z] = clebsch_gordan(j1,j2,m1,m2,J)
  return mats

clebsch_gordan(1/2,1/2,-1/2,1/2,1,0)

In [None]:
from numpy import sqrt, cos, sin
from scipy.special import factorial

def wigner_small(j,b,mp,m):
  A = sqrt(factorial(j+mp)*
           factorial(j-mp)*
           factorial(j+m)*
           factorial(j-m))
  B=0
  for s in range(max(0,m-mp),min(j+m,j-mp)+1):
    top = ((-1)**(mp-m+s)*
          (cos(b/2)**(2*j+m-mp-2*s))*
          (sin(b/2)**(mp-m+2*s)))
    bot = (factorial(j+m-s)*
          factorial(s)*
          factorial(mp-m+s)*
          factorial(j-mp-s))
    B += top/bot
  return A*B

def wigner_small_mat(l, b):
  mat = torch.zeros(2*l+1,2*l+1)
  for i in range(2*l+1):
    for j in range(2*l+1):
      mat[j,i] = wigner_small(l,b,i-l,j-l)
  return mat

#wigner(1,np.pi/2,1,1)
wigner_small_mat(1,np.pi/2)

In [None]:
class PointCloud:
  '''Represents a point cloud in R^3. This class calculates and stores relevant
  information such as the vectors, distances, directions, and spherical
  harmonics of the vectors.'''
  # new_g, new_ntypes, new_etypes, new_nframes, new_eframes
  def __init__(self, pos, cutoff=8.0):
    edges = self._find_edges_(pos, cutoff)
    self.graph = dgl.graph(edges)
    self.graph.ndata['pos'] = pos
    self._calc_edge_info_()
    self.sph_harm = dict()
    self.wj = dict()

  def _find_edges_(self, pos, cutoff):
    # Use positions to create graph. Need to improve! Currently O(n^2)
    self.vec_mat = pos[:,None,:]-pos[None,:,:]
    self.dist_mat = torch.sqrt(torch.sum((self.vec_mat)**2,axis=-1))
    u = []
    v = []
    for j in range(len(pos)):
      for i in range(j):
        if self.dist_mat[i,j] < cutoff:
          u.append(i)
          v.append(j)
    u, v = torch.tensor(u+v), torch.tensor(v+u)
    return (u,v)

  def _calc_edge_info_(self):
    # Calculate and store position and angle information
    u,v = self.graph.edges()[0], self.graph.edges()[1]
    vec = self.vec_mat[u,v]
    self.graph.edata['vec'] = vec
    dist = self.dist_mat[u,v]
    self.graph.edata['dist'] = dist
    dir = vec/dist[:,None]
    self.graph.edata['dir'] = dir
    self.graph.edata['theta'] = torch.atan2(dir[:,1], dir[:,0])
    self.graph.edata['phi'] = torch.arccos(dir[:,2])

  def get_sph_harm(self, J):
    # Returns spherical harmonic of order J.
    if not J in self.sph_harm.keys():
      m = torch.arange(-J,J+1)
      theta = self.graph.edata['theta']
      phi = self.graph.edata['phi']
      self.sph_harm[J] = torch.real(sph_harm_func(m[None,:], J, theta[:,None], phi[:,None])).double()
    return self.sph_harm[J]
    
  def get_wj(self, l, k):
    # This needs to be improved
    if not (l,k) in self.wj.keys():
      wj = torch.zeros(k+l-abs(k-l)+1, self.graph.number_of_edges(), 2*l+1, 2*k+1)
      for i, J in enumerate(range(abs(k-l), k+l+1)):
        sh = self.get_sph_harm(J)
        cg = clebsch_gordan_mats(l,k,J).double()
        wj[i] = torch.einsum("em,mlk->elk",sh, cg)
      self.wj[(l,k)] = wj.transpose(0,1).clone()
    return self.wj[(l,k)]

In [None]:
# TODO: Self interaction
# TODO: Ability to change head size

# Dictionary for indices
# e: edges
# i: c_in
# o: c_out
# l: output tensor representation
# k: input tensor representation
# j: hidden tensor representation

class WLayer(nn.Module):
  def __init__(self, k, l, c_in=1, c_out=1):
    super(WLayer, self).__init__()
    self.k = k
    self.l = l
    self.c_in = c_in
    self.c_out = c_out

    r_size = (k+l-abs(k-l)+1) * c_out * c_in

    self.radial = nn.Sequential(nn.Linear(1,32),
                                nn.BatchNorm1d(32),
                                nn.ReLU(),
                                nn.Linear(32,32),
                                nn.BatchNorm1d(32,32),
                                nn.ReLU(),
                                nn.Linear(32,r_size))

  def forward(self, pc):
    wj = pc.get_wj(self.l, self.k)
    R = self.radial(pc.graph.edata['dist'][:,None])
    R = R.reshape((-1,
                   self.k+self.l-abs(self.k-self.l)+1,
                   self.c_out,
                   self.c_in))
    w = torch.einsum('ejoi,ejlk->eoilk',R,wj)
    return w

In [None]:
from torch.autograd import Variable

class AttnBlock(nn.Module):
  def __init__(self, d_in, d_out, c_in=1, c_out=1):
    super(AttnBlock, self).__init__()
    self.d_in = d_in
    self.d_out = d_out

    self.c_in = c_in
    self.c_out = c_out

    self.wq = Variable(torch.randn(d_in+1,
                        c_out,
                        c_in),
                       requires_grad=True)
    
    self.wk_layers = [[WLayer(k, l, c_in, c_out)
                      for k in range(d_in+1)] 
                     for l in range(d_out+1)]

  def forward(self, pc, feat):
    for key, value in feat.items():
      pc.graph.ndata[key] = value
    pc.graph.ndata['q'] = self.calc_q(pc)

    for l in range(self.d_out+1):
      for k in range(self.d_in+1):
        pc.graph.edata[(k,l)] = self.wk_layers[l][k](pc)

    pc.graph.update_all(self.attn_msg, self.attn_rdc)
    a = pc.graph.edata['exp'] / pc.graph.ndata['sum'][pc.graph.edges()[1]]
    return a

  def attn_msg(self, edges):
    k = self.calc_k(edges)
    q = edges.dst['q']
    exp = torch.exp(torch.einsum('eol,eol->e',q,k))
    edges.data['exp'] = exp
    return {'exp': exp}

  def attn_rdc(self, nodes):
    # does sum over j'
    exp = nodes.mailbox['exp']
    sum = torch.sum(exp, dim=1)
    return {'sum': sum}

  def calc_q(self, pc):
    ql = []
    for k in range(min(self.d_in,self.d_out)+1):
      sum = torch.einsum('oi,nik->nok',self.wq[k],feat[k])
      ql.append(sum)
    q = torch.cat(ql, dim=2)
    return q

  def calc_k(self, edges):
    kl = []
    for l in range(min(self.d_in,self.d_out)+1):
      wks = []
      for k in range(self.d_in+1):
        wk = torch.einsum('eoilk,eik->eol',
                          edges.data[(k,l)],
                          edges.dst[k])
        wks.append(wk)
      stack = torch.stack(wks, dim=3)
      sum = torch.sum(stack, dim=3)
      kl.append(sum)
    k = torch.cat(kl, dim=2)
    return k

class TransLayer(nn.Module):
  def __init__(self, d_in, d_out, c_in=1, c_out=1):
    super(TransLayer, self).__init__()
    self.d_in = d_in
    self.d_out = d_out
    self.c_in = c_in
    self.c_out = c_out

    self.wv_layers = [[WLayer(k, l, c_in=c_in, c_out=c_out)
                       for l in range(d_out+1)] 
                      for k in range(d_in+1)]

    self.block = AttBlock(d_in, d_out, c_in=c_in, c_out=c_out)

  def forward(self, pc, feat):
    pc.graph.edata['a'] = self.block(pc, feat)
    for k in range(self.d_in+1):
      pc.graph.ndata[k] = feat[k]
      for l in range(self.d_out+1):
        pc.graph.edata[(k,l)] = self.wv_layers[k][l](pc)
    pc.graph.update_all(self.msg_func, self.rdc_func)
    f = dict()
    for l in range(self.d_out+1):
      f[l] = pc.graph.ndata[l]
    return f

  def msg_func(self, edges):
    vls = dict()
    for l in range(self.d_out+1):
      vl = self.calc_vl(edges, l)
      vls[l] = vl
    return vls

  def rdc_func(self, nodes):
    f = dict()
    for key, value in nodes.mailbox.items():
      f[key] = torch.sum(value, dim=1)
    return f

  def calc_vl(self, edges, l):
    a = edges.data['a']
    vlks = []
    for k in range(self.d_in+1):
      wk = edges.data[(k,l)]
      f = edges.src[k]
      vlk = torch.einsum('e,eoilk,eik->eol', a, wk, f)
      vlks.append(vlk)
    vlk = torch.stack(vlks, dim=3)
    vl = torch.sum(vlk, dim=3)
    return vl

# Split features into chunks
def split_feat(heads, feat):
  feats = [dict()]*heads
  for key, value in feat.items():
    chunk = torch.chunk(value,heads,dim=1)
    for i in range(heads):
      feats[i][key] = chunk[i]
  return feats

In [None]:
# Create a test graph
num_pts = 10
pos = torch.rand(num_pts,3)
pc = PointCloud(pos)

in_channels=2
out_channels=3
feat = {0: torch.rand(num_pts,in_channels,1),
        1: torch.rand(num_pts,in_channels,3)}

# AttnBlock
block = AttnBlock(0, 1,in_channels,out_channels)
block(pc, feat)

# TFN layer
layer = TransLayer(1,1,in_channels,out_channels)
layer(pc, feat)[1].size()

# Transformer layer
#layer = TransLayer(0,1,channels=2)
#layer(pc, feat).size()