In [1]:
import os
import torch
from torch import Tensor
import numpy as np
from jax import random
import jax.numpy as jnp
from jax import jit

os.chdir(os.path.join(os.getcwd(), ".."))
os.getcwd()

path_dir = os.path.join(os.getcwd(), "data", "paths")
if not os.path.exists(path_dir):
    os.mkdir(path_dir)

torch.cuda.is_available()

True

In [2]:
N = 1000
delta = 1e-3
rho = 0.99
t = 1.0
key = random.PRNGKey(0)
k = jnp.astype(t / delta, int)

In [3]:
def generate_BM_paths(N: int, delta: float, rho: float, t: float = 1) -> np.ndarray:
    """
    Return N samples of 2D brownian motion with step-size delta and sample correlation rho
    """
    k = int(np.floor(t / delta))
    sigma = np.sqrt(delta)
    x = np.repeat(np.random.normal(scale=sigma, size=(1, k, 2)), N, axis=0)
    y = np.random.normal(scale=sigma, size=(N, k, 2))
    z = rho**2 * x + (1 - rho**2) * y
    z = np.cumsum(z, axis=1)
    z = np.concatenate([np.zeros((N, 1, 2)), z], axis=1)
    return z


def generate_BM_paths_torch(
    N: int, delta: float, rho: float, t: float = 1, device: str = "cuda"
) -> Tensor:
    """
    Return N samples of 2D brownian motion with step-size delta and sample correlation rho
    """
    k = int(np.floor(t / delta))
    sigma = np.sqrt(delta)
    x = sigma * torch.randn((1, k, 2)).repeat(N, 1, 1)
    y = sigma * torch.randn((N, k, 2))
    z = rho**2 * x + (1 - rho**2) * y
    z = torch.cumsum(z, axis=1)
    z = torch.cat([torch.zeros((N, 1, 2)), z], axis=1)
    return z


@jit
def generate_BM_paths_jax(rng) -> jnp.array:
    """
    Return N samples of 2D Brownian motion with step-size delta and sample correlation rho
    """
    sigma = jnp.sqrt(delta)

    # Generate random samples for x and y
    x = jnp.repeat(random.normal(rng, shape=(1, k, 2)) * sigma, N, axis=0)
    y = random.normal(rng, shape=(N, k, 2)) * sigma

    # Apply the correlation
    z = rho * x + jnp.sqrt(1 - rho**2) * y

    # Compute the cumulative sum to simulate Brownian motion paths
    z = jnp.cumsum(z, axis=1)

    # Prepend zeros to represent the initial position
    z = jnp.concatenate([jnp.zeros((N, 1, 2)), z], axis=1)

    return z

In [4]:
%%timeit
x = generate_BM_paths(N, delta, rho)

57.7 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
%%timeit
x = generate_BM_paths_torch(N, delta, rho)

18.2 ms ± 460 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
%%timeit
key, sub = random.split(key)
x = generate_BM_paths_jax(sub)

UnboundLocalError: cannot access local variable 'key' where it is not associated with a value