In [17]:
import os
os.chdir('/home/zongchen/nest_bq')
import pandas as pd
import numpy as np
import jax
import matplotlib.pyplot as plt

from utils.kernel_means import *

rng_key = jax.random.PRNGKey(0)

In [18]:
def f(x):
    return x ** 2

# def g(x, theta):
#     return (1 + jnp.sqrt(3) * jnp.abs(x - theta)) * jnp.exp(- jnp.sqrt(3) * jnp.abs(x - theta))

def g(x, theta):
    return jnp.abs(x - theta) ** 1.5

def simulate_theta(T, rng_key):
    rng_key, _ = jax.random.split(rng_key)
    Theta = jax.random.uniform(rng_key, shape=(T, 1), minval=0., maxval=1.)
    return Theta


def simulate_x_theta(N, Theta, rng_key):
    def simulate_x_per_theta(N, theta, rng_key):
        rng_key, _ = jax.random.split(rng_key)
        x = jax.random.uniform(rng_key, shape=(N, ), minval=0., maxval=1.)
        # x = jax.random.normal(rng_key, shape=(N, ))
        return x
    vmap_func = jax.vmap(simulate_x_per_theta, in_axes=(None, 0, None))
    X = vmap_func(N, Theta, rng_key)
    return X

true_value = 4/25 *(1/3 + 5 * jnp.pi / 512)

print(f"True value: {true_value}")


True value: 0.05824207185456738


In [19]:
seed = 0
N = 16 + 8 + 4
T = 16 + 8 + 4


rng_key = jax.random.PRNGKey(seed)    
rng_key, _ = jax.random.split(rng_key)
Theta = simulate_theta(T, rng_key)
X = simulate_x_theta(N, Theta, rng_key)
g_X = g(X, Theta)

# This is nested Monte Carlo
I_theta_MC = g_X.mean(1)
I_NMC = f(I_theta_MC).mean(0)
print(f"Nested Monte Carlo: {I_NMC}")

Nested Monte Carlo: 0.05594543740153313


In [26]:
N_list = [4, 8, 16]
T_list = [16, 8, 4]

level = 3

I_NMLMC = 0

for l in range(level):
    if l == 0:
        rng_key, _ = jax.random.split(rng_key)
        Theta_l = simulate_theta(T_list[l], rng_key)
        X_l = simulate_x_theta(N_list[l], Theta_l, rng_key)
        g_X = g(X_l, Theta_l)
        Z_l = f(g_X.mean(1)).mean(0)
        I_NMLMC = Z_l
        print(f"Level {l}: {Z_l}")
    else:
        rng_key, _ = jax.random.split(rng_key)
        Theta_l = simulate_theta(T_list[l], rng_key)
        rng_key, _ = jax.random.split(rng_key)
        X_l = simulate_x_theta(N_list[l], Theta_l, rng_key)
        rng_key, _ = jax.random.split(rng_key)
        X_l_prev = simulate_x_theta(N_list[l-1], Theta_l, rng_key)
        g_X = f(g(X_l, Theta_l).mean(1)) - f(g(X_l_prev, Theta_l).mean(1))
        Z_l = g_X.mean(0)
        I_NMLMC += Z_l
        print(f"Level {l}: {Z_l}")
    
print(f"Nested Multilevel Monte Carlo: {I_NMLMC}")

Level 0: 0.046016134321689606
Level 1: 0.0012066980125382543
Level 2: 0.015920639038085938
Nested Multilevel Monte Carlo: 0.06314347684383392
