<a href="https://colab.research.google.com/github/mjauza/jax_example/blob/main/derivatives.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from scipy.stats import norm
import numpy as np

# Black Scholes model

In [13]:
def sim_gbm(n_timestamp, N_paths, mu, sigma, dt = 1/365, S0 = 1):
  key = random.PRNGKey(0)
  Z = random.normal(key, (n_timestamp,N_paths))
  arr = jnp.exp((mu - sigma**2/2)*dt + sigma*(dt)**0.5*Z)
  S0_arr = jnp.array([S0]*N_paths).reshape(1,-1)
  S = jnp.concatenate((S0_arr, arr), axis=0)
  S_arr = jnp.cumprod(S, axis = 0)
  return S_arr

def call_option_mc(T, N_paths, r, sigma, strike, S0 = 1):
  # simualate S at maturity
  S_path = sim_gbm(1, N_paths, mu=r, sigma=sigma, dt = 1, S0 = S0)
  ST = S_path[-1, :]
  payout = jnp.maximum(0, ST - strike)
  price = jnp.exp(-r*T)*jnp.mean(payout)
  return jnp.asarray(price)

def call_option_bs(T, r, sigma, strike, S0 = 1):
  d1 = 1/(sigma*np.sqrt(T))*(np.log(S0/strike) + (r - sigma**2/2)*T)
  d2 = d1 - sigma*np.sqrt(T)
  price = norm.cdf(d1)*S0 - norm.cdf(d2)*strike*np.exp(-r*T)
  return price



In [16]:
T = 1
r = 0.05
sigma = 0.2
S0 = 100
strike = 90
call_mc = call_option_mc(T = T, N_paths=100_000, r=r, sigma=sigma, strike=strike, S0 = S0)
call_bs = call_option_bs(T=T, r=r, sigma=sigma, strike=strike, S0=S0)

In [17]:
print(f"BS price: {call_bs} ; MC price: {call_mc}")

BS price: 16.580079495561144 ; MC price: 16.778467178344727


# Heston Model

In [24]:

def sim_heston(n_timestamp, n_paths,dt, rho, r, kappa, theta, sigma,S0, V0):
  key = random.PRNGKey(0)
  Z1 = random.normal(key, (n_timestamp,n_paths))
  Z2 = random.normal(key, (n_timestamp,n_paths))
  Zv = Z1
  Zs = rho*Z1 + (1 - rho**2)**0.5 * Z2
  del Z1, Z2
  S_list = []
  S0_arr = jnp.array([S0]*n_paths).reshape(1,-1)
  S_list.append(S0_arr)
  Vt = jnp.array([V0]*n_paths)
  for i in range(n_timestamp):
    St = S_list[i]
    dV = kappa*(theta - Vt)*dt + sigma*jnp.sqrt(Vt)*(dt**0.5)*Zv[i, :]
    V_t1 = Vt + dV
    dS = r*St*dt + jnp.sqrt(Vt)*St*(dt**0.5)*Zs[i,:]
    S_t1 = St + dS
    Vt = V_t1
    S_list.append(S_t1.reshape(1,-1))

  S_arr = jnp.concatenate(S_list, axis = 0)
  return S_arr

def call_option_heston_mc(T, n_paths, rho, r, kappa, theta, sigma,S0, V0, strike):
  # simualate S at maturity
  S_path = sim_heston(1, n_paths,1, rho, r, kappa, theta, sigma,S0, V0)
  ST = S_path[-1, :]
  payout = jnp.maximum(0, ST - strike)
  price = jnp.exp(-r*T)*jnp.mean(payout)
  return jnp.asarray(price)

In [25]:
call_heston_mc = call_option_heston_mc(T=1, n_paths=10_000, rho=0.5, r=0.05, kappa=1, theta=1, sigma=0.2,S0=100, V0=0.1, strike=90)

In [26]:
print(call_heston_mc)

24.94186
