In [1]:
from jax.scipy.special import betaln, gammainc, gammaln
import numpy as np
from sympy import *

## Table lookup for "binom factor"

The function called "binom factor" is one of those with a variable-sized loop, 
which is tricky to vectorize.  Its definition is
$$
B(i,j,a,b,s) = \sum_{n=s-i}^j \binom{i}{s-n} \binom{j}{n} a^{i-(s-n)} b^{j-n}
$$

Its arguments are three integers $(i,j,s)$ and two floats.  It can't be a lookup table because of the floats, and it can't be a simple table of functions because JAX would not be impressed.

But... For a given $i,j,s$ we know it will be a polynomial in $(a,b)$:
$$
B(i,j,a,b,s) = \sum_{p=1}^{max} \sum_{q=1}^{max} w^{i,j,s}_{p,q} a^p b^q
$$

This notebook computes those weights.  If max=8, that is a 128KB array so might need to be sparsely stored.

In [2]:
LMAX = 4
kilobytes = LMAX ** 5 * 4 / 1024
print(f"{kilobytes=}")

kilobytes=12.20703125


Let's now compute the coefficients W.  First, we just define the functions, then pass them through sympy.

In [3]:
def binom(x, y):
    approx = 1.0 / ((x + 1) * np.exp(betaln(x - y + 1, y + 1)))
    return int(np.rint(approx))


def binom_factor(i: int, j: int, a: float, b: float, s: int):
    out = 0
    for t in range(max(s - i, 0), j + 1):
        assert ((s - i) <= t) & (t <= j)
        val = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t)
        out += val
    return out


a, b = symbols("a b", real=True)
for i in range(3):
    for j in range(3):
        for s in range(5):
            bf = binom_factor(i, j, a, b, s)
            print((i, j, s), bf)

bf



(0, 0, 0) 1
(0, 0, 1) 0
(0, 0, 2) 0
(0, 0, 3) 0
(0, 0, 4) 0
(0, 1, 0) b
(0, 1, 1) 1
(0, 1, 2) 0
(0, 1, 3) 0
(0, 1, 4) 0
(0, 2, 0) b**2
(0, 2, 1) 2*b
(0, 2, 2) 1
(0, 2, 3) 0
(0, 2, 4) 0
(1, 0, 0) a
(1, 0, 1) 1
(1, 0, 2) 0
(1, 0, 3) 0
(1, 0, 4) 0
(1, 1, 0) a*b
(1, 1, 1) a + b
(1, 1, 2) 1
(1, 1, 3) 0
(1, 1, 4) 0
(1, 2, 0) a*b**2
(1, 2, 1) 2*a*b + b**2
(1, 2, 2) a + 2*b
(1, 2, 3) 1
(1, 2, 4) 0
(2, 0, 0) a**2
(2, 0, 1) 2*a
(2, 0, 2) 1
(2, 0, 3) 0
(2, 0, 4) 0
(2, 1, 0) a**2*b
(2, 1, 1) a**2 + 2*a*b
(2, 1, 2) 2*a + b
(2, 1, 3) 1
(2, 1, 4) 0
(2, 2, 0) a**2*b**2
(2, 2, 1) 2*a**2*b + 2*a*b**2
(2, 2, 2) a**2 + 4*a*b + b**2
(2, 2, 3) 2*a + 2*b
(2, 2, 4) 1


1

In [4]:
# https://stackoverflow.com/questions/74731353/is-there-any-all-coeffs-for-multivariable-polynomials-in-sympy
def all_coeffs(expr, *free):
    if isinstance(expr, (int, Number)):
        return {S(1): N(expr)}

    x = IndexedBase("x")
    expr = expr.expand()
    free = list(free) or list(expr.free_symbols)
    pows = [p.as_base_exp() for p in expr.atoms(Pow, Symbol)]
    P = {}
    for p, e in pows:
        if p not in free:
            continue
        elif p not in P:
            P[p] = e
        elif e > P[p]:
            P[p] = e
    reps = dict([(f, x[i]) for i, f in enumerate(free)])
    xzero = dict([(v, 0) for k, v in reps.items()])
    e = expr.xreplace(reps)
    reps = {v: k for k, v in reps.items()}
    ans = dict(
        [
            (
                m.xreplace(reps),
                e.coeff(m).xreplace(xzero) if m != 1 else e.xreplace(xzero),
            )
            for m in monomials(*[P[f] for f in free])
        ]
    )
    return {m: w for m, w in ans.items() if w != 0}


def monomials(*o):
    x = IndexedBase("x")
    f = []
    for i, o in enumerate(o):
        f.append(Poly([1] * (o + 1), x[i]).as_expr())
    return Mul(*f).expand().args


assert all_coeffs(4 + a + 3 * b * a) == {1: 4, a: 1, a * b: 3}
display(S(all_coeffs(4 + a + 3 * b * a)))

bf = binom_factor(0, 2, a, b, 0)
display(bf, all_coeffs(bf))
assert all_coeffs(b * b) == {b * b: 1}

bf = binom_factor(0, 0, a, b, 0)
assert bf == 1
display(bf, all_coeffs(bf))
assert all_coeffs(bf) == {1: 1}

bf = binom_factor(1, 2, a, b, 1)
display(bf)


{1: 4, a: 1, a*b: 3}

b**2

{b**2: 1}

1

{1: 1.00000000000000}

2*a*b + b**2

In [5]:
all_monomials = set()
for i in range(LMAX):
    for j in range(LMAX):
        for s in range(LMAX):
            bf = binom_factor(i, j, a, b, s)
            if bf:
                coefs = all_coeffs(bf)
                if LMAX < 5 or i%2 and j%2:
                    display(S(((i, j, s), bf, coefs)))
                all_monomials = all_monomials.union(set(coefs.keys()))
all_monomials = tuple(all_monomials)
n = len(all_monomials)
print(f"{len(all_monomials)=}")
display(S(all_monomials))
# monom_map = {k:i for i,k in enumerate(all_monomials)}

weights = np.zeros((LMAX, LMAX, LMAX, n))
for i in range(LMAX):
    for j in range(LMAX):
        for s in range(LMAX):
            bf = binom_factor(i, j, a, b, s)
            coefs = all_coeffs(bf)
            val = [coefs.get(monom, 0) for monom in all_monomials]
            if LMAX < 5:
              print((i, j, s), val, bf)
            weights[i, j, s, :] = val

((1, 1, 0), a*b, {a*b: 1})

((1, 1, 1), a + b, {a: 1, b: 1})

((1, 1, 2), 1, {1: 1.0})

((1, 3, 0), a*b**3, {a*b**3: 1})

((1, 3, 1), 3*a*b**2 + b**3, {b**3: 1, a*b**2: 3})

((1, 3, 2), 3*a*b + 3*b**2, {b**2: 3, a*b: 3})

((1, 3, 3), a + 3*b, {a: 1, b: 3})

((1, 3, 4), 1, {1: 1.0})

((3, 1, 0), a**3*b, {a**3*b: 1})

((3, 1, 1), a**3 + 3*a**2*b, {a**3: 1, a**2*b: 3})

((3, 1, 2), 3*a**2 + 3*a*b, {a**2: 3, a*b: 3})

((3, 1, 3), 3*a + b, {a: 3, b: 1})

((3, 1, 4), 1, {1: 1.0})

((3, 3, 0), a**3*b**3, {a**3*b**3: 1})

((3, 3, 1), 3*a**3*b**2 + 3*a**2*b**3, {a**2*b**3: 3, a**3*b**2: 3})

((3, 3, 2), 3*a**3*b + 9*a**2*b**2 + 3*a*b**3, {a*b**3: 3, a**2*b**2: 9, a**3*b: 3})

((3, 3, 3), a**3 + 9*a**2*b + 9*a*b**2 + b**3, {a**3: 1, b**3: 1, a*b**2: 9, a**2*b: 9})

((3, 3, 4), 3*a**2 + 9*a*b + 3*b**2, {a**2: 3, b**2: 3, a*b: 9})

len(all_monomials)=25


(1, a**2*b, a**4, a**2*b**4, b**4, a**3*b**2, a**4*b**2, a*b**4, a**3*b, a**3*b**3, a**2, a*b, a**4*b**3, b**2, a**4*b, a, a**3*b**4, a**4*b**4, a**2*b**2, a*b**2, b, b**3, a**3, a**2*b**3, a*b**3)

In [6]:
inds = np.nonzero(weights)
print(f'nnz={len(inds[0])}')

with np.printoptions(threshold=np.inf, formatter={'float':lambda x:f'{x:.10g}'}):
  with open("../pyscf_ipu/experimental/binom_factor_table.py", "w") as f:
    print(
        f"""# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
# AUTOGENERATED from notebooks/binom_factor_table.ipynb

# fmt: off
# flake8: noqa
# isort: skip_file

from numpy import array,zeros

LMAX = {LMAX}
def get_monomials(a,b):
    return {all_monomials}

def build_binom_factor_table(sparse=False):
    inds,values = {repr((inds, weights[inds]))}
    if sparse:
      return inds,values
    else:
      W = zeros((LMAX,LMAX,LMAX,LMAX*LMAX))
      W[inds] = values
      return W
""",
        file=f,
        end=''
    )

nnz=190


## Test

Import the file we just generated, and check it works

In [8]:
from pyscf_ipu.experimental import binom_factor_table

aval, bval = 1.1, 2.2

monomials = binom_factor_table.get_monomials(aval, bval)
print(f"{monomials=}")

W = binom_factor_table.build_binom_factor_table()

table_ab = W @ monomials

for i in range(LMAX):
    for j in range(LMAX):
        for s in range(LMAX):
            bf = binom_factor(i, j, aval, bval, s)
            if LMAX < 5:
                print((i, j, s), bf, table_ab[i, j, s])
            np.testing.assert_allclose(bf, table_ab[i, j, s])

monomials=(1, 2.662000000000001, 1.4641000000000004, 28.344976000000013, 23.425600000000006, 6.442040000000003, 7.086244000000003, 25.76816000000001, 2.9282000000000012, 14.172488000000008, 1.2100000000000002, 2.4200000000000004, 15.58973680000001, 4.840000000000001, 3.221020000000001, 1.1, 31.17947360000002, 34.29742096000002, 5.856400000000002, 5.324000000000002, 2.2, 10.648000000000003, 1.3310000000000004, 12.884080000000006, 11.712800000000005)
