In [11]:
# !pip install git+https://github.com/crispitagorico/sigkernel.git
# !git clone 'https://github.com/ryanmccrickerd/rough_bergomi.git'

In [5]:
import numpy as np
import torch
import time
import pickle
import seaborn as sns

from rbergomi import rBergomi_MC_pricer, rBergomi_sigkernel_pricer, grid_search_sigkernel_rBergomi
from utils import *

import warnings
warnings.filterwarnings('ignore')

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
with open('data.p', 'rb') as fp:
    data = pickle.load(fp)

sigma_t = data['sigma_t']
sigma_x = data['sigma_x']
sigma_sig = data['sigma_sig']
lambda_ = data['lambda_']

In [9]:
print(sigma_t, sigma_x, sigma_sig, lambda_)

100.0 1.0 1000.0 0.01


In [18]:
# model parameters
T          = 1.
xi         = 0.055
eta        = 1.9
rho        = -0.9
n_incs     = 10
x_var      = 1.
n_mc_exact = 10000
n_eval     = 50

# sigkernel PDE computation params
dyadic_order, max_batch = 0, 200 

In [21]:
for a in [-0.4999, -0.4, -0.2, 0.3]:
    
    # evaluation points
    t_inds_eval = np.random.choice(n_incs, n_eval)
    xs_eval     = generate_xs(xi, x_var, t_inds_eval)
    paths_eval  = generate_theta_paths(t_inds_eval, n_incs, T, a)
    
    for strike in [0.5, 1., 1.5]:
        
        # call payoff
        payoff = lambda x: max(np.exp(x) - strike, 0.)
        
        # ground truth prices
        print('Computing true prices...')
        mc_pricer_exact = rBergomi_MC_pricer(n_incs, n_mc_exact, T, a, xi, eta, rho)
        mc_prices_exact = mc_pricer_exact.fit_predict(t_inds_eval, xs_eval, paths_eval, payoff)
        print('Finshed computing true prices.')
                
#         plt.figure(figsize=(8,3))
#         plt.plot(mc_prices_exact)
#         plt.title('H: %2.4f, Strike: %2.2f' % (0.5+a, strike))    
#         plt.show() 
                
        # for error_fn, error_type in zip([mse, mae], ['MSE', 'MAE']):
        for error_fn, error_type in zip([mse,], ['MSE',]):

            if error_type == 'MSE':
                precisions = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]
            else:
                precisions = [1e-1]

            for precision in precisions:

                # MC prices
                n_mc, error_mc = 500, 1e9
                timeout = time.time() + 60*0.5   # 5 minutes from now
                while error_mc > precision:
                    mc_pricer = rBergomi_MC_pricer(n_incs, n_mc, T, a, xi, eta, rho)
                    t0 = time.time()
                    mc_prices = mc_pricer.fit_predict(t_inds_eval, xs_eval, paths_eval, payoff)
                    t1 = time.time()
                    error_mc  = error_fn(mc_prices, mc_prices_exact)
                    n_mc += 500
                    if time.time() > timeout:
                        print('Runtime exceeded. Samples: {}. Error: {}'.format(n_mc, error_mc))
                        flag = False
                        break

                print('MC | H: %2.4f | Strike: %2.2f | Error type: %r | Precision: %2.4f | Time: %2.3f sec | Samples: %r' % (0.5+a, strike, error_type, precision, t1-t0, n_mc))

#                 # PPDE prices            
#                 m, n, error_sig = 200, 150, 1e9
#                 timeout = time.time() + 60*5   # 5 minutes from now
#                 flag = False
#                 while error_sig > precision:
#                     sig_pricer = rBergomi_sigkernel_pricer(n_incs, x_var, m, n, T, a, xi, eta, rho, sigma_t, sigma_x, sigma_sig, dyadic_order, max_batch, device, lambda_)
#                     sig_pricer.fit(payoff)
#                     t0 = time.time()
#                     sig_prices = sig_pricer.predict(t_inds_eval, xs_eval, paths_eval) 
#                     t1 = time.time()
#                     error_sig = error_fn(sig_prices, mc_prices_exact)
#                     torch.cuda.empty_cache()
#                     m += 200
#                     n += 150
#                     flag = True
#                     if time.time() > timeout:
#                         print('Runtime exceeded. Collocation points: ({}, {}). Error: {}'.format(m, n, error_sig))
#                         flag = False
#                         break
                
#                 if flag:
#                     print('SK | H: %2.4f | Strike: %2.1f | Error type: %r | Precision: %2.8f| Time: %2.4f sec | Cpoints: (%r,%r)' % (0.5+a, strike, error_type, precision, t1-t0, m, n))

Computing true prices...
Finshed computing true prices.
MC | H: 0.0001 | Strike: 0.50 | Error type: 'MSE' | Precision: 0.0100 | Time: 0.147 sec | Samples: 1000
MC | H: 0.0001 | Strike: 0.50 | Error type: 'MSE' | Precision: 0.0010 | Time: 0.146 sec | Samples: 1000
MC | H: 0.0001 | Strike: 0.50 | Error type: 'MSE' | Precision: 0.0001 | Time: 0.277 sec | Samples: 1500
MC | H: 0.0001 | Strike: 0.50 | Error type: 'MSE' | Precision: 0.0000 | Time: 0.400 sec | Samples: 2000
Runtime exceeded. Samples: 11500. Error: 5.823837515852443e-06
MC | H: 0.0001 | Strike: 0.50 | Error type: 'MSE' | Precision: 0.0000 | Time: 2.788 sec | Samples: 11500
Computing true prices...
Finshed computing true prices.
MC | H: 0.0001 | Strike: 1.00 | Error type: 'MSE' | Precision: 0.0100 | Time: 0.149 sec | Samples: 1000
MC | H: 0.0001 | Strike: 1.00 | Error type: 'MSE' | Precision: 0.0010 | Time: 0.147 sec | Samples: 1000
MC | H: 0.0001 | Strike: 1.00 | Error type: 'MSE' | Precision: 0.0001 | Time: 0.273 sec | Sample

KeyboardInterrupt: 