In [1]:
from jax.scipy.special import betaln, gammainc, gammaln
import numpy as np
from sympy import *

## Table lookup for "binom factor"

The function called "binom factor" is one of those with a variable-sized loop, 
which is tricky to vectorize.  Its definition is
$$
B(i,j,a,b,s) = \sum_{n=s-i}^j \binom{i}{s-n} \binom{j}{n} a^{i-(s-n)} b^{j-n}
$$

Its arguments are three integers $(i,j,s)$ and two floats.  It can't be a lookup table because of the floats, and it can't be a simple table of functions because JAX would not be impressed.

But... For a given $i,j,s$ we know it will be a polynomial in $(a,b)$:
$$
B(i,j,a,b,s) = \sum_{p=1}^{max} \sum_{q=1}^{max} w^{i,j,s}_{p,q} a^p b^q
$$

This notebook computes those weights.  If max=8, that is a 128KB array so might need to be sparsely stored.

In [2]:
LMAX = 8
kilobytes = LMAX ** 5 * 4 / 1024
print(f"{kilobytes=}")

kilobytes=12.20703125


Let's now compute the coefficients W.  First, we just define the functions, then pass them through sympy.

In [3]:
def binom(x, y):
    approx = 1.0 / ((x + 1) * np.exp(betaln(x - y + 1, y + 1)))
    return int(np.rint(approx))


def binom_factor(i: int, j: int, a: float, b: float, s: int):
    out = 0
    for t in range(max(s - i, 0), j + 1):
        assert ((s - i) <= t) & (t <= j)
        val = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t)
        out += val
    return out

def binom_factor_sym(i: int, j: int, a: Symbol, b: Symbol, s: int):
    return Poly(binom_factor(i, j, a, b, s), a, b)


a, b = symbols("a b", real=True)
for i in range(0,7,3):
    for j in range(0,7,2):
        for s in range(0,7,2):
            bf = binom_factor_sym(i, j, a, b, s)
            print((i, j, s), bf)

display(bf, bf.coeff_monomial(a**3*b**3))



(0, 0, 0) Poly(1, a, b, domain='ZZ')
(0, 0, 2) Poly(0, a, b, domain='ZZ')
(0, 0, 4) Poly(0, a, b, domain='ZZ')
(0, 0, 6) Poly(0, a, b, domain='ZZ')
(0, 2, 0) Poly(b**2, a, b, domain='ZZ')
(0, 2, 2) Poly(1, a, b, domain='ZZ')
(0, 2, 4) Poly(0, a, b, domain='ZZ')
(0, 2, 6) Poly(0, a, b, domain='ZZ')
(0, 4, 0) Poly(b**4, a, b, domain='ZZ')
(0, 4, 2) Poly(6*b**2, a, b, domain='ZZ')
(0, 4, 4) Poly(1, a, b, domain='ZZ')
(0, 4, 6) Poly(0, a, b, domain='ZZ')
(0, 6, 0) Poly(b**6, a, b, domain='ZZ')
(0, 6, 2) Poly(15*b**4, a, b, domain='ZZ')
(0, 6, 4) Poly(15*b**2, a, b, domain='ZZ')
(0, 6, 6) Poly(1, a, b, domain='ZZ')
(3, 0, 0) Poly(a**3, a, b, domain='ZZ')
(3, 0, 2) Poly(3*a, a, b, domain='ZZ')
(3, 0, 4) Poly(0, a, b, domain='ZZ')
(3, 0, 6) Poly(0, a, b, domain='ZZ')
(3, 2, 0) Poly(a**3*b**2, a, b, domain='ZZ')
(3, 2, 2) Poly(a**3 + 6*a**2*b + 3*a*b**2, a, b, domain='ZZ')
(3, 2, 4) Poly(3*a + 2*b, a, b, domain='ZZ')
(3, 2, 6) Poly(0, a, b, domain='ZZ')
(3, 4, 0) Poly(a**3*b**4, a, b, domain='

Poly(a**6 + 36*a**5*b + 225*a**4*b**2 + 400*a**3*b**3 + 225*a**2*b**4 + 36*a*b**5 + b**6, a, b, domain='ZZ')

400

In [4]:
monomials_a = Matrix([a ** i for i in range(LMAX)])
monomials_b = Matrix([b ** i for i in range(LMAX)])

all_monomials = monomials_a * monomials_b.transpose()
display(all_monomials)
all_monomials = all_monomials.reshape(1,LMAX**2)

display(all_monomials)


Matrix([
[   1,      b,      b**2,      b**3,      b**4],
[   a,    a*b,    a*b**2,    a*b**3,    a*b**4],
[a**2, a**2*b, a**2*b**2, a**2*b**3, a**2*b**4],
[a**3, a**3*b, a**3*b**2, a**3*b**3, a**3*b**4],
[a**4, a**4*b, a**4*b**2, a**4*b**3, a**4*b**4]])

Matrix([[1, b, b**2, b**3, b**4, a, a*b, a*b**2, a*b**3, a*b**4, a**2, a**2*b, a**2*b**2, a**2*b**3, a**2*b**4, a**3, a**3*b, a**3*b**2, a**3*b**3, a**3*b**4, a**4, a**4*b, a**4*b**2, a**4*b**3, a**4*b**4]])

In [5]:
def get_coeffs(p):
  return tuple(p.coeff_monomial(m) for m in all_monomials)

np.array(get_coeffs(binom_factor_sym(6,6,a,b,6))).reshape(LMAX,LMAX)

array([[0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 225],
       [0, 0, 0, 400, 0],
       [0, 0, 225, 0, 0]], dtype=object)

In [6]:
weights = np.zeros((LMAX, LMAX, LMAX, LMAX*LMAX))
for i in range(LMAX):
    for j in range(LMAX):
        for s in range(LMAX):
            bf = binom_factor_sym(i, j, a, b, s)
            val = get_coeffs(bf)
            if np.random.rand()**LMAX > .7:
              print((i, j, s), val, bf)
            weights[i, j, s, :] = val

(0, 1, 2) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')
(0, 1, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')
(0, 3, 3) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')
(1, 0, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')
(2, 1, 3) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')
(3, 0, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')
(3, 1, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**3 + 3*a**2*b, a, b, domain='ZZ')
(3, 4, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0) Poly(a**3*b**4, a, b, domain='ZZ')
(4, 2, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [7]:
import jax.numpy as jnp
def get_monomials(a,b):
    a_pows = a ** jnp.arange(LMAX)
    b_pows = b ** jnp.arange(LMAX)
    ans = a_pows.reshape(LMAX,1) @ b_pows.reshape(1,LMAX)
    return ans.reshape(LMAX*LMAX)

f = lambda x: np.array(x, dtype=np.float32)
fa,fb = f(1.1),f(2.33)
got = get_monomials(fa,fb)
print(got)

expect = lambdify((a,b), all_monomials, "numpy")(fa,fb).reshape(LMAX*LMAX)
print(expect)

np.testing.assert_allclose(got,expect)


[ 1.         2.33       5.4289    12.649336  29.472952   1.1
  2.563      5.97179   13.914269  32.420246   1.21       2.8193
  6.568969  15.3056965 35.662273   1.3310001  3.1012301  7.225866
 16.836267  39.2285     1.4641001  3.411353   7.948453  18.519894
 43.151352 ]
[ 1.          2.32999992  5.42889977 12.64933576 29.47295135  1.10000002
  2.56299996  5.97178984 13.91426963 32.42024719  1.21000004  2.81929994
  6.56896877 15.30569675 35.66227226  1.33100009  3.1012301   7.22586606
 16.83626699 39.2285008   1.46410013  3.41135318  7.94845284 18.51989409
 43.15135181]


In [8]:
inds = np.nonzero(weights)
print(f'nnz={len(inds[0])}')

import inspect

with np.printoptions(threshold=np.inf, formatter={'float':lambda x:f'{x:.10g}'}):
  with open("../pyscf_ipu/experimental/binom_factor_table.py", "w") as file:
    print(
        f"""# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
# AUTOGENERATED from notebooks/binom_factor_table.ipynb

# fmt: off
# flake8: noqa
# isort: skip_file

import jax.numpy as jnp
import numpy as np
array = np.array

LMAX = {LMAX}
{inspect.getsource(get_monomials)}

def build_binom_factor_table(sparse=False):
    inds,values = {repr((inds, weights[inds]))}
    if sparse:
      return inds,values
    else:
      W = np.zeros((LMAX,LMAX,LMAX,LMAX*LMAX))
      W[inds] = values
      return W
""",
        file=file,
        end=''
    )

nnz=190


## Test

Import the file we just generated, and check it works

In [18]:
from pyscf_ipu.experimental import binom_factor_table

aval, bval = f(1.1), f(2.2)

monomials = binom_factor_table.get_monomials(aval, bval)
print(f"{monomials=}")

W = jnp.array(binom_factor_table.build_binom_factor_table(), dtype=jnp.float32)
table_ab = W @ monomials.reshape(LMAX*LMAX,1)

for i in range(LMAX):
    for j in range(LMAX):
        for s in range(LMAX):
            bf = binom_factor(i, j, aval, bval, s)
            if np.random.rand() ** LMAX > 0.6:
                print((i, j, s), bf, table_ab[i, j, s])
            np.testing.assert_allclose(bf, table_ab[i, j, s], rtol=1e-6)

monomials=DeviceArray([ 1.       ,  2.2      ,  4.84     , 10.648001 , 23.425602 ,
              1.1      ,  2.42     ,  5.3240004, 11.712801 , 25.768162 ,
              1.21     ,  2.6620002,  5.8564005, 12.884081 , 28.344978 ,
              1.3310001,  2.9282002,  6.4420404, 14.17249  , 31.179478 ,
              1.4641001,  3.2210202,  7.0862446, 15.589739 , 34.29743  ],            dtype=float32)
(0, 3, 3) 1.0 [1.]
(0, 4, 2) 29.040000915527344 [29.04]
(1, 0, 2) 0 [0.]
(1, 0, 3) 0 [0.]
(2, 0, 1) 2.200000047683716 [2.2]
(2, 1, 4) 0 [0.]
(2, 3, 0) 12.88408124395375 [12.884081]
(2, 3, 2) 50.578002816677134 [50.578003]
(2, 4, 0) 28.344979351059116 [28.344978]
(3, 1, 2) 10.890000429153446 [10.89]
(3, 1, 3) 5.5000001192092896 [5.5]
(4, 2, 2) 60.02810437345515 [60.028107]
(4, 2, 4) 31.460001220703134 [31.460001]
(4, 3, 2) 164.27203597465098 [164.27203]
(4, 3, 4) 127.77600698661814 [127.776]
