Is it worth pushing Jax to use multiple cores at this point?  tl;dr on a laptop, no, but maybe for larger devices, to be tested.

In [1]:
import os
# tells XLA to use many cores
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'

In [2]:
%load_ext autoreload
%load_ext memory_profiler
%autoreload 2
%pylab inline
import numpy as np
import tqdm
import shtns
import jax
import jax.numpy as jnp
import nfjax.shtlc as lc

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


In [3]:
lmax, nlat, nlon = 31, 32, 64

sht = shtns.sht(lmax)
sht.set_grid(nlat=nlat, nphi=nlon)

[32, 64]

In [4]:
lm = lc.make_lm(lmax)
phi, _, gw = lc.make_grid(nlat, nlon)

In [5]:
from scipy.special import sph_harm

# forward & inverse Legendre transform matrices
LT = []
iLT = []
L = []
D = 0.00047108

# each longitudinal frequency mode `m` needs its own forward & inverse matrices
for m in range(lmax):
    l = lm[0, lm[0]>=m]
    LT.append( gw[None, :] * sph_harm(m, l[:, None], 0, phi[None, :]).conjugate() )
    iLT.append(              sph_harm(m, l[None, :], 0, phi[:, None])             )
    dll = D * l * (l + 1)
    L.append(iLT[-1].dot(dll[:, None] * LT[-1]))

In [6]:
fLT = [_.astype(np.complex64) for _ in LT]
fiLT = [_.astype(np.complex64) for _ in iLT]
jLT = [jnp.array(_) for _ in fLT]
jiLT = [jnp.array(_) for _ in fiLT]

In [7]:
def compute_L(jLT, jiLT, jdll):
    return jnp.array(jax.tree_map(
        lambda lj, ilj: jnp.dot(ilj, jdll[:lj.shape[0]]*lj).real,
        jLT, jiLT,
    ))

def close_L_for_D(jLT, jiLT, l):
    @jax.jit
    def L_for_D(D):
        jdll = D * l * (l + 1)
        return compute_L(jLT, jiLT, jdll)
    return L_for_D

@jax.jit
def apply_L(L, x):
    X = jnp.fft.rfft(x, axis=1)
    # X[:,:lmax] = jnp.einsum('abc,ca->ba',L,X[:,:lmax])
    # X[:,lmax:] = 0.0
    X = jnp.hstack(
        (jnp.einsum('abc,ca->ba', L, X[:,:lmax]),
         jnp.zeros((X.shape[0], X.shape[1] - lmax), jnp.complex64)
        )
    )
    y = jnp.fft.irfft(X, axis=1).real
    return y

jl = jnp.array(sht.l)[:,None]
l4d = close_L_for_D(jLT, jiLT, jl)
L = l4d(D)

In [8]:
apply_1 = lambda x: apply_L(L, x)
apply_batch = jax.vmap(apply_1)
apply_batch_batch = jax.pmap(apply_batch)

dc = jax.device_count()
for n in (dc, 32, 512):
    xs = jnp.zeros((dc, n//dc, nlat, nlon))
    xs_ = xs.reshape((-1, nlat, nlon))
    %timeit apply_batch(xs_).block_until_ready()
    %timeit apply_batch_batch(xs).block_until_ready()

188 µs ± 842 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
121 µs ± 813 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
509 µs ± 4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
226 µs ± 999 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
5.71 ms ± 66.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.92 ms ± 59.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In this case, there's not an enormous benefit because the ops within the `apply_L` are already doing well.  From 4 to 8 cores on m1 is not beneficial (makes sense) but worth a try on larger CPUs. Still, with parallelism,

- 325 us for 32, ~10 us per
- 2610 us for 512, ~5 us per

this suggests that a windowed approach will be efficient. 

In [9]:
@jax.jit
def filter(ts, x0, k):
    dt = ts[1] - ts[0]
    def f(x, t):
        lx = 3.14*x + apply_L(L, x)
        x = x + dt * (-x + k * lx)
        return x, x
    _, x_t = jax.lax.scan(f, x0, ts)
    return x_t[-1]

x0 = jnp.zeros((nlat, nlon))
x0 = x0.at[23:29,45:50].set(1.0)

ts = jnp.r_[:10]*0.1
k = 0.2
x1 = filter(ts, x0, k)

run_1 = lambda x: filter(ts, x0, k)
%timeit run_1(x0)

248 µs ± 467 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
jax.vmap(run_1)(jnp.zeros((8, nlat, nlon))).shape

(8, 32, 64)

In [11]:
run_batch = jax.vmap(run_1)
run_batch_batch = jax.pmap(run_batch)

dc = jax.device_count()
for n in (dc, 32, 512):
    xs = jnp.zeros((dc, n//dc, nlat, nlon))
    xs_ = xs.reshape((-1, nlat, nlon))
    %timeit run_batch(xs_).block_until_ready()
    %timeit run_batch_batch(xs).block_until_ready()

339 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
388 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
347 µs ± 4.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
408 µs ± 3.91 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
672 µs ± 19.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
959 µs ± 17.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Running the time stepping here is not better with `pmap`, maybe it's the `scan` inside `map`, maybe just memory bandwidth? 

In [12]:
%timeit run_batch(xs_).block_until_ready()

678 µs ± 6.45 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Neither manually rewriting with batched sized arrays nor using alternate constructs like `jax.lax.while_loop` is faster, so for now the simple answer is the best.

## Inside a scan

This is close to the form we might use for inversion.  A simple ground truth,

In [67]:
x0 = jnp.zeros((nlat, nlon))
x0 = x0.at[23:29,45:50].set(1.0)
ts = jnp.r_[:500]*0.1
k = 0.2

@jax.jit
def simulate(ts, x0, k):
    dt = ts[1] - ts[0]
    def f(x, t):
        # gradients currently wrong, so drop this term for now
        lx = 0 # 3.14*x + apply_L(L, x)
        x = x + dt * (-x + k * lx)
        return x, x
    _, x_t = jax.lax.scan(f, x0, ts)
    return x_t

x1 = simulate(ts, x0, k)
x1.shape

(500, 32, 64)

with a loss function,

In [62]:
@jax.jit
def loss(x0):
    x2 = simulate(ts, x0, k)
    sse = jnp.sum(jnp.square(x1 - x2))
    return sse

x0h = jnp.zeros_like(x0)
loss(x0), loss(x0+1e-4), loss(x0h)

(Array(0., dtype=float32),
 Array(8.731306e-05, dtype=float32),
 Array(127.89471, dtype=float32))

and some gradients,

In [63]:
g = jax.jit(jax.grad(loss))
g(x0h)
%timeit simulate(ts, x0, k)
%timeit g(x0h)

670 µs ± 7.27 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.25 ms ± 3.18 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Note how gradient eval is a little bit more than 2x forward simulation. 

And optimize,

In [66]:
x0h = x0.copy() + 1e-1
for i in range(5):
    print(i, loss(x0h))
    x0h -= 0.01 * g(x0h)

0 87.30969
1 73.05178
2 61.12939
3 51.147156
4 42.79853
