In [1]:
from jax.scipy.special import betaln, gammainc, gammaln
import numpy as np
from sympy import *

## Table lookup for "binom factor"

The function called "binom factor" is one of those with a variable-sized loop, 
which is tricky to vectorize.  Its definition is
$$
B(i,j,a,b,s) = \sum_{n=s-i}^j \binom{i}{s-n} \binom{j}{n} a^{i-(s-n)} b^{j-n}
$$

Its arguments are three integers $(i,j,s)$ and two floats.  It can't be a lookup table because of the floats, and it can't be a simple table of functions because JAX would not be impressed.

But... For a given $i,j,s$ we know it will be a polynomial in $(a,b)$:
$$
B(i,j,a,b,s) = \sum_{p=1}^{max} \sum_{q=1}^{max} w^{i,j,s}_{p,q} a^p b^q
$$

This notebook computes those weights.  If max=8, that is a 128KB array so might need to be sparsely stored.

In [2]:
LMAX = 8
kilobytes = LMAX ** 5 * 4 / 1024
print(f"{kilobytes=}")

kilobytes=128.0


Let's now compute the coefficients W.  First, we just define the functions, then pass them through sympy.

In [3]:
def binom(x, y):
    approx = 1.0 / ((x + 1) * np.exp(betaln(x - y + 1, y + 1)))
    return int(np.rint(approx))


def binom_factor(i: int, j: int, a: float, b: float, s: int):
    return sum(
        binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t)
        for t in range(max(s - i, 0), j + 1)
    )


def binom_factor_sym(i: int, j: int, a: Symbol, b: Symbol, s: int):
    return Poly(binom_factor(i, j, a, b, s), a, b)


a, b = symbols("a b", real=True)
for i in range(0, 7, 3):
    for j in range(0, 7, 2):
        for s in range(0, 7, 2):
            bf = binom_factor_sym(i, j, a, b, s)
            print((i, j, s), bf)

display(bf, bf.coeff_monomial(a**3 * b**3))



(0, 0, 0) Poly(1, a, b, domain='ZZ')
(0, 0, 2) Poly(0, a, b, domain='ZZ')
(0, 0, 4) Poly(0, a, b, domain='ZZ')
(0, 0, 6) Poly(0, a, b, domain='ZZ')
(0, 2, 0) Poly(b**2, a, b, domain='ZZ')
(0, 2, 2) Poly(1, a, b, domain='ZZ')
(0, 2, 4) Poly(0, a, b, domain='ZZ')
(0, 2, 6) Poly(0, a, b, domain='ZZ')
(0, 4, 0) Poly(b**4, a, b, domain='ZZ')
(0, 4, 2) Poly(6*b**2, a, b, domain='ZZ')
(0, 4, 4) Poly(1, a, b, domain='ZZ')
(0, 4, 6) Poly(0, a, b, domain='ZZ')
(0, 6, 0) Poly(b**6, a, b, domain='ZZ')
(0, 6, 2) Poly(15*b**4, a, b, domain='ZZ')
(0, 6, 4) Poly(15*b**2, a, b, domain='ZZ')
(0, 6, 6) Poly(1, a, b, domain='ZZ')
(3, 0, 0) Poly(a**3, a, b, domain='ZZ')
(3, 0, 2) Poly(3*a, a, b, domain='ZZ')
(3, 0, 4) Poly(0, a, b, domain='ZZ')
(3, 0, 6) Poly(0, a, b, domain='ZZ')
(3, 2, 0) Poly(a**3*b**2, a, b, domain='ZZ')
(3, 2, 2) Poly(a**3 + 6*a**2*b + 3*a*b**2, a, b, domain='ZZ')
(3, 2, 4) Poly(3*a + 2*b, a, b, domain='ZZ')
(3, 2, 6) Poly(0, a, b, domain='ZZ')
(3, 4, 0) Poly(a**3*b**4, a, b, domain='

Poly(a**6 + 36*a**5*b + 225*a**4*b**2 + 400*a**3*b**3 + 225*a**2*b**4 + 36*a*b**5 + b**6, a, b, domain='ZZ')

400

In [4]:
monomials_a = Matrix([a ** i for i in range(LMAX)])
monomials_b = Matrix([b ** i for i in range(LMAX)])

all_monomials = monomials_a * monomials_b.transpose()
display(all_monomials)
all_monomials = all_monomials.reshape(1,LMAX**2)

display(all_monomials)


Matrix([
[   1,      b,      b**2,      b**3,      b**4,      b**5,      b**6,      b**7],
[   a,    a*b,    a*b**2,    a*b**3,    a*b**4,    a*b**5,    a*b**6,    a*b**7],
[a**2, a**2*b, a**2*b**2, a**2*b**3, a**2*b**4, a**2*b**5, a**2*b**6, a**2*b**7],
[a**3, a**3*b, a**3*b**2, a**3*b**3, a**3*b**4, a**3*b**5, a**3*b**6, a**3*b**7],
[a**4, a**4*b, a**4*b**2, a**4*b**3, a**4*b**4, a**4*b**5, a**4*b**6, a**4*b**7],
[a**5, a**5*b, a**5*b**2, a**5*b**3, a**5*b**4, a**5*b**5, a**5*b**6, a**5*b**7],
[a**6, a**6*b, a**6*b**2, a**6*b**3, a**6*b**4, a**6*b**5, a**6*b**6, a**6*b**7],
[a**7, a**7*b, a**7*b**2, a**7*b**3, a**7*b**4, a**7*b**5, a**7*b**6, a**7*b**7]])

Matrix([[1, b, b**2, b**3, b**4, b**5, b**6, b**7, a, a*b, a*b**2, a*b**3, a*b**4, a*b**5, a*b**6, a*b**7, a**2, a**2*b, a**2*b**2, a**2*b**3, a**2*b**4, a**2*b**5, a**2*b**6, a**2*b**7, a**3, a**3*b, a**3*b**2, a**3*b**3, a**3*b**4, a**3*b**5, a**3*b**6, a**3*b**7, a**4, a**4*b, a**4*b**2, a**4*b**3, a**4*b**4, a**4*b**5, a**4*b**6, a**4*b**7, a**5, a**5*b, a**5*b**2, a**5*b**3, a**5*b**4, a**5*b**5, a**5*b**6, a**5*b**7, a**6, a**6*b, a**6*b**2, a**6*b**3, a**6*b**4, a**6*b**5, a**6*b**6, a**6*b**7, a**7, a**7*b, a**7*b**2, a**7*b**3, a**7*b**4, a**7*b**5, a**7*b**6, a**7*b**7]])

## get_coeffs

A function to get the coefficients from a polynomial, in the order defined by monomials

In [5]:
def get_coeffs(p):
  return tuple(p.coeff_monomial(m) for m in all_monomials)

np.array(get_coeffs(binom_factor_sym(6,6,a,b,6))).reshape(LMAX,LMAX)

array([[0, 0, 0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 36, 0, 0],
       [0, 0, 0, 0, 225, 0, 0, 0],
       [0, 0, 0, 400, 0, 0, 0, 0],
       [0, 0, 225, 0, 0, 0, 0, 0],
       [0, 36, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0]], dtype=object)

## Build the weight matrix

W[i,j,s] = the polynomial coefficients for binom_factor(i,j,a,b,s)

In [6]:
weights = np.zeros((LMAX, LMAX, LMAX, LMAX*LMAX))
for i in range(LMAX):
    for j in range(LMAX):
        for s in range(LMAX):
            bf = binom_factor_sym(i, j, a, b, s)
            val = get_coeffs(bf)
            if np.random.rand()**LMAX > .7:
              print((i, j, s), val, bf)
            weights[i, j, s, :] = val

(0, 3, 2) (0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(3*b, a, b, domain='ZZ')
(0, 4, 1) (0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(4*b**3, a, b, domain='ZZ')
(0, 6, 4) (0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(15*b**2, a, b, domain='ZZ')
(1, 1, 7) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')
(1, 2, 1) (0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0

## How sparse?

In [7]:
nnz = len(np.nonzero(weights)[0])
print(f'nnz={nnz}')

sparsity = 1 - nnz/LMAX**5
print(f'elementwise sparsity = {sparsity*100:.0f}%')

all_zero_fraction = 1 - len(np.nonzero(np.sum(abs(weights), axis=3))[0]) / LMAX**3
print(f'all-zero = {all_zero_fraction*100:.0f}%')

nnz=1086
elementwise sparsity = 97%
all-zero = 16%


## get_monomials(a,b)

Given float values a,b compute the monomials as above.  This function will be exported to `binom_factor_table.py`.  

We test it here by running on a sample input

In [8]:
import jax.numpy as jnp
def get_monomials(a,b):
    a_pows = a ** jnp.arange(LMAX)
    b_pows = b ** jnp.arange(LMAX)
    ans = a_pows.reshape(LMAX,1) @ b_pows.reshape(1,LMAX)
    return ans.reshape(LMAX*LMAX)

f = lambda x: np.array(x, dtype=np.float32)
fa,fb = f(1.1),f(2.33)
got = get_monomials(fa,fb)
print(got)

expect = lambdify((a,b), all_monomials, "numpy")(fa,fb).reshape(LMAX*LMAX)
print(expect)

np.testing.assert_allclose(got,expect)


[  1.          2.33        5.4289     12.649336   29.472952   68.671974
 160.00569   372.81326     1.1         2.563       5.97179    13.914269
  32.420246   75.53918   176.00627   410.0946      1.21        2.8193
   6.568969   15.3056965  35.662273   83.093094  193.60689   451.10406
   1.3310001   3.1012301   7.225866   16.836267   39.2285     91.402405
 212.96759   496.21448     1.4641001   3.411353    7.948453   18.519894
  43.151352  100.54265   234.26436   545.83594     1.6105102   3.7524886
   8.743299   20.371885   47.46649   110.59692   257.6908    600.41956
   1.7715613   4.1277375   9.617628   22.409073   52.21314   121.65661
 283.4599    660.46155     1.9487174   4.540511   10.5793915  24.64998
  57.434452  133.82227   311.80588   726.5077   ]
[  1.           2.32999992   5.42889977  12.64933576  29.47295135
  68.6719744  160.0056951  372.81325738   1.10000002   2.56299996
   5.97178984  13.91426963  32.42024719  75.53917347 176.00626843
 410.09459201   1.21000004   2.819299

In [9]:
inds = np.nonzero(weights)
import inspect

with np.printoptions(threshold=np.inf, formatter={'float':lambda x:f'{x:.10g}'}):
  with open("../pyscf_ipu/experimental/binom_factor_table.py", "w") as file:
    print(
        f"""# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
# AUTOGENERATED from notebooks/binom_factor_table.ipynb

# fmt: off
# flake8: noqa
# isort: skip_file

import jax.numpy as jnp
import numpy as np
array = np.array

LMAX = {LMAX}
{inspect.getsource(get_monomials)}

def build_binom_factor_table(sparse=False):
    inds,values = {repr((inds, weights[inds]))}
    if sparse:
      return inds,values
    else:
      W = np.zeros((LMAX,LMAX,LMAX,LMAX*LMAX))
      W[inds] = values
      return W
""",
        file=file,
        end=''
    )

## Test

Import the file we just generated, and check it works

In [10]:
from pyscf_ipu.experimental import binom_factor_table

aval, bval = f(1.1), f(2.2)

monomials = binom_factor_table.get_monomials(aval, bval)
print(f"{monomials=}")

W = jnp.array(binom_factor_table.build_binom_factor_table(), dtype=jnp.float32)
table_ab = W @ monomials.reshape(LMAX*LMAX,1)

for i in range(LMAX):
    for j in range(LMAX):
        for s in range(LMAX):
            bf = binom_factor(i, j, aval, bval, s)
            if np.random.rand() ** LMAX > 0.6:
                print((i, j, s), bf, table_ab[i, j, s])
            np.testing.assert_allclose(bf, table_ab[i, j, s], rtol=1e-6)

monomials=DeviceArray([  1.       ,   2.2      ,   4.84     ,  10.648001 ,
              23.425602 ,  51.536327 , 113.37992  , 249.43582  ,
               1.1      ,   2.42     ,   5.3240004,  11.712801 ,
              25.768162 ,  56.68996  , 124.71792  , 274.3794   ,
               1.21     ,   2.6620002,   5.8564005,  12.884081 ,
              28.344978 ,  62.35896  , 137.18971  , 301.81735  ,
               1.3310001,   2.9282002,   6.4420404,  14.17249  ,
              31.179478 ,  68.59486  , 150.90869  , 331.9991   ,
               1.4641001,   3.2210202,   7.0862446,  15.589739 ,
              34.29743  ,  75.454346 , 165.99956  , 365.199    ,
               1.6105102,   3.5431225,   7.79487  ,  17.148714 ,
              37.727173 ,  82.99978  , 182.59952  , 401.71893  ,
               1.7715613,   3.897435 ,   8.574357 ,  18.863586 ,
              41.49989  ,  91.29976  , 200.85948  , 441.89084  ,
               1.9487174,   4.287178 ,   9.431792 ,  20.749945 ,
              4

(0, 2, 2) 1.0 [1.]
(0, 5, 0) 51.53632558509851 [51.536327]
(1, 1, 6) 0 [0.]
(2, 1, 7) 0 [0.]
(2, 2, 5) 0 [0.]
(2, 6, 7) 15.40000033378601 [15.400001]
(2, 7, 4) 3336.977076303898 [3336.977]
(3, 1, 4) 1.0 [1.]
(3, 1, 6) 0 [0.]
(3, 1, 7) 0 [0.]
(3, 3, 0) 14.172489843082529 [14.17249]
(3, 4, 6) 12.100000262260437 [12.1]
(3, 5, 0) 68.59485381402618 [68.59486]
(3, 5, 1) 342.9742594246647 [342.97427]
(3, 5, 3) 889.0016110117186 [889.0016]
(4, 2, 2) 60.02810437345515 [60.028107]
(4, 5, 5) 997.0521781336806 [997.0522]
(4, 6, 3) 4926.357549691019 [4926.358]
(5, 1, 6) 1.0 [1.]
(5, 2, 4) 93.17000511407859 [93.170006]
(5, 4, 0) 37.72717041542877 [37.727173]
(5, 6, 2) 4338.624597774309 [4338.625]
(6, 3, 6) 252.89001389455817 [252.89001]
(6, 4, 5) 1903.623008021071 [1903.6229]
(6, 7, 0) 441.8908399527361 [441.89084]
(7, 0, 2) 33.8207136652209 [33.820713]
(7, 2, 7) 64.13000242233278 [64.130005]
(7, 5, 4) 12625.740273048325 [12625.74]
(7, 6, 2) 8353.927707221395 [8353.928]
