In [2]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import roughpy as rp
import math
import jax
import jax.numpy as jnp
from diffrax import *
import time

jax.config.update("jax_enable_x64", True)

Consider the linear controlled differential equation (CDE)

$$ \mathrm{d}Y = AY \mathrm{d}X, \quad Y_0 = y_0, \qquad \text{with} \quad A \in \mathcal{L}(\mathbb R^e \otimes \mathbb R^d, \mathbb R^e), \quad X \in C^\infty([0,1], \mathbb R^d).$$

The solution is given explicitly by 

$$ Y_1 = \sum_{n = 0}^\infty A_{k_n \gamma_n} A_{k_{n-1} \gamma_{n-1}}^{k_n} \cdots A_{k_1 \gamma_1}^{k_2} y_0^{k_1} S(X)^{\gamma_1 \ldots \gamma_n}_{0,1}. $$

with an implicit sum over the $k_i$'s, the indices that run over $1,\ldots,e$. We want to approximate $Y_1$ by computing the partial sum above up to some $n$, i.e. the one-step level-$n$ Euler approximation.

The following example code will do this for a specific example of a $2$-dimensional circular path ($2$ is arbitrary and the rest of the code accommodates general $d$). This won't be important for cubature: the signatures will be swapped out for the exponential of the cubature formula (in fact, you can save time by suppressing the trivial computations at odd degrees). What matters is that the functions are vmappable over initial conditions, which means the one-step Euler approximations can be concatenated over several intervals.

In [3]:
# The following code obtains the signature of a 2-dimensional circlular path, as a list of jnp arrays.
# It can be easily modified to obtain the signature of any d-dimensional path.

def circle(t, a, b, c):
    return a * np.exp(2 * b * np.pi * 1j * (t + c))

def _make_path(x):
    def lie_path(t, ctx):
        return rp.Lie([x(t).real, x(t).imag], ctx=ctx)
    return lie_path

def make_signature(x, d, n, res, s, t):
    context = rp.get_context(width = d, depth = n, coeffs=rp.DPReal)
    function_stream = rp.FunctionStream.from_function(_make_path(x), ctx = context, resolution = res)
    return function_stream.signature(rp.RealInterval(s,t))

def _sig_degrees(sig, d, n):
    expected_length = d ** (n + 1) - 1
    assert len(sig) == expected_length, f"Array length must be {expected_length}, but got {len(sig)}"
    result = []
    start = 0
    for i in range(n + 1):
        length = d ** i
        subarray = sig[start:start + length]
        result.append(subarray)
        start += length
    return result

def _reshape_level(arr, d, n):
    expected_length = d ** n
    assert len(arr) == expected_length, f"Array length must be {expected_length}, but got {len(arr)}"
    
    new_shape = (d,) * n
    reshaped_array = arr.reshape(new_shape)
    return reshaped_array

def reshape_signature(sig, d, n): 
    npsig = np.array(sig)
    result = []
    k = 0
    for arr in _sig_degrees(npsig, d, n):
        result.append(jnp.array(_reshape_level(arr, d, k)))
        k += 1
    return result

In [4]:
# Computes the powers of the tensor A up to order n and stores them in a list of jnp arrays.

def powers_up_to(A, n):
    matrix = A
    result = [matrix]
    for i in range(1, n):
        subscripts = 'ab' + ''.join(chr(100 + k) for k in range(i)) + ',bc' + chr(100 + i) + '->ac' + ''.join(chr(100 + k) for k in range(i + 1))
        matrix = jnp.einsum(subscripts, matrix, A)
        result.append(matrix)
    return result

In [5]:
# Computes the one-step Euler approximation of the linear CDE with given signature.

def _make_indices_j(n):
    indices_A = 'ab' + ''.join(chr(99 + i) for i in range(n-1, -1, -1))
    indices_S = ''.join(chr(99 + i) for i in range(n))
    output_indices = 'ab'
    return indices_A, indices_S, output_indices

def _single_sum_euler(n, y0, An, Sn, indices_list):
    R = jnp.einsum(f'{indices_list[0]},{indices_list[1]}->{indices_list[2]}', An, Sn)
    return jnp.dot(R, y0)

def one_step_euler(n, y0, powers, S, indices_list):
    return y0 + sum(_single_sum_euler(k, y0, powers[k], S[k+1], indices_list[k]) for k in range(n-1))

In [6]:
# Make the signature (all that matters here is the output).
k = 5
a = 1/k
b = k ** 2
c = 0

length = b * 2 * np.pi * a
area = b * np.pi * (a ** 2)

d = 2
n = 5
s, t = 0, 1

sig = make_signature(lambda t: circle(t,a,b,c), d = d, n = n, res = 15, s = s, t = t)

rs = reshape_signature(sig, d, n)

2024-08-18 23:02:15.796320: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [7]:
# Parameters of the differential equation. A is drawn from a uniform distribution.

e = 4
r = 1


key0 = jax.random.PRNGKey(0)
y0 = jax.random.uniform(key0, (e,))

key1 = jax.random.PRNGKey(1)
A = jax.random.uniform(key1, (e, e, d))

n = 6

In [9]:
ind_list = [_make_indices_j(k) for k in range(1, n+1)] #make the indices for the einsums
pow = powers_up_to(A, n) #compute and store the powers of A
one_step_euler(n, y0, pow, rs, ind_list) 

Array([5.9543521 , 5.06384921, 1.39503282, 3.45248092], dtype=float64)

In [10]:
# Now let's test how much time vmap is saving us. We compare the sequential evaluation of the above
# function over many y0's, with the evaluation of the same function using vmap.

num_y0s = 10000
keys = jax.random.split(key1, num_y0s)
random_arrays = [jax.random.uniform(k, (e,)) for k in keys]

start = time.time()
for y in random_arrays:
    one_step_euler(n, y, pow, rs, ind_list)
print("Sequential evaluation time: ", time.time() - start)

start = time.time()
jax.vmap(lambda y: one_step_euler(n, y, pow, rs, ind_list))(jnp.array(random_arrays))
print("vmap evaluation time: ", time.time() - start)

# about a 10-fold speedup on Havok, 20-fold on Sauron
# the first time vmap is run it is slower becuase of compilation, subsequent runs are faster

Sequential evaluation time:  22.0222806930542
vmap evaluation time:  2.2496044635772705


It's very possible the code above isn't fully optimised. One obstacle to removing some of the lists and for loops is that the arrays for the signatures of various levels are of different sizes, and therefore not stackables. This makes it impossible to vectorise operations over them, even though these operations can in principle be done in parallel.