In [1]:
# to use cpu uncomment the following:
#import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import roughpy as rp
import math
import jax
import jax.numpy as jnp
from diffrax import *
import time

print(jax.devices())

[CudaDevice(id=0)]


Consider the linear controlled differential equation (CDE)

$$ \mathrm{d}Y = AY \mathrm{d}X, \quad Y_0 = y_0, \qquad \text{with} \quad A \in \mathcal{L}(\mathbb R^e \otimes \mathbb R^d, \mathbb R^e), \quad X \in C^\infty([0,1], \mathbb R^d).$$

The solution is given explicitly by 

$$ Y_1 = \sum_{n = 0}^\infty A_{k_n \gamma_n} A_{k_{n-1} \gamma_{n-1}}^{k_n} \cdots A_{k_1 \gamma_1}^{k_2} y_0^{k_1} S(X)^{\gamma_1 \ldots \gamma_n}_{0,1}. $$

with an implicit sum over the $k_i$'s, the indices that run over $1,\ldots,e$. We want to approximate $Y_1$ by computing the partial sum above up to some $n$, i.e. the one-step level-$n$ Euler approximation.

The following example code will do this for a specific example of a $2$-dimensional circular path ($2$ is arbitrary and the rest of the code accommodates general $d$). This won't be important for cubature: the signatures will be swapped out for the exponential of the cubature formula. What matters is that the functions are vmappable over initial conditions, which means the one-step Euler approximations can be concatenated over several intervals.

In [2]:
# The following code obtains the signature of a 2-dimensional circlular path, as a list of jnp arrays.
# It can be easily modified to obtain the signature of any d-dimensional path.

def circle(t, a, b, c):
    return a * np.exp(2 * b * np.pi * 1j * (t + c))

def _make_path(x):
    def lie_path(t, ctx):
        return rp.Lie([x(t).real, x(t).imag], ctx=ctx)
    return lie_path

def make_signature(x, d, n, res, s, t):
    context = rp.get_context(width = d, depth = n, coeffs=rp.DPReal)
    function_stream = rp.FunctionStream.from_function(_make_path(x), ctx = context, resolution = res)
    return function_stream.signature(rp.RealInterval(s,t))

def _sig_degrees(sig, d, n):
    expected_length = d ** (n + 1) - 1
    assert len(sig) == expected_length, f"Array length must be {expected_length}, but got {len(sig)}"
    result = []
    start = 0
    for i in range(n + 1):
        length = d ** i
        subarray = sig[start:start + length]
        result.append(subarray)
        start += length
    return result

def _reshape_level(arr, d, n):
    expected_length = d ** n
    assert len(arr) == expected_length, f"Array length must be {expected_length}, but got {len(arr)}"
    
    new_shape = (d,) * n
    reshaped_array = arr.reshape(new_shape)
    return reshaped_array

def reshape_signature(sig, d, n): 
    npsig = np.array(sig)
    result = []
    k = 0
    for arr in _sig_degrees(npsig, d, n):
        result.append(jnp.array(_reshape_level(arr, d, k)))
        k += 1
    return result

In [3]:
# Computes the powers of the tensor A up to order n and stores them in a list of jnp arrays.

def powers_up_to(A, n):
    matrix = A
    result = [matrix]
    for i in range(1, n):
        subscripts = 'ab' + ''.join(chr(100 + k) for k in range(i)) + ',bc' + chr(100 + i) + '->ac' + ''.join(chr(100 + k) for k in range(i + 1))
        matrix = jnp.einsum(subscripts, matrix, A)
        result.append(matrix)
    return result

In [4]:
# Computes the one-step Euler approximation of the linear CDE with given signature.

def _make_indices_j(n):
    indices_A = 'ab' + ''.join(chr(99 + i) for i in range(n-1, -1, -1))
    indices_S = ''.join(chr(99 + i) for i in range(n))
    output_indices = 'ab'
    return indices_A, indices_S, output_indices

def _single_sum_euler(n, y0, An, Sn, indices_list):
    R = jnp.einsum(f'{indices_list[0]},{indices_list[1]}->{indices_list[2]}', An, Sn)
    return jnp.dot(R, y0)

def one_step_euler(n, y0, powers, S, indices_list):
    return y0 + sum(_single_sum_euler(k, y0, powers[k], S[k+1], indices_list[k]) for k in range(n-1))

In [5]:
# The above code is not fully parallelised. Unfortunately, JAX does not support vectorising the call
# to _single_sum_euler, as it involves arrays of different shapes. Let's try to use parallelisation
# in concurrent.futures instead.

import concurrent.futures
import jax.numpy as jnp

def one_step_euler_cf(n, y0, powers, S, indices_list):
    with concurrent.futures.ThreadPoolExecutor() as executor: # ProcessPoolExecutor never works
        futures = [
            executor.submit(_single_sum_euler, k, y0, powers[k], S[k+1], indices_list[k])
            for k in range(n-1)
        ]
        results = [future.result() for future in concurrent.futures.as_completed(futures)]
    return y0 + sum(results)

In [6]:
# Make the signature (all that matters here is the output).
k = 5
a = 1/k
b = k ** 2
c = 0

length = b * 2 * np.pi * a
area = b * np.pi * (a ** 2)

d = 2
n = 6
s, t = 0, 1

sig = make_signature(lambda t: circle(t,a,b,c), d = d, n = n, res = 15, s = s, t = t)

rs = reshape_signature(sig, d, n)

2024-08-19 01:15:04.607376: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [7]:
# Parameters of the differential equation. A is drawn from a uniform distribution.

e = 4
r = 1


key0 = jax.random.PRNGKey(0)
y0 = jax.random.uniform(key0, (e,))

key1 = jax.random.PRNGKey(1)
A = jax.random.uniform(key1, (e, e, d))

In [9]:
ind_list = [_make_indices_j(k) for k in range(1, n+1)] #make the indices for the einsums
pow = powers_up_to(A, n) #compute and store the powers of A

# comparison of one_step_euler vs one_step_euler_cf:

start = time.time()
one_step_euler(n, y0, pow, rs, ind_list)
print(time.time() - start)

start = time.time()
one_step_euler_cf(n, y0, pow, rs, ind_list)
print(time.time() - start)

0.004661083221435547
0.009574174880981445


In [11]:
# Now let's test how much time vmap is saving us. We compare the sequential evaluation of the above
# function over many y0's, with the evaluation of the same function using vmap.

num_y0s = 10000
keys = jax.random.split(key1, num_y0s)
random_arrays = [jax.random.uniform(k, (e,)) for k in keys]

start = time.time()
for y in random_arrays:
    one_step_euler(n, y, pow, rs, ind_list)
print("Sequential evaluation time: ", time.time() - start)

start = time.time()
jax.vmap(lambda y: one_step_euler(n, y, pow, rs, ind_list))(jnp.array(random_arrays))
print("vmap evaluation time: ", time.time() - start)

start = time.time()
jax.vmap(lambda y: one_step_euler_cf(n, y, pow, rs, ind_list))(jnp.array(random_arrays))
print("vmap + concurrent futures evaluation time: ", time.time() - start)

# About a 10-fold speedup on Havok, 20-fold on Sauron of using vmap over sequential evaluation.
# Lots of issues and few benefits from using one_step_euler_cf. It only seems to work on GPU.

Sequential evaluation time:  22.72610902786255
vmap evaluation time:  2.423283576965332
vmap + concurrent futures evaluation time:  1.7909224033355713


In [12]:
# the first time vmap is run it is slower becuase of compilation, subsequent runs are faster:

start = time.time()
jax.vmap(lambda y: one_step_euler(n, y, pow, rs, ind_list))(jnp.array(random_arrays))
print("vmap evaluation time: ", time.time() - start)

start = time.time()
jax.vmap(lambda y: one_step_euler_cf(n, y, pow, rs, ind_list))(jnp.array(random_arrays))
print("vmap + concurrent futures evaluation time: ", time.time() - start)

# in fact, after compilation the benefit of using concurrent futures is almost non-existent

vmap evaluation time:  1.89339280128479
vmap + concurrent futures evaluation time:  1.8314363956451416


It's very possible the code above isn't fully optimised. One obstacle to removing some of the lists and for loops is that the arrays for the signatures of various levels are of different sizes, and therefore not stackables. This makes it impossible to vectorise operations over them, even though these operations can in principle be done in parallel.

The instance of this that I have in mind are the calls to ```_single_sum_euler``` in ```one_step_euler```. I tried to parallelise these calls using ```concurrent.futures```, but this turns out to have negligible benefits and to give lots of errors. It makes sense that classical parallelisation wouldn't work well in JAX, especially when used at the very bottom of a call stack which involves vectorisation (a much more basic form of parallelisation). Going forward it makes sense to use ```one_step_euler```, not ```one_step_euler_cf```.

A possible last resort would be to pad the arrays with 0s and to use vmap, which would require some careful rewriting of the code. Note that unless JAX is able to detect sparse arrays it will be doing lots of useless 0 products.

In [15]:
# Here's what goes wrong. I would like to parallelise the call to _single_sum_euler
# in the function one_step_euler. However, JAX is unable to handle this.

jax.vmap(lambda k:_single_sum_euler(k, y0, pow[k], rs[k+1], ind_list[k]))(jnp.arange(n-1))

TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[]
This BatchTracer with object id 140637968653488 was created on line:
  /tmp/ipykernel_3822391/2737179216.py:4 (<module>)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [19]:
# Making an array from the list won't work because the arrays are not all of the same shape.
p = jnp.array(pow)

ValueError: All input arrays must have the same shape.

In [17]:
# pmap doesn't work either
jax.pmap(lambda k:_single_sum_euler(k, y0, pow[k], rs[k+1], ind_list[k]))(jnp.arange(n-1))

TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[]
The error occurred while tracing the function <lambda> at /tmp/ipykernel_3822391/1602965186.py:2 for pmap. This concrete value was not available in Python because it depends on the value of the argument k.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

More tests

In [14]:
num_y0s = 100000
keys = jax.random.split(key1, num_y0s)
random_arrays = jnp.array([jax.random.uniform(k, (e,)) for k in keys]) # vectorise this

In [16]:
start = time.time()
jax.vmap(lambda y: _single_sum_euler(n, y, pow[5], rs[6], ind_list[5]))(random_arrays)
print("vmap evaluation time: ", time.time() - start)

vmap evaluation time:  0.004851341247558594


In [17]:
start = time.time()
for y in random_arrays:
    _single_sum_euler(n, y, pow[5], rs[6], ind_list[5])
print("Sequential evaluation time: ", time.time() - start)

Sequential evaluation time:  36.801605224609375


Ok I know what to do! Vectorise each level separately, and then if necessary parallelise (outside of JAX) across levels. This is going to be super fast.