In [6]:
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import jax.numpy as jnp
import jax_quant_finance as jqf
import jax.scipy.stats as stats

In [4]:
dtype = jnp.float64

In [2]:
# Vanilla Black Scholes
forwards = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
strikes = np.array([3.0, 3.0, 3.0, 3.0, 3.0])
volatilities = np.array([0.0001, 102.0, 2.0, 0.1, 0.4])
expiries = 1.0
computed_prices = jqf.black_scholes.option_price(
        volatilities=volatilities,
        strikes=strikes,
        expiries=expiries,
        forwards=forwards)
print(computed_prices)

[0.         2.         2.04806848 1.00020297 2.07303131]


In [3]:
# Option Pricing Normal Model
forwards = np.array([0.01, 0.02, 0.03, 0.03, 0.05])
strikes = np.array([0.03, 0.03, 0.03, 0.03, 0.03])
volatilities = np.array([0.0001, 0.001, 0.01, 0.005, 0.02])
expiries = 1.0
computed_prices = jqf.black_scholes.option_price(
        volatilities=volatilities,
        strikes=strikes,
        expiries=expiries,
        forwards=forwards,
        is_normal_volatility=True)

print(computed_prices)

[0.00000000e+00 7.47456025e-28 3.98942280e-03 1.99471140e-03
 2.16663094e-02]


In [5]:
# option prices detailed discount
spots = np.array([80.0, 90.0, 100.0, 110.0, 120.0] * 2)
strikes = np.array([100.0] * 10)
discount_rates = 0.08
volatilities = 0.2
expiries = 0.25

is_call_options = np.array([True] * 5 + [False] * 5)
dividend_rates = 0.12
computed_prices = jqf.black_scholes.option_price(
        volatilities=volatilities,
        strikes=strikes,
        expiries=expiries,
        spots=spots,
        discount_rates=discount_rates,
        dividend_rates=dividend_rates,
        is_call_options=is_call_options,
        dtype=dtype)

print(computed_prices)


[ 0.02909021  0.5699856   3.4211088   9.84695715 18.61802275 20.41331485
 11.24975491  4.39642278  1.11781579  0.18442606]
