## Pricing a European Call Option via JAX

#### Written for the QuantEcon Paris Workshop (September 2022)
#### Author: [John Stachurski](http://johnstachurski.net/)

In [4]:
import numpy as np
import matplotlib.pyplot as plt

Recall that we want to compute


$$ P = \beta^n \mathbb E \max\{ S_n - K, 0 \} $$

A common model for $\{S_t\}$ is

$$ \ln \frac{S_{t+1}}{S_t} = \mu + \sigma \xi_{t+1} $$

where $\{ \xi_t \}$ is IID and standard normal.  However, its predictions are in some ways counterfactual.  For example, volatility is not stationary but rather changes over time.  Here is an improved version:

$$ \ln \frac{S_{t+1}}{S_t} = \mu + \sigma_t \xi_{t+1} $$

where 

$$ 
    \sigma_t = \exp(h_t), 
    \quad
        h_{t+1} = \rho h_t + \nu \eta_{t+1}
$$

Here $\{\eta_t\}$ is also IID and standard normal.

Write a function that simulates the sequence $S_0, \ldots, S_n$, where the parameters are set to

With $s_t := \ln S_t$, the price dynamics become

$$ s_{t+1} = s_t + \mu + \exp(h_t) \xi_{t+1} $$

We use the following defaults.

In [5]:
μ  = 0.0001
ρ  = 0.1
ν  = 0.001
S0 = 10
h0 = 0
n  = 20

(Here `S0` is $S_0$ and `h0` is $h_0$.)

Here is a function to simulate a path using this equation:

In [6]:
from numpy.random import randn

def simulate_asset_price_path(μ=μ, S0=S0, h0=h0, n=n, ρ=ρ, ν=ν):
    s = np.empty(n+1)
    s[0] = np.log(S0)

    h = h0
    for t in range(n):
        s[t+1] = s[t] + μ + np.exp(h) * randn()
        h = ρ * h + ν * randn()
        
    return np.exp(s)

We used the following estimate of the price, computed via Monte Carlo and applying Numba and parallelization.

In [8]:
M = 10_000_000
K = 100
n = 10
β = 0.95

In [9]:
from numba import njit, prange

In [10]:
@njit(parallel=True)
def compute_call_price_parallel(β=0.99,
                                μ=μ,
                                S0=S0,
                                h0=h0,
                                K=K,
                                n=n,
                                ρ=ρ,
                                ν=ν,
                                M=M):
    current_sum = 0.0
    # For each sample path
    for m in prange(M):
        s = np.log(S0)
        h = h0
        # Simulate forward in time
        for t in range(n):
            s = s + μ + np.exp(h) * randn()
            h = ρ * h + ν * randn()
        # And add the value max{S_n - K, 0} to current_sum
        current_sum += np.maximum(np.exp(s) - K, 0)
        
    return β**n * current_sum / M

In [11]:
%%time
compute_call_price_parallel()

CPU times: user 7.68 s, sys: 6.33 ms, total: 7.69 s
Wall time: 1.62 s


1459.1607003586987

In [12]:
%%time
compute_call_price_parallel()

CPU times: user 7.01 s, sys: 11.7 ms, total: 7.02 s
Wall time: 897 ms


1317.2487300856776

### Exercise

Try to shift the whole operation to the GPU using JAX and test your speed gain.

### Solution

In [13]:
!nvidia-smi

Tue Sep  6 07:51:25 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:3B:00.0 Off |                  N/A |
| 30%   27C    P8    25W / 320W |      1MiB / 10240MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [14]:
import jax
import jax.numpy as jnp

In [15]:
@jax.jit
def compute_call_price_jax(β=0.99,
                           μ=μ,
                           S0=S0,
                           h0=h0,
                           K=K,
                           n=n,
                           ρ=ρ,
                           ν=ν,
                           M=M):

    s = jnp.full(M, np.log(S0))
    h = jnp.full(M, h0)
    for t in range(n):
        key = jax.random.PRNGKey(t)
        Z = jax.random.normal(key, (2, M))
        s = s + μ + jnp.exp(h) * Z[0, :]
        h = ρ * h + ν * Z[1, :]
    expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
        
    return β**n * expectation

In [16]:
%%time 
compute_call_price_jax()

CPU times: user 1.77 s, sys: 164 ms, total: 1.93 s
Wall time: 2.21 s


DeviceArray(1319.7174, dtype=float32)

In [17]:
%%time 
compute_call_price_jax()

CPU times: user 703 µs, sys: 14 µs, total: 717 µs
Wall time: 443 µs


DeviceArray(1319.7174, dtype=float32)