In [1]:
%pylab inline
import numpy as np
import jax
import jax.numpy as jnp
import shtns

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


In [4]:
def gen_ml(N):
    return np.array([[m, l] for m in range(N+1) for l in range(N+1) if m<=l])

def all_amm(N):
    amm = []
    for m in range(N+1):
        els = []
        for k in range(1,m+1):
            els.append((2*k+1)/(2*k))
        el = np.prod(els)/(4*np.pi)
        amm.append(np.sqrt(el))
    amm = np.array(amm)
    return amm

# @jax.jit
def amn(m, n):
    nom = 4*n*n - 1
    den = n*n - m*m
    out = jnp.sqrt(nom/den)
    return out

# @jax.jit
def bmn(m, n):
    l = (2*n + 1)/(2*n - 3)
    r = ((n - 1)*(n - 1) - m*m)/(n*n - m*m)
    return - jnp.sqrt(l*r)

def lat_grid(nlat):
    x = np.r_[:nlat] / nlat * np.pi
    cx = np.cos(x)
    return cx

In [None]:
# @jax.jit
def LT(amm: np.ndarray[1,np.float32],
       cx: np.ndarray[1,np.float32],
       fx: np.ndarray[2,np.complex64]) -> np.ndarray[1,np.complex64]:
    
    N = amm.size - 1  # lmax
    nlat = cx.size
    i = 0
    
    # since cx is a vector, this is vectorized over latitudes
    # TODO convert to scan, it won't really work otherwise
    acc = jnp.zeros_like(cx)
    for m in range(N+1):
        # eq 13
        n = m
        p0 = amm[m]*(1 - cx*cx)**(m/2)*(-1)**m
        print(i, m, n, p0[0]); i+=1
        if n == N:
            break
        # eq 14
        n += 1
        p1 = amn(m, n) * cx * p0
        print(i, m, n, p1[0]); i+=1
        if n == N:
            continue
        # eq 15 base case
        n += 1
        p2 = amn(m, n)*cx*p1 + bmn(m, n)*p0
        print(i, m, n, p2[0]); i+=1
        if n == N:
            continue
        # eq 15 iterate
        while n < N:
            p0, p1 = p1, p2
            n += 1
            p2 = amn(m, n)*cx*p1 + bmn(m, n)*p0
            print(i, m, n, p2[0]); i+=1
    return p2

The above doesn't actually do the transform, just a complex loop to generate the coefficients.  Since Jax wants a more functional approach, and that loop above could result in many extra nodes in the graph, we want to convert that in to a scan over `m,n` sequence, where we carry whatever required, and the result will be just the `m,n` vector which is the result.

When including the `fx` Fourier transformed data, it may be helpful to transpose the array for locality.

First we need to break the loop above into cases and a function to map a given step to the specific case,

In [32]:
# TODO may be useful/necessary to do all this as a closure over some constant data arrays?
# TODO may be useful to have a pytree to collect arrays?

from collections import namedtuple

# instances of this instead of many args, jax will flatten automatically
# LTRecurState = namedtuple('LTRecurState', 'stop m n p0 p1 p2'.split(' '))

from scipy.special import sph_harm
check_Pmn = lambda p, m, n, cx: \
    np.testing.assert_allclose(p, 
        sph_harm(m, n, 0, np.arccos(cx)).real, rtol=1e-5, atol=1e-6)


def make_lt_recur(lmax, nlat):
    
    amm = all_amm(lmax)
    cx = lat_grid(nlat)
    N = amm.size - 1
    ml = np.array([[m, l] for m in range(N+1) for l in range(N+1) if m<=l])

    def eq13(i, amm_m, m, n, _1, _2, _3):
        n = m
        p0 = amm_m*(1 - cx*cx)**(m/2)*(-1)**m
        check_Pmn(p0, m, n, cx)
        stop = n == N
        return p0, (stop, m, n, p0, p1, p2)

    def eq14(i, _1, m, n, p0, _2, _3):
        n += 1
        p1 = amn(m, n) * cx * p0
        check_Pmn(p1, m, n, cx)
        # print(i, m, n, p1[0]); i+=1
        stop = n < N
        return p1, (stop, m, n, p0, p1, p2)

    def eq15_base(i, _1, m, n, p0, p1, _2):
        n += 1
        p2 = amn(m, n)*cx*p1 + bmn(m, n)*p0
        check_Pmn(p2, m, n, cx)
        # print(i, m, n, p2[0]); i+=1
        stop = n < N
        return p2, (stop, m, n, p0, p1, p2)

    def eq15_iter(i, _1, m, n, p0, p1, p2):
        p0, p1 = p1, p2
        n += 1
        p2 = amn(m, n)*cx*p1 + bmn(m, n)*p0
        check_Pmn(p2, m, n, cx)
        # print(i, m, n, p2[0]); i+=1
        stop = n < N
        return p2, (stop, m, n, p0, p1, p2)
    
    def switch(i,fs,*args): return fs[i](*args)
    
    def choose_eqn(i, m, n, p0, p1, p2):
        return switch(
            min(3, n - m),
            [eq13, eq14, eq15_base, eq15_iter],
            i, amm[m], m, n, p0, p1, p2)

    p0, p1, p2 = jnp.zeros((3,nlat), jnp.float32)
    for i, (m, n) in enumerate(ml):
        pmn, (_, _, _, p0, p1, p2) = choose_eqn(i, m, n, p0, p1, p2)
        print(i, m, n, pmn[nlat//2])

make_lt_recur(6, 32)
# what we really want is a scan that accumulates results for m,n array
# TODO how to factor forward vs inverse transforms?
# TODO test custom gradients vs auto gradient? probably identical esp with scan-based graph

0 0 0 0.28209479177387814


AssertionError: 
Not equal to tolerance rtol=1e-05, atol=1e-06

Mismatched elements: 32 / 32 (100%)
Max absolute difference: 1.16533666
Max relative difference: 13.99931088
 x: array([ 5.462742e-01,  5.436438e-01,  5.357777e-01,  5.227519e-01,
        5.046916e-01,  4.817709e-01,  4.542104e-01,  4.222757e-01,
        3.862742e-01,  3.465527e-01,  3.034937e-01,  2.575119e-01,...
 y: array([ 0.630783,  0.621693,  0.594772,  0.551054,  0.492219,  0.420529,
        0.338738,  0.249991,  0.157696,  0.065401, -0.023347, -0.105137,
       -0.176827, -0.235662, -0.27938 , -0.306301, -0.315392, -0.306301,...

While we're debugging, let's just check we've got the same ordering as SHTns, namely, `m` is the outer stride:

In [30]:
sht = shtns.sht(lmax=6)
sht.set_grid(nlat=32, nphi=64)
np.c_[np.r_[:sht.m.size], sht.m, sht.l]

array([[ 0,  0,  0],
       [ 1,  0,  1],
       [ 2,  0,  2],
       [ 3,  0,  3],
       [ 4,  0,  4],
       [ 5,  0,  5],
       [ 6,  0,  6],
       [ 7,  1,  1],
       [ 8,  1,  2],
       [ 9,  1,  3],
       [10,  1,  4],
       [11,  1,  5],
       [12,  1,  6],
       [13,  2,  2],
       [14,  2,  3],
       [15,  2,  4],
       [16,  2,  5],
       [17,  2,  6],
       [18,  3,  3],
       [19,  3,  4],
       [20,  3,  5],
       [21,  3,  6],
       [22,  4,  4],
       [23,  4,  5],
       [24,  4,  6],
       [25,  5,  5],
       [26,  5,  6],
       [27,  6,  6]])

And here from our function:

In [31]:
amm = jnp.array(all_amm(6))
cx = jnp.array(lat_grid(32))

Pmn(amm, cx)

0 0 0 0.2820948
1 0 1 0.48860252
2 0 2 0.63078314
3 0 3 0.74635273
4 0 4 0.8462845
5 0 5 0.9356028
6 0 6 1.0171077
7 1 1 -0.0
8 1 2 -0.0
9 1 3 0.0
10 1 4 0.0
11 1 5 0.0
12 1 6 0.0
13 2 2 0.0
14 2 3 0.0
15 2 4 0.0
16 2 5 0.0
17 2 6 0.0
18 3 3 -0.0
19 3 4 -0.0
20 3 5 0.0
21 3 6 0.0
22 4 4 0.0
23 4 5 0.0
24 4 6 0.0
25 5 5 -0.0
26 5 6 -0.0
27 6 6 0.0


Array([ 0.0000000e+00,  3.2583479e-04,  4.9519106e-03,  2.2985460e-02,
        6.4191215e-02,  1.3311261e-01,  2.2449929e-01,  3.2205141e-01,
        4.0137896e-01,  4.3657413e-01,  4.0844810e-01,  3.1174392e-01,
        1.5879709e-01, -2.1865398e-02, -1.9192256e-01, -3.1297475e-01,
       -3.5678127e-01, -3.1297475e-01, -1.9192256e-01, -2.1865398e-02,
        1.5879709e-01,  3.1174392e-01,  4.0844810e-01,  4.3657413e-01,
        4.0137896e-01,  3.2205141e-01,  2.2449929e-01,  1.3311261e-01,
        6.4191215e-02,  2.2985460e-02,  4.9519106e-03,  3.2583479e-04],      dtype=float32)

In [17]:
%timeit Pmn(amm, cx)

1.43 µs ± 11.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
