In [2]:
import json

# Read json file
file = './ggl-data-original.json'
data = json.load(open(file))

In [1]:
def verify(key, fn):
    for d in data:
        r = d['r']
        if not jnp.allclose(jnp.array(d[key]), fn(r)):
            jax.debug.print(
                "Failed for {}, {} - {} = {}",
                r,
                jnp.array(d[key]),
                fn(r),
                jnp.array(d[key]) - fn(r)
            )
            return False

In [3]:
data[0]

{'r': 0, 'x': [-1.0, 1.0], 'w': [1.0, 1.0], 'd': [[-0.5, 0.5], [-0.5, 0.5]]}

In [3]:
from jax import Array, jit, grad

import jax.lax
import jax.numpy as jnp
import scipy.special as sps

In [5]:
r = 0
c = sps.legendre(r + 1).c
xs = jnp.real(jnp.roots(jnp.polyder(c)))
xs

Array([], dtype=float32)

In [6]:
jnp.poly(jnp.array([-1, 1]))

Array([ 1.,  0., -1.], dtype=float32)

In [7]:
jnp.roots(jnp.polymul(jnp.array([1, 0, -1]), jnp.polyder(c)))

Array([-1.+0.j,  1.+0.j], dtype=complex64)

In [5]:
def xs(r: int) -> Array:
    c = sps.legendre(r + 1).c
    return jnp.sort(jnp.roots(jnp.polymul(jnp.array([1, 0, -1]), jnp.polyder(c))))

In [9]:
jnp.array(data[3]['x']) - jnp.real(xs(3))

Array([ 3.5762787e-07, -5.9604645e-08,  0.0000000e+00, -1.1920929e-07,
        1.1920929e-07], dtype=float32)

In [24]:
data[3]['x']

[-1.0, -0.6546536707079772, 0.0, 0.6546536707079772, 1.0]

In [33]:
(jnp.array(data[3]['x']) - jnp.sort(jnp.real(xs(3))))

Array([ 3.5762787e-07, -5.9604645e-08,  0.0000000e+00, -1.1920929e-07,
        1.1920929e-07], dtype=float32)

In [6]:
verify('x', xs)

In [73]:
data[9]['x']

[-1.0,
 -0.9340014304080592,
 -0.7844834736631444,
 -0.565235326996205,
 -0.2957581355869394,
 0.0,
 0.2957581355869394,
 0.565235326996205,
 0.7844834736631444,
 0.9340014304080592,
 1.0]

In [74]:
jnp.real(xs(9))

Array([-0.99999464, -0.9340094 , -0.7844821 , -0.56523544, -0.2957581 ,
        0.        ,  0.29575828,  0.56523496,  0.7844805 ,  0.9340178 ,
        0.9999873 ], dtype=float32)

In [8]:
def ws(r: int) -> Array:
    legendre = sps.legendre(r + 1).c
    poly = jnp.polymul(
        jnp.array([1, 0, -1]),
        jnp.polyder(legendre)
    )

    legendre_at_xs = jnp.polyval(legendre, xs(r))
    ws = 2 / ((r + 1) * (r + 2) * legendre_at_xs ** 2)
    
    return ws

In [37]:
(jnp.array(data[3]['w']) - jnp.sort(jnp.real(ws(3))))

Array([ 8.5401542e+16, -2.2448361e+12, -8.9793453e+12, -2.3058436e+18,
       -3.3908877e+31], dtype=float32)

In [10]:
jnp.real(ws(1))

Array([0.33333333, 1.33333333, 0.33333333], dtype=float64)

In [9]:
verify('w', ws)

In [4]:
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)

In [66]:
def dij(r: int):
    dij_a = -1 / 4 * (r + 1) * (r + 2)
    dij_b = -dij_a
    legendre = sps.legendre(r + 1).c
    xss = jnp.real(xs(r))
    legendre_at_xs = jnp.polyval(legendre, xss)

    # TODO: Is there a nicer way to do this? Maybe jax.lax.switch?
    @jit
    def derivative_matrix_element(i_, j_):
        # For some reason we get floats here, convert to ints for indexing.
        i = i_.astype('int')
        j = j_.astype('int')

        return jax.lax.cond(
            jax.lax.eq(i, j),
            lambda: jax.lax.cond(
                jax.lax.eq(i, 0),
                lambda: dij_a,
                lambda: jax.lax.cond(
                    jax.lax.eq(i, r + 1),
                    lambda: dij_b,
                    lambda: 0.0,
                )
            ),
            lambda: legendre_at_xs[i] / (legendre_at_xs[j] * (xss[i] - xss[j]))
        )

    derivative_matrix = jnp.fromfunction(
        derivative_matrix_element,
        shape=(r + 2, r + 2),
    )
    
    return derivative_matrix


In [67]:
dij(1) 

Array([[-1.5,  2. , -0.5],
       [-0.5,  0. ,  0.5],
       [ 0.5, -2. ,  1.5]], dtype=float64)

In [68]:
dij(1) - jnp.array(data[1]['d'])

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float64)

In [69]:
verify('d', dij)