In [1]:
import numpy as np
import tensorflow as tf

In [2]:
import scipy.special

In [3]:
maxdeg = 3
dim = 2

In [4]:
smd = maxdeg + 1
dof = int(scipy.special.binom(smd + dim - 1, dim))
digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"

In [5]:
def padder(x):
    if len(x) < dim:
        return '0' * (dim - len(x)) + x
    else:
        return x

def int2str(x):
    if x < 0:
        return "-" + int2str(-x)
    return ("" if x < smd else int2str(x // smd)) + digits[x % smd]

def padint2str(x):
    return padder(int2str(x))


In [6]:
def index_mapping():
    index = 0
    index_map = {}
    allpos = list(map(padint2str, range(smd ** dim)))
    for d in range(smd):
        for s in allpos:
            y = list(map(int, s))[::-1]
            if sum(y) == d:
                index_map[tuple(y)] = index
                index += 1

    return index_map


In [7]:
indmap = index_mapping()

In [8]:
ns = np.zeros(9)
for ii in range(9):
    ns[ii] = 1.0/np.sqrt(np.sqrt(2 * np.pi) * scipy.special.factorial(ii))

In [9]:
h2omat = np.zeros((9,9))
for ii in range(9):
    h2omat[ii, ii] = 1.0 #ns[ii]

h2omat[0, 2] = -h2omat[2, 2]
h2omat[1, 3] = -3 * h2omat[3, 3]
h2omat[2, 4] = -6 * h2omat[4, 4]
h2omat[0, 4] = 3 * h2omat[4, 4]
h2omat[3, 5] = -10 * h2omat[5, 5]
h2omat[1, 5] = 15 * h2omat[5, 5]
h2omat[4, 6] = -15 * h2omat[6, 6]
h2omat[2, 6] = 45 * h2omat[6, 6]
h2omat[0, 6] = -15 * h2omat[6, 6]
h2omat[5, 7] = -21 * h2omat[7, 7]
h2omat[3, 7] = 105 * h2omat[7, 7]
h2omat[1, 7] = -105 * h2omat[7, 7]
h2omat[6, 8] = -28 * h2omat[8, 8]
h2omat[4, 8] = 210 * h2omat[8, 8]
h2omat[2, 8] = -420 * h2omat[8, 8]
h2omat[0, 8] = 105 * h2omat[8, 8]
h2omat = h2omat[0:smd, 0:smd]

In [10]:
transmat = np.full((dof, dof), 1.)
for row_index in indmap:
    for col_index in indmap:
        for d in range(dim):
            transmat[indmap[row_index], indmap[col_index]] *= h2omat[row_index[d], col_index[d]]

In [11]:
indlist = []
for d in range(dim):
    for index in indmap:
        indlist.append([d, index[d]])

In [12]:
derivindlists = [[]]*dim
# selfindlists = [[]]*dim
for s in range(dim):
    derivindlists[s] = []
    # selfindlists[s] = []
    for d in range(dim):
        for index in indmap:
            if d == s:
                if index[d] >= 1:
                    derivindlists[s].append([0, d, index[d]-1])
                else:
                    derivindlists[s].append([1, 0, 0])
            else:
                derivindlists[s].append([0, d, index[d]])


In [13]:
print(len(indlist))
print(len(derivindlists[1]))

20
20


In [43]:
prefaclist = [[]]*dim
for d in range(dim):
    prefaclist[d] = np.ones(dof)
    for index in indmap:
        nf = ns[index[d]]/ns[index[d]-1]
        prefaclist[d][indmap[index]] *= nf*index[d]
    
    prefaclist[d] = tf.expand_dims(tf.constant(prefaclist[d], dtype=tf.float32),1)

In [15]:
prefaclist[0]
prefaclist[1]

array([0.        , 0.        , 1.        , 0.        , 1.        ,
       1.41421356, 0.        , 1.        , 1.41421356, 1.73205081])

In [16]:
tfh2omat = tf.constant(np.flip(h2omat,axis=0).T,dtype=tf.float32)

In [17]:
x = tf.placeholder(dtype=tf.float32, shape=[None, dim])

In [18]:
def tfherm(n):
    return tf.math.polyval(tf.unstack(tfh2omat[n]), x)

In [75]:
def tfbasis(npts):
    Hcached = tf.map_fn(fn=tfherm,
                        elems=np.arange(smd,dtype=np.int32),
                        back_prop=False,
                        dtype=tf.float32)

    Hcached = tf.transpose(Hcached,[2,0,1])  
    # Hcached has dimensions (dim, smd, npts)
    hermout = tf.reduce_prod(tf.reshape(tf.gather_nd(Hcached, indlist),[dim,dof,npts]),axis=0)
    hermout = tf.transpose(hermout,[1,0])
    return hermout

In [76]:
def tfgradient(npts):
    Hcached = tf.map_fn(fn=tfherm,
                        elems=np.arange(smd,dtype=np.int32),
                        back_prop=False,
                        dtype=tf.float32)

    Hcached = tf.transpose(Hcached,[2,0,1])  
    derivhermout = []
    bigten = tf.stack([Hcached, tf.zeros([dim,smd,npts])])
    for derivdim in range(dim):
        part = tf.gather_nd(bigten, derivindlists[derivdim])
        # part2 = tf.gather_nd(bigten, selfindlists[derivdim])
        derivhermout.append(
            tf.reduce_prod(tf.reshape(part,[dim,dof,npts]),axis=0)
        )
        derivhermout[derivdim] *= prefaclist[derivdim]
    
    derivhermout = tf.transpose(tf.stack(derivhermout),[0,2,1])
    return derivhermout
        

In [77]:
with tf.Session() as sess:
    test = sess.run(tfbasis(3), feed_dict = {x : np.array([[3.,-4.],[-1.13,2.08],[-7.13,9.08]])})


In [66]:
Hcached = np.transpose(test3, [2, 0, 1])
hermout = np.full((dim, 3, dof), 1.)

basisout = np.full((3, dof), 1.)
for index in indmap:
    for d in range(dim):
        basisout[:, indmap[index]] *= Hcached[:, d, index[d]]

for derivdim in range(dim):
    for index in indmap:
        for d in range(dim):
            if d==derivdim:
                nf = ns[index[d]]/ns[index[d]-1]
                hermout[derivdim, :, indmap[index]] *= Hcached[:, d, index[d]-1]*index[d]*nf
            if d != derivdim:
                hermout[derivdim, :, indmap[index]] *= Hcached[:, d, index[d]]


In [79]:
hermout

array([[[  0.        ,   1.        ,  -0.        ,   4.2426405 ,
          -4.        ,   0.        ,  13.85640621, -16.97056198,
          15.        ,  -0.        ],
        [  0.        ,   1.        ,   0.        ,  -1.59806132,
           2.07999992,   0.        ,   0.47960475,  -3.32396743,
           3.3263998 ,   0.        ],
        [ -0.        ,   1.        ,  -0.        , -10.08334255,
           9.07999992,  -0.        ,  86.32004547, -91.5567496 ,
          81.44639587,  -0.        ]],

       [[ -0.        ,  -0.        ,   1.        ,  -0.        ,
           3.        ,  -5.65685415,  -0.        ,   8.        ,
         -16.97056246,  25.98076248],
        [  0.        ,  -0.        ,   1.        ,   0.        ,
          -1.13      ,   2.94156408,   0.        ,   0.27689993,
          -3.3239674 ,   5.76149321],
        [  0.        ,  -0.        ,   1.        ,   0.        ,
          -7.13000011,  12.84105873,  -0.        ,  49.83690262,
         -91.55675022, 141.0

In [80]:
test.shape

(3, 10)

In [81]:
basisout.shape

(3, 10)

In [82]:
np.sum(np.abs(test - basisout))

2.104320594753517e-05