In [5]:
from jax.scipy.special import betaln, gammainc, gammaln
import numpy as np
from sympy import *

## Table lookup for "binom factor"

The function called "binom factor" is one of those with a variable-sized loop, 
which is tricky to vectorize.  Its definition is
$$
B(i,j,a,b,s) = \sum_{n=s-i}^j \binom{i}{s-n} \binom{j}{n} a^{i-(s-n)} b^{j-n}
$$

Its arguments are three integers $(i,j,s)$ and two floats.  It can't be a lookup table because of the floats, and it can't be a simple table of functions because JAX would not be impressed.

But... For a given $i,j,s$ we know it will be a polynomial in $(a,b)$:
$$
B(i,j,a,b,s) = \sum_{p=1}^{max} \sum_{q=1}^{max} w^{i,j,s}_{p,q} a^p b^q
$$

This notebook computes those weights.  If max=5, that is a 5x5x5x25 array.

In [6]:
kilobytes = 5*5*5*25 * 4 / 1024
print(f'{kilobytes=}')

kilobytes=12.20703125


In fact, they are all small integers and it is sparse, so it could be less, but certainly not a huge memory burden.

In [7]:

def binom(x, y):
    approx = 1.0 / ((x + 1) * np.exp(betaln(x - y + 1, y + 1)))
    return int(np.rint(approx))

def binom_factor(i: int, j: int, a: float, b: float, s: int):
    out = 0
    for t in range(max(s - i, 0), j + 1):
        assert ((s - i) <= t) & (t <= j)
        val = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t)
        out += val
    return out

a,b = symbols("a b", real=True)
LMAX = 4
for i in range(LMAX):
  for j in range(LMAX):
    for s in range(LMAX):
      bf = binom_factor(i,j,a,b,s)
      print((i,j,s),bf)

bf

(0, 0, 0) 1
(0, 0, 1) 0
(0, 0, 2) 0
(0, 0, 3) 0
(0, 1, 0) b
(0, 1, 1) 1
(0, 1, 2) 0
(0, 1, 3) 0
(0, 2, 0) b**2
(0, 2, 1) 2*b
(0, 2, 2) 1
(0, 2, 3) 0
(0, 3, 0) b**3
(0, 3, 1) 3*b**2
(0, 3, 2) 3*b
(0, 3, 3) 1
(1, 0, 0) a
(1, 0, 1) 1
(1, 0, 2) 0
(1, 0, 3) 0
(1, 1, 0) a*b
(1, 1, 1) a + b
(1, 1, 2) 1
(1, 1, 3) 0
(1, 2, 0) a*b**2
(1, 2, 1) 2*a*b + b**2
(1, 2, 2) a + 2*b
(1, 2, 3) 1
(1, 3, 0) a*b**3
(1, 3, 1) 3*a*b**2 + b**3
(1, 3, 2) 3*a*b + 3*b**2
(1, 3, 3) a + 3*b
(2, 0, 0) a**2
(2, 0, 1) 2*a
(2, 0, 2) 1
(2, 0, 3) 0
(2, 1, 0) a**2*b
(2, 1, 1) a**2 + 2*a*b
(2, 1, 2) 2*a + b
(2, 1, 3) 1
(2, 2, 0) a**2*b**2
(2, 2, 1) 2*a**2*b + 2*a*b**2
(2, 2, 2) a**2 + 4*a*b + b**2
(2, 2, 3) 2*a + 2*b
(2, 3, 0) a**2*b**3
(2, 3, 1) 3*a**2*b**2 + 2*a*b**3
(2, 3, 2) 3*a**2*b + 6*a*b**2 + b**3
(2, 3, 3) a**2 + 6*a*b + 3*b**2
(3, 0, 0) a**3
(3, 0, 1) 3*a**2
(3, 0, 2) 3*a
(3, 0, 3) 1
(3, 1, 0) a**3*b
(3, 1, 1) a**3 + 3*a**2*b
(3, 1, 2) 3*a**2 + 3*a*b
(3, 1, 3) 3*a + b
(3, 2, 0) a**3*b**2
(3, 2, 1) 2*a**3*b + 3*a**

a**3 + 9*a**2*b + 9*a*b**2 + b**3

In [8]:
# https://stackoverflow.com/questions/74731353/is-there-any-all-coeffs-for-multivariable-polynomials-in-sympy
def all_coeffs(expr,*free):
    x = IndexedBase('x')
    expr = expr.expand()
    free = list(free) or list(expr.free_symbols)
    pows = [p.as_base_exp() for p in expr.atoms(Pow,Symbol)]
    P = {}
    for p,e in pows:
        if p not in free:
            continue
        elif p not in P:
            P[p]=e
        elif e>P[p]:
            P[p] = e
    reps = dict([(f, x[i]) for i,f in enumerate(free)])
    xzero = dict([(v,0) for k,v in reps.items()])
    e = expr.xreplace(reps); reps = {v:k for k,v in reps.items()}
    return dict([(m.xreplace(reps), e.coeff(m).xreplace(xzero) if m!=1 else e.xreplace(xzero)) for m in monoms(*[P[f] for f in free])])

def monoms(*o):
    x = IndexedBase('x')
    f = []
    for i,o in enumerate(o):
        f.append(Poly([1]*(o+1),x[i]).as_expr())
    return Mul(*f).expand().args

all_coeffs(S(bf))

{1: 0,
 b**2: 0,
 b**3: 1,
 a**2: 0,
 a**3: 1,
 a**2*b**2: 0,
 a**3*b**2: 0,
 a*b**2: 9,
 a**2*b**3: 0,
 a**3*b**3: 0,
 a*b**3: 0,
 a**2*b: 9,
 a**3*b: 0,
 a*b: 0,
 b: 0,
 a: 0}

In [9]:
LMAX = 5
all_monoms = set()
for i in range(LMAX):
  for j in range(LMAX):
    for s in range(LMAX):
      bf = binom_factor(i,j,a,b,s)
      coefs = all_coeffs(S(bf))
      all_monoms = all_monoms.union(set(coefs.keys()))
all_monoms = tuple(all_monoms)
n = len(all_monoms)
display(n, S(all_monoms))
#monom_map = {k:i for i,k in enumerate(all_monoms)}

weights = np.zeros((LMAX, LMAX, LMAX, n))
for i in range(LMAX):
  for j in range(LMAX):
    for s in range(LMAX):
      bf = binom_factor(i,j,a,b,s)
      coefs = all_coeffs(S(bf))
      #display(bf, S(coefs))
      val = [coefs.get(monom, 0) for monom in all_monoms]
      print((i,j,s), val, bf)
      weights[i,j,s,:] = val

25

(1, a*b**2, b**4, a*b, b, a**3*b**2, a**4*b**3, a**4, a**4*b**4, a**2*b**3, a**2*b, b**3, a**2*b**4, a**3, a**3*b, a, a**3*b**3, a*b**3, a**3*b**4, a*b**4, a**4*b**2, b**2, a**2, a**2*b**2, a**4*b)

(0, 0, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1
(0, 0, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0
(0, 0, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0
(0, 0, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0
(0, 0, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0
(0, 1, 0) [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] b
(0, 1, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1
(0, 1, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0
(0, 1, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0
(0, 1, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0
(0, 2, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] b**2
(0, 2, 1) [0, 0, 0, 0, 2, 0, 

In [23]:
inds = np.nonzero(weights)
inds, weights[inds]

with open("../pyscf_ipu/experimental/binom_factor_table.py", "w") as f:
  print("# AUTOGENERATED from notebooks/binom_factor_table.ipynb", file=f)
  print("from numpy import array", file=f)
  print("binom_factor_table = ", repr((inds, weights[inds])), file=f)

import pyscf_ipu.experimental.binom_factor_table

pyscf_ipu.experimental.binom_factor_table.binom_factor_table

((array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
         4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
         4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]),
  array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3,
         3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 1, 1, 1, 1, 1, 2, 2,
         2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4,
         4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4,
         4, 4, 4, 4, 4, 4,