In [5]:
import numpy as np
import jax.numpy as jnp
from jax import jit, jvp, default_backend, jacfwd

from jax.config import config
config.update("jax_enable_x64", True)
# We use GPU as the default backend.
# If you want to use cpu as backend, uncomment the following line.
# config.update("jax_platform_name", "cpu")

from jax_quant_finance.black_scholes.vanilla_prices import option_price

Check backend (cpu/gpu/tpu):

In [3]:
print(f"we are running on {default_backend()}")

we are running on gpu


In [6]:
dtype = jnp.float64
expiries = jnp.array([1.0], dtype=dtype)
strikes = jnp.array([600, 650, 680], dtype=dtype)
sigma = jnp.array(0.1, dtype=dtype)
spot = jnp.array(700, dtype=dtype)
rate = jnp.array(0.03, dtype=dtype)

jit_option_price_fn = jit(option_price)

true_delta_fn = lambda spot : jit_option_price_fn(volatilities=sigma,
                                            strikes=strikes,
                                            spots=spot,
                                            expiries=expiries,
                                            discount_rates=rate)

true_vega_fn = lambda sigma : jit_option_price_fn(volatilities=sigma,
                               strikes=strikes,
                               spots=spot,
                               expiries=expiries,
                               discount_rates=rate)

prices, delta = jvp(true_delta_fn, (spot, ), (jnp.ones_like(spot), ))
prices, vega = jvp(true_vega_fn, (sigma, ), (jnp.ones_like(sigma), ))

print(f"prices: \n{prices} \n")
print(f"true delta: \n {delta} \n")
print(f"true vega: \n {vega}")

prices: 
[118.55230522  74.32285307  51.74306127] 

true delta: 
 [0.97072164 0.8623811  0.73887319] 

true vega: 
 [ 46.67659297 153.994105   227.5617336 ]


In [None]:

jacfwd(jit_option_price_fn, (0, 2))(sigma, )