<a href="https://colab.research.google.com/github/mjauza/jax_example/blob/main/derivatives.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from scipy.stats import norm
import numpy as np

# Black Scholes model

In [3]:
def sim_gbm(n_timestamp, N_paths, mu, sigma, dt = 1/365, S0 = 1):
  key = random.PRNGKey(0)
  Z = random.normal(key, (n_timestamp,N_paths))
  arr = jnp.exp((mu - sigma**2/2)*dt + sigma*(dt)**0.5*Z)
  S0_arr = jnp.array([S0]*N_paths).reshape(1,-1)
  S = jnp.concatenate((S0_arr, arr), axis=0)
  S_arr = jnp.cumprod(S, axis = 0)
  return S_arr

def call_option_mc(T, N_paths, r, sigma, strike, S0 = 1):
  # simualate S at maturity
  S_path = sim_gbm(1, N_paths, mu=r, sigma=sigma, dt = 1, S0 = S0)
  ST = S_path[-1, :]
  payout = jnp.maximum(0, ST - strike)
  price = jnp.exp(-r*T)*jnp.mean(payout)
  return jnp.asarray(price)

def call_option_bs(T, r, sigma, strike, S0 = 1):
  d1 = 1/(sigma*np.sqrt(T))*(np.log(S0/strike) + (r - sigma**2/2)*T)
  d2 = d1 - sigma*np.sqrt(T)
  price = norm.cdf(d1)*S0 - norm.cdf(d2)*strike*np.exp(-r*T)
  return price



In [4]:
T = 1
r = 0.05
sigma = 0.2
S0 = 100
strike = 90
call_mc = call_option_mc(T = T, N_paths=100_000, r=r, sigma=sigma, strike=strike, S0 = S0)
call_bs = call_option_bs(T=T, r=r, sigma=sigma, strike=strike, S0=S0)

In [5]:
print(f"BS price: {call_bs} ; MC price: {call_mc}")

BS price: 16.580079495561144 ; MC price: 16.778467178344727


# Heston Model

In [36]:

def sim_heston(n_timestamp, n_paths,dt, rho, r, kappa, theta, sigma,S0, V0):
  key = random.PRNGKey(0)
  Z1 = random.normal(key, (n_timestamp,n_paths))
  Z2 = random.normal(key, (n_timestamp,n_paths))
  Zv = Z1
  Zs = rho*Z1 + (1 - rho**2)**0.5 * Z2
  del Z1, Z2
  S_list = []
  S0_arr = jnp.array([S0]*n_paths).reshape(1,-1)
  S_list.append(S0_arr)
  Vt = jnp.array([V0]*n_paths)
  for i in range(n_timestamp):
    St = S_list[i]
    dV = kappa*(theta - Vt)*dt + sigma*jnp.sqrt(Vt)*(dt**0.5)*Zv[i, :]
    V_t1 = Vt + dV
    dS = r*St*dt + jnp.sqrt(Vt)*St*(dt**0.5)*Zs[i,:]
    S_t1 = St + dS
    Vt = V_t1
    S_list.append(S_t1.reshape(1,-1))

  S_arr = jnp.concatenate(S_list, axis = 0)
  return S_arr

def call_option_heston_mc(T, n_paths, rho, r, kappa, theta, sigma,S0, V0, strike):
  # simualate S at maturity
  S_path = sim_heston(1, n_paths,1, rho, r, kappa, theta, sigma,S0, V0)
  ST = S_path[-1, :]
  payout = jnp.maximum(0, ST - strike)
  price = jnp.exp(-r*T)*jnp.mean(payout)
  return jnp.asarray(price)


def get_heston_cfun(j, kappa, theta, rho, lam, T, q = 0):
  a = kappa*theta
  if j == 1:
    uj = 1/2
    bj = kappa + lam - rho*sigma
  else:
    uj = -1/2
    bj = kappa + lam

  def cfun(psi, t, xt, vt):
    tau = T - t
    dj = np.sqrt((bj - rho*sigma*1j*psi)**2 - sigma**2*(2*1j*uj*psi - psi**2))
    gj = (bj - rho*sigma*1j*psi + dj)/(bj - rho*sigma*1j*psi - dj)

    C = (r-q)*1j*psi*tau + a/sigma**2*((bj - rho*sigma*1j*psi + dj)*tau - 2*np.log((1-gj*np.exp(dj*tau))/(1-gj)))
    D = (bj - rho*sigma*1j*psi + dj)/sigma**2*((1-np.exp(dj*tau))/(1-gj*np.exp(dj*tau)))
    f =  np.exp(C + D*vt + 1j*psi*xt)
    return f

  return cfun

def simpson13(x0,xn,n, f):
    # calculating step size
    h = (xn - x0) / n

    # Finding sum
    integration = f(x0) + f(xn)

    for i in range(1,n):
        k = x0 + i*h

        if i%2 == 0:
            integration = integration + 2 * f(k)
        else:
            integration = integration + 4 * f(k)

    # Finding final integration value
    integration = integration * h/3

    return integration

def get_Pj(xt, vt,t, log_K, j, kappa, theta, rho, lam, max_x):
  fj = get_heston_cfun(j, kappa, theta, rho, lam, T, q = 0)
  fj_fun = lambda psi,  : fj(psi, t, xt, vt)
  def fun(psi):
    return np.real((np.exp(-1j*psi*log_K)*fj_fun(psi))/(1j*psi))

  int_part = simpson13(0.001,max_x,n=100_000, f=fun)
  Pj = 1/2 + 1/np.pi*int_part
  return Pj

def get_heston_call_an(st, vt, t,T, K, kappa, theta, rho, lam=0.5):
  tau = T - t

  xt = np.log(st)
  max_x = 100*xt
  P1 = get_Pj(xt=xt, vt=vt,t=t, log_K=np.log(K), j=1, kappa=kappa, theta=theta, rho=rho, lam=lam, max_x = max_x)
  P2 = get_Pj(xt=xt, vt=vt,t=t, log_K=np.log(K), j=2, kappa=kappa, theta=theta, rho=rho, lam=lam, max_x = max_x)
  price = st*P1 - K*np.exp(-r*tau)*P2
  return price



In [43]:
lam = 1
r = 0.05
kappa = 1
kappa_star = kappa + lam
theta = 1
theta_star = kappa*theta/(kappa + lam)
sigma = 0.02
S0 = 100
V0 = 0.1
strike = 90
rho = 0.5
call_heston_mc = call_option_heston_mc(T=1, n_paths=10_000, rho=rho, r=r, kappa=kappa_star, theta=theta_star, sigma=sigma,S0=S0, V0=V0, strike=strike)
call_heston_an = get_heston_call_an(st=S0, vt=V0, t=0,T=1, K=strike, kappa=kappa, theta=theta, rho=0.5, lam=lam)

In [44]:
print(call_heston_mc)

24.94186


In [41]:
print(call_heston_an)

28.817079086783593
