In [3]:
# to use cpu uncomment the following:
#import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import roughpy as rp
import math
import jax
import jax.numpy as jnp
from diffrax import *
import time
from one_step_euler import *

print(jax.devices())

[CudaDevice(id=0)]


Consider the linear controlled differential equation (CDE)

$$ \mathrm{d}Y = AY \mathrm{d}X, \quad Y_0 = y_0, \qquad \text{with} \quad A \in \mathcal{L}(\mathbb R^e \otimes \mathbb R^d, \mathbb R^e), \quad X \in C^\infty([0,1], \mathbb R^d).$$

The solution is given explicitly by 

$$ Y_1 = \sum_{n = 0}^\infty A_{k_n \gamma_n} A_{k_{n-1} \gamma_{n-1}}^{k_n} \cdots A_{k_1 \gamma_1}^{k_2} y_0^{k_1} S(X)^{\gamma_1 \ldots \gamma_n}_{0,1}. $$

with an implicit sum over the $k_i$'s, the indices that run over $1,\ldots,e$. We want to approximate $Y_1$ by computing the partial sum above up to some $n$, i.e. the one-step level-$n$ Euler approximation.

The following example code will do this for a brownian path, but can be made to work for an arbitrary path. This won't be important for cubature: the signatures will be swapped out for the exponential of the cubature formula. What matters is that the functions are vmappable over initial conditions, which means the one-step Euler approximations can be concatenated over several intervals.

In [4]:
d = 3
n = 6
s, t = 0, 1
N = 100



# This will take a couple of minutes, but it is only for the purpose of getting 200 signatures.
# When doing cubature, these will be the exponentials of the Lie cubature points.

num_paths = 200

sigs = []
for k in range(num_paths):
    key = jax.random.PRNGKey(k)
    sig = make_brownian_sig(key = key, N = N, res = 10, d = d, n = n)
    rs = reshape_signature(sig, d, n)
    sigs.append(rs)

2024-08-20 02:18:39.129681: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [9]:
# Incidentally, let's check that it matches the expected signature of brownian motion:

for k in range(n+1):
    sigs_k = [s[k] for s in sigs]
    stacked_sigs = jnp.stack(sigs_k)
    print(k, ': ', jnp.mean(stacked_sigs, axis=0))
    
# looks ok

0 :  1.0
1 :  [-0.04812605 -0.11194067 -0.09954539]
2 :  [[ 0.5512239   0.00854901  0.10728024]
 [ 0.06953989  0.531183    0.03312418]
 [ 0.03078745 -0.03520167  0.49968967]]
3 :  [[[ 0.03773351  0.06910873 -0.0124695 ]
  [-0.01562674  0.03357104  0.03712659]
  [-0.01997475 -0.00657927 -0.03985787]]

 [[-0.04998645 -0.02910184  0.01246647]
  [-0.09339364 -0.03479901 -0.00718148]
  [ 0.00180046 -0.01630269  0.01299437]]

 [[-0.03678331 -0.0422317  -0.02207446]
  [-0.00490524 -0.00300318  0.02319371]
  [-0.07644357 -0.0274091  -0.06546504]]]
4 :  [[[[ 0.14592618  0.01984253  0.04161057]
   [ 0.02419238  0.1299959  -0.01077867]
   [ 0.00772096 -0.00876141  0.11305094]]

  [[-0.0118957  -0.00890883  0.0279956 ]
   [-0.01690654  0.02236339  0.00224523]
   [-0.01349891 -0.01648054 -0.02418018]]

  [[-0.00996255 -0.00487061  0.00020643]
   [-0.00383315 -0.00903263  0.0381149 ]
   [ 0.03042732 -0.01106206  0.04695561]]]


 [[[-0.00686811 -0.00314645  0.01466037]
   [ 0.02487585 -0.01738557  0.

In [5]:
# Parameters of the differential equation. A is drawn from a uniform distribution.

e = 4
r = 1


key0 = jax.random.PRNGKey(0)
y0 = jax.random.uniform(key0, (e,))

key1 = jax.random.PRNGKey(1)
A = jax.random.uniform(key1, (e, e, d))

In [6]:
ind_list = [make_indices(k) for k in range(1, n+1)] #make the indices for the einsums
pow = powers_up_to(A, n) #compute and store the powers of A

In [9]:
# My signatures are stored in a list (indexed by random key) of lists (indexed by degree) of arrays.
# I need to restructure them so that they form a list (indexed by dimension) of stacked arrays.

stacked_sigs = []

for k in range(n+1):
    sigs_k = [s[k] for s in sigs]
    stacked_sigs.append(jnp.stack(sigs_k))

In [18]:
# Make some random initial conditions.

num_y0s = 200 ** 2
key = jax.random.PRNGKey(1)
random_y0s = jax.random.uniform(key, (num_y0s, e))

In [None]:
# Now I will vmap over both the initial conditions and the signatures.

start = time.time()

results = jnp.tile(random_y0s, (num_paths, 1, 1))

for k in range(1,n):
    def sse(y,s): return single_sum_euler(n, y, pow[k], s, ind_list[k])
    sse_vmap_y = jax.vmap(sse, in_axes=(0, None))
    sse_vmap_y_and_s = jax.vmap(sse_vmap_y, in_axes=(None, 0))
    results += sse_vmap_y_and_s(random_y0s, stacked_sigs[k+1])

print(time.time() - start)

This runs in under 1 second on a GPU. Memory is the bottleneck. It should be ok on 3 intervals without sampling, but it will run out of memory on the 4th.