In [1]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import seaborn as sns
import jax.numpy as jnp
import jax

import rbergomi
from utils import *

import warnings
warnings.filterwarnings('ignore')

jax.config.update("jax_enable_x64", True)

In [5]:
# rough Bergomi parameters
T            = 1.0
xi           = 0.055
eta          = 1.9
rho          = - 0.9
x_var        = 3.0
eps_paths    = 1e-2
dtype        = jnp.float64
path_measure = 'brownian'
n_increments = 10
t_grid       = jnp.linspace(0, T, 1 + n_increments)
n_eval       = 10

# sigkernel hyperparameters
sig_samples_in          = [10, 50, 250] 
interior_boundary_ratio = 0.75
sig_samples_b           = [int(interior_boundary_ratio * k) for k in sig_samples_in]
t_scale                 = 1e0
x_scale                 = 1e1
sig_scales              = 1e-2 * jax.random.exponential(getkey(), shape=(1,), dtype=jnp.float64)
refinement_factor       = 1
static_kernel_kind      = 'linear'
eps_derivatives         = 1e-2

# MC sample paths
mc_samples  = [int((1.0 + interior_boundary_ratio) * k) for k in sig_samples_in]

In [8]:
import importlib
importlib.reload(rbergomi)

<module 'rbergomi' from '/mnt/batch/tasks/shared/LS_root/mounts/clusters/csalvi1/code/Users/csalvi/sigppde/rbergomi.py'>

In [9]:
x_var

3.0

In [16]:
for a in tqdm([-0.4, -0.2]):
        
    for strike in tqdm([0.1, 1.0, 1.5]):
        
        log_strike = jnp.log(strike)
                    
        # evaluation points
        t_inds_eval  = jax.random.choice(getkey(), a=jnp.arange(n_increments), shape=(n_eval,)) 
        # t_inds_eval  = jax.random.choice(getkey(), a=jnp.array([n_increments]), shape=(n_eval,))
        t_eval = jnp.array([t_grid[t] for t in t_inds_eval], dtype=jnp.float64)
        xs_eval      = jnp.array(generate_xs(xi, x_var, t_eval), dtype=jnp.float64)[:,0]
        # paths_eval = generate_brownian_paths(T, n_increments, n_eval)
        paths_eval   = generate_theta_paths(t_inds_eval, n_increments, T, a, eps=eps_paths)

        # payoff
        payoff = lambda x: max(jnp.exp(x) - jnp.exp(log_strike), 0.)

        # true prices
        true_prices = rbergomi.rBergomi_MC_pricer(n_increments, 5000, T, a, xi, eta, rho).fit_predict(t_inds_eval, xs_eval, paths_eval, payoff)
        
        # MC prices
        mses_mc = []
        maes_mc = []
        # fig, ax = plt.subplots(1, len(mc_samples), figsize=(14, 4))
        for i, n_mc in enumerate(mc_samples):
            mc_pricer = rbergomi.rBergomi_MC_pricer(n_increments, n_mc, T, a, xi, eta, rho)
            mc_prices = mc_pricer.fit_predict(t_inds_eval, xs_eval, paths_eval, payoff)
            mses_mc.append(mse(true_prices, mc_prices))
            maes_mc.append(mae(true_prices, mc_prices)) 
            # sns.regplot(x=true_prices, y=mc_prices, ax=ax[i])
        # plt.tight_layout()
        # plt.show()
                
        # sigkernel prices
        mses_sig = []
        maes_sig = []
        # fig, ax = plt.subplots(1, len(sig_samples_in), figsize=(14, 4))
        for i, (m, n) in enumerate(zip(sig_samples_in, sig_samples_b)):
            sig_pricer = rbergomi.rBergomi_sigkernel_pricer(n_increments=n_increments, x_var=x_var, m=m, n=n, T=T, a=a, xi=xi, eta=eta, rho=rho, 
                                                            t_scale=t_scale, x_scale=x_scale, sig_scales=sig_scales, 
                                                            path_measure=path_measure, 
                                                            refinement_factor=refinement_factor, 
                                                            static_kernel_kind=static_kernel_kind, 
                                                            dtype=dtype, 
                                                            eps_paths=eps_paths, 
                                                            eps_derivatives=eps_derivatives)
            sig_pricer.fit(payoff)

            # t_inds_eval = sig_pricer.t_inds_boundary
            # xs_eval = sig_pricer.xs_boundary
            # paths_eval = sig_pricer.paths_boundary
            # true_prices = jnp.array([payoff(x) for x in xs_eval])

            # t_inds_eval = sig_pricer.t_inds
            # xs_eval = sig_pricer.xs
            # paths_eval = sig_pricer.paths
            # true_prices = rbergomi.rBergomi_MC_pricer(n_increments, 100, T, a, xi, eta, rho).fit_predict(t_inds_eval, xs_eval, paths_eval, payoff)

            # true_prices = jnp.exp(xs_eval)

            sig_prices = sig_pricer.predict(t_inds_eval, xs_eval, paths_eval) 
            sig_prices = jnp.round(sig_prices, 3)

            mses_sig.append(mse(true_prices, sig_prices))
            maes_sig.append(mae(true_prices, sig_prices))
            
            # sns.regplot(x=true_prices, y=sig_prices, ax=ax[i])

        # plt.tight_layout()
        # plt.show()
                        
        print(f'a = {a}, log strike = {log_strike}')
        print(f'MC: MSE = {mses_mc}')
        print(f'MC: MAE = {maes_mc}')
        print(f'Sig: MSE = {mses_sig}')
        print(f'Sig: MAE = {maes_sig}', '\n')
    
        # fig, ax = plt.subplots(1, 2, figsize=(10,4))
        # ax[0].plot(mc_samples, mses_mc, label='MSE MC')
        # ax[0].plot(mc_samples, mses_sig, label='MSE Sig')
        # ax[1].plot(mc_samples, maes_mc, label='MAE MC')
        # ax[1].plot(mc_samples, maes_sig, label='MAE Sig')
        # ax[0].legend()
        # ax[1].legend()
        # ax[0].set_title(f'a: {a}, log-strike: {log_strike}')
        # plt.show()

  0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [01:42<03:25, 102.66s/it][A

a = -0.4, log strike = -2.3025850929940455
MC: MSE = [Array(0.2070686, dtype=float64), Array(0.0079591, dtype=float64), Array(0.00803032, dtype=float64)]
MC: MAE = [Array(1.31938771, dtype=float64), Array(0.26832033, dtype=float64), Array(0.23002447, dtype=float64)]
Sig: MSE = [Array(0.20841061, dtype=float64), Array(0.00169831, dtype=float64), Array(0.00168679, dtype=float64)]
Sig: MAE = [Array(0.63751202, dtype=float64), Array(0.07144307, dtype=float64), Array(0.06721668, dtype=float64)] 




 67%|██████▋   | 2/3 [03:24<01:42, 102.35s/it][A

a = -0.4, log strike = 0.0
MC: MSE = [Array(0.00047504, dtype=float64), Array(0.00031074, dtype=float64), Array(1.73924033e-05, dtype=float64)]
MC: MAE = [Array(0.06798544, dtype=float64), Array(0.05571345, dtype=float64), Array(0.01314468, dtype=float64)]
Sig: MSE = [Array(0.15769603, dtype=float64), Array(0.00921622, dtype=float64), Array(0.00387793, dtype=float64)]
Sig: MAE = [Array(0.79987951, dtype=float64), Array(0.26712049, dtype=float64), Array(0.14381548, dtype=float64)] 




100%|██████████| 3/3 [05:07<00:00, 102.60s/it][A
 50%|█████     | 1/2 [05:07<05:07, 307.79s/it]

a = -0.4, log strike = 0.4054651081081644
MC: MSE = [Array(0.0044637, dtype=float64), Array(0.00843129, dtype=float64), Array(0.0005036, dtype=float64)]
MC: MAE = [Array(0.1552092, dtype=float64), Array(0.21004152, dtype=float64), Array(0.04841373, dtype=float64)]
Sig: MSE = [Array(0.63052452, dtype=float64), Array(0.01221787, dtype=float64), Array(0.00735065, dtype=float64)]
Sig: MAE = [Array(1.835, dtype=float64), Array(0.151, dtype=float64), Array(0.17217488, dtype=float64)] 




  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [01:41<03:22, 101.39s/it][A

a = -0.2, log strike = -2.3025850929940455
MC: MSE = [Array(0.01982215, dtype=float64), Array(0.00059435, dtype=float64), Array(0.00135325, dtype=float64)]
MC: MAE = [Array(0.3330306, dtype=float64), Array(0.06248508, dtype=float64), Array(0.07182088, dtype=float64)]
Sig: MSE = [Array(0.11989827, dtype=float64), Array(0.00027182, dtype=float64), Array(0.00269041, dtype=float64)]
Sig: MAE = [Array(0.63449488, dtype=float64), Array(0.02849488, dtype=float64), Array(0.0973111, dtype=float64)] 




 67%|██████▋   | 2/3 [03:16<01:37, 97.95s/it] [A

a = -0.2, log strike = 0.0
MC: MSE = [Array(0.01133104, dtype=float64), Array(0.00153786, dtype=float64), Array(0.00061837, dtype=float64)]
MC: MAE = [Array(0.23417776, dtype=float64), Array(0.0945221, dtype=float64), Array(0.06798987, dtype=float64)]
Sig: MSE = [Array(0.42017456, dtype=float64), Array(0.00349338, dtype=float64), Array(0.00467211, dtype=float64)]
Sig: MAE = [Array(1.1832388, dtype=float64), Array(0.105, dtype=float64), Array(0.14306578, dtype=float64)] 




100%|██████████| 3/3 [04:50<00:00, 96.77s/it][A
100%|██████████| 2/2 [09:58<00:00, 299.05s/it]

a = -0.2, log strike = 0.4054651081081644
MC: MSE = [Array(0.00481117, dtype=float64), Array(0.04727701, dtype=float64), Array(0.00410989, dtype=float64)]
MC: MAE = [Array(0.2161007, dtype=float64), Array(0.44946254, dtype=float64), Array(0.1225046, dtype=float64)]
Sig: MSE = [Array(1.06280338, dtype=float64), Array(0.02118085, dtype=float64), Array(0.00350427, dtype=float64)]
Sig: MAE = [Array(1.66057655, dtype=float64), Array(0.33391094, dtype=float64), Array(0.1007492, dtype=float64)] 




