In [1]:
# to use cpu uncomment the following:
#import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import roughpy as rp
import math
import jax
import jax.numpy as jnp
from diffrax import *
import time
from one_step_euler import *

print(jax.devices())

[CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]


2024-08-20 00:13:29.650972: W external/xla/xla/service/platform_util.cc:206] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory


Consider the linear controlled differential equation (CDE)

$$ \mathrm{d}Y = AY \mathrm{d}X, \quad Y_0 = y_0, \qquad \text{with} \quad A \in \mathcal{L}(\mathbb R^e \otimes \mathbb R^d, \mathbb R^e), \quad X \in C^\infty([0,1], \mathbb R^d).$$

The solution is given explicitly by 

$$ Y_1 = \sum_{n = 0}^\infty A_{k_n \gamma_n} A_{k_{n-1} \gamma_{n-1}}^{k_n} \cdots A_{k_1 \gamma_1}^{k_2} y_0^{k_1} S(X)^{\gamma_1 \ldots \gamma_n}_{0,1}. $$

with an implicit sum over the $k_i$'s, the indices that run over $1,\ldots,e$. We want to approximate $Y_1$ by computing the partial sum above up to some $n$, i.e. the one-step level-$n$ Euler approximation.

The following example code will do this for a brownian path, but can be made to work for an arbitrary path. This won't be important for cubature: the signatures will be swapped out for the exponential of the cubature formula. What matters is that the functions are vmappable over initial conditions, which means the one-step Euler approximations can be concatenated over several intervals.

In [3]:
d = 3
n = 6
s, t = 0, 1
N = 100

k = jax.random.PRNGKey(0)
vbt = VirtualBrownianTree(0, 1, tol=0.01, shape=(d,), key=k)

sig = make_brownian_sig(key = k, N = N, res = 10, d = d, n = n)

rs = reshape_signature(sig,d,n)

In [5]:
# The above code is not fully parallelised. Unfortunately, JAX does not support vectorising the call
# to _single_sum_euler, as it involves arrays of different shapes. Let's try to use parallelisation
# in concurrent.futures instead.

import concurrent.futures
import jax.numpy as jnp

def one_step_euler_cf(n, y0, powers, S, indices_list):
    with concurrent.futures.ThreadPoolExecutor() as executor: # ProcessPoolExecutor never works
        futures = [
            executor.submit(_single_sum_euler, k, y0, powers[k], S[k+1], indices_list[k])
            for k in range(n-1)
        ]
        results = [future.result() for future in concurrent.futures.as_completed(futures)]
    return y0 + sum(results)

In [4]:
# Parameters of the differential equation. A is drawn from a uniform distribution.

e = 4
r = 1


key0 = jax.random.PRNGKey(0)
y0 = jax.random.uniform(key0, (e,))

key1 = jax.random.PRNGKey(1)
A = jax.random.uniform(key1, (e, e, d))

In [5]:
ind_list = [make_indices(k) for k in range(1, n+1)] #make the indices for the einsums
pow = powers_up_to(A, n) #compute and store the powers of A

In [9]:
num_y0s = 10000

key = jax.random.PRNGKey(1)
random_y0s = jax.random.uniform(key, (num_y0s, e))

start = time.time()
jax.vmap(lambda y: single_sum_euler(n, y, pow[5], rs[6], ind_list[5]))(random_y0s)
print("vmap evaluation time: ", time.time() - start)

start = time.time()
for y in random_y0s:
    single_sum_euler(n, y, pow[5], rs[6], ind_list[5])
print("Sequential evaluation time: ", time.time() - start)

vmap evaluation time:  0.3517310619354248
Sequential evaluation time:  3.875796318054199


Next:

- sum the single sum eulers, try to parallelise
- vectorise over different signatures as well as different initial conditions