In [6]:
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import time

from IPython.display import set_matplotlib_formats
set_matplotlib_formats("pdf", "png")
# plt.tight_layout()
plt.rcParams["figure.figsize"] = (6, 4)
plt.rcParams["axes.titlesize"] = 28
plt.rcParams["font.size"] = 28
plt.rcParams["lines.linewidth"] = 1.5
plt.rcParams["lines.markersize"] = 7
plt.rcParams["grid.linestyle"] = "--"
plt.rcParams["grid.linewidth"] = 1.0
plt.rcParams["legend.fontsize"] = 16
plt.rcParams["legend.facecolor"] = "white"
plt.rcParams["axes.labelsize"] = 22
plt.rcParams["xtick.labelsize"] = 18
plt.rcParams["ytick.labelsize"] = 18
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams['xtick.major.pad'] = 8
plt.rcParams['ytick.major.pad'] = 8
plt.rcParams['axes.grid'] = True
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
# plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath, amssymb}']


  set_matplotlib_formats("pdf", "png")
  plt.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath, amssymb}']


In [211]:

def price(St, N, rng_key, K1=50, K2=150, s=-0.2, sigma=0.3, T=2, t=1):
    """
    :param St: the price St at time t
    :return: The function returns the price ST at time T sampled from the conditional
    distribution p(ST|St), and the loss \psi(ST) - \psi((1+s)ST) due to the shock. Their shape is Nx * Ny
    """
    output_shape = (N, St.shape[0])
    rng_key, _ = jax.random.split(rng_key)
    epsilon = jax.random.normal(rng_key, shape=output_shape)
    ST = St * jnp.exp(sigma * jnp.sqrt(T - t) * epsilon - 0.5 * (sigma ** 2) * (T - t))
    psi_ST_1 = jnp.maximum(ST - K1, 0) + jnp.maximum(ST - K2, 0) - 2 * jnp.maximum(ST - (K1 + K2) / 2, 0)
    psi_ST_2 = jnp.maximum((1 + s) * ST - K1, 0) + jnp.maximum((1 + s) * ST - K2, 0) - 2 * jnp.maximum(
        (1 + s) * ST - (K1 + K2) / 2, 0)

    return ST



In [212]:

def jax_dist(x, y):
    return jnp.abs(x - y).squeeze()

distance = jax.vmap(jax_dist, in_axes=(None, 1), out_axes=1)
sign_func = jax.vmap(jnp.greater, in_axes=(None, 1), out_axes=1)


def my_laplace(x, y, l):
    r = distance(x, y).squeeze()
    return jnp.exp(- r / l)


def dx_laplace(x, y, l):
    sign = sign_func(x, y).squeeze().astype(float) * 2 - 1
    r = distance(x, y).squeeze()
    part1 = jnp.exp(- r / l) * (-sign)
    return part1


def dy_laplace(x, y, l):
    sign = sign_func(x, y).squeeze().astype(float) * 2 - 1
    r = distance(x, y).squeeze()
    part1 = jnp.exp(- r / l) * sign
    return part1


def dxdy_laplace(x, y, l):
    r = distance(x, y).squeeze()
    part1 = jnp.exp(- r / l) * (-1)
    return part1

def dx_log_px(x, sigma, T, t, St):
    # dx log p(x) for log normal distribution with mu=-\sigma^2 / 2 * (T - t) and sigma = \sigma^2 (T - y)
    part1 = (jnp.log(x) + sigma ** 2 * (T - t) / 2 - jnp.log(St)) / x / (sigma ** 2 * (T - t))
    return -1. / x - part1


def dx_log_px_debug(x):
    # dx log p(x) for log normal distribution with mu=-\sigma^2 / 2 * (T - t) and sigma = \sigma^2 (T - y)
    return - 1. / x - jnp.log(x) / x


def stein_Laplace(x, y, l, sigma, T, t, St):
    d_log_px = dx_log_px(x, sigma, T, t, St)
    d_log_py = dx_log_px(y, sigma, T, t, St)
    
    K = my_laplace(x, y, l)
    dx_K = dx_laplace(x, y, l)
    dy_K = dy_laplace(x, y, l)
    dxdy_K = dxdy_laplace(x, y, l)
    part1 = d_log_px @ d_log_py.T * K
    part2 = d_log_py.T * dx_K
    part3 = d_log_px * dy_K
    part4 = dxdy_K
    return part1 + part2 + part3 + part4

In [236]:

def jax_dist(x, y):
    return jnp.abs(x - y).squeeze()

distance = jax.vmap(jax_dist, in_axes=(None, 1), out_axes=1)
sign_func = jax.vmap(jnp.greater, in_axes=(None, 1), out_axes=1)


# @jax.jit
def my_Matern(x, y, l):
    r = distance(x, y).squeeze()
    part1 = 1 + math.sqrt(3) * r / l
    part2 = jnp.exp(-math.sqrt(3) * r / l)
    return part1 * part2

# @jax.jit
def dx_Matern(x, y, l):
    sign = sign_func(x, y).squeeze().astype(float) * 2 - 1
    r = distance(x, y).squeeze()
    part1 = jnp.exp(-math.sqrt(3) / l * r) * (math.sqrt(3) / l * sign)
    part2 = (-math.sqrt(3) / l * sign) * jnp.exp(-math.sqrt(3) / l * r) * (1 + math.sqrt(3) / l * r)
    return part1 + part2


# @jax.jit
def dy_Matern(x, y, l):
    sign = -(sign_func(x, y).squeeze().astype(float) * 2 - 1)
    r = distance(x, y).squeeze()
    part1 = jnp.exp(-math.sqrt(3) / l * r) * (math.sqrt(3) / l * sign)
    part2 = (-math.sqrt(3) / l * sign) * jnp.exp(-math.sqrt(3) / l * r) * (1 + math.sqrt(3) / l * r)
    return part1 + part2


# @jax.jit
def dxdy_Matern(x, y, l):
    r = distance(x, y).squeeze()
    const = math.sqrt(3) / l
    part1 = const * const * jnp.exp(-const * r)
    part2 = -const * const * jnp.exp(-const * r) * (1 + const * r)
    part3 = const * jnp.exp(-const * r) * const
    return part1 + part2 + part3


def dx_log_px(x, sigma, T, t, St):
    # dx log p(x) for log normal distribution with mu=-\sigma^2 / 2 * (T - t) and sigma = \sigma^2 (T - y)
    part1 = (jnp.log(x) + sigma ** 2 * (T - t) / 2 - jnp.log(St)) / x / (sigma ** 2 * (T - t))
    return -1. / x - part1

def stein_Matern(x, y, l, sigma, T, t, St):
    d_log_px = dx_log_px(x, sigma, T, t, St)
    d_log_py = dx_log_px(y, sigma, T, t, St)
    
    K = my_Matern(x, y, l)
    dx_K = dx_Matern(x, y, l)
    dy_K = dy_Matern(x, y, l)
    dxdy_K = dxdy_Matern(x, y, l)
    part1 = d_log_px @ d_log_py.T * K
    part2 = d_log_py.T * dx_K
    part3 = d_log_px * dy_K
    part4 = dxdy_K
    return part1 + part2 + part3 + part4

In [237]:
seed = int(time.time())
# seed = 0
rng_key = jax.random.PRNGKey(seed)

rng_key, _ = jax.random.split(rng_key)
epsilon = jax.random.normal(rng_key, shape=(1, 1))

S0 = 50
K1 = 50
K2 = 150
s = -0.2
t = 1
T = 2
sigma = 0.3
    
St = S0 * jnp.exp(sigma * jnp.sqrt(t) * epsilon - 0.5 * (sigma ** 2) * t)

In [238]:
St[0][0]

Array(59.290325, dtype=float32)

In [259]:
y1

Array([[67.20589 ],
       [84.199646],
       [60.908493],
       ...,
       [39.765858],
       [44.904034],
       [69.98509 ]], dtype=float32)

In [240]:
y2.shape

(2, 1)

In [258]:
l = 0.2
sigma = 0.3
T = 2
t = 1

for _ in range(3):
    rng_key, _ = jax.random.split(rng_key)
    y1 = price(St, 100000, rng_key)
    rng_key, _ = jax.random.split(rng_key)
    y2 = price(St, 2, rng_key)
    
    K = stein_Laplace(y1, y2, l, sigma, T, t, St[0][0])
    print(K[:3, :3])
    

[[-8.5060405e-05 -1.0048301e-05]
 [-7.1182367e-13 -6.1113097e-22]
 [-2.4759015e-06 -2.1256673e-15]]
[[-0.0000000e+00 -0.0000000e+00]
 [-0.0000000e+00 -1.3722648e-33]
 [-4.2688006e-28 -3.3105731e-05]]
[[-0.0000000e+00 -1.8727757e-20]
 [-0.0000000e+00  0.0000000e+00]
 [-0.0000000e+00 -8.7386735e-07]]


In [248]:
l = 1.0
sigma = 0.3
T = 2
t = 1

for _ in range(3):
    rng_key, _ = jax.random.split(rng_key)
    y1 = price(St, 1000, rng_key)
    rng_key, _ = jax.random.split(rng_key)
    y2 = price(St, 2, rng_key)
    
    K = stein_Matern(y1, y2, l, sigma, T, t, St[0][0])
    print(K.mean(0)[0])
    

-0.014033108
0.008280542
-0.009189025


In [230]:
dx_laplace(y1, y2, l)[:3, :]

Array([[ 2.5079748e-15,  1.9564406e-16],
       [-3.0240935e-11, -3.8766063e-10],
       [-8.0577744e-04, -1.0329316e-02]], dtype=float32)

In [204]:
my_laplace(y1, y2, l)

Array([[0.94137335, 0.8306925 ],
       [0.87190187, 0.8968805 ],
       [0.7773773 , 0.9940991 ],
       [0.76439756, 0.97750074],
       [0.4532355 , 0.57959116],
       [0.9395032 , 0.7346838 ],
       [0.98705244, 0.77186686],
       [0.90508544, 0.86399776],
       [0.7723893 , 0.98772043],
       [0.7543249 , 0.96462005]], dtype=float32)

In [234]:
dxdy_laplace(y1, y2, l).shape

(100, 2)

In [233]:
dx_log_px(y1, sigma, T, t, St[0][0]).shape

(100, 1)

In [235]:
dx_log_px(y2, sigma, T, t, St[0][0]).shape

(2, 1)