In [1]:
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import jax.numpy as jnp
import jax_quant_finance as jqf
import jax.scipy.stats as stats

In [2]:
dtype= jnp.float64

In [3]:
#  BS binary option prices
forwards = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
strikes = np.array([3.0, 3.0, 3.0, 3.0, 3.0])
volatilities = np.array([0.0001, 102.0, 2.0, 0.1, 0.4])
expiries = 1.0
computed_prices = jqf.black_scholes.binary_price(
        volatilities=volatilities,
        strikes=strikes,
        expiries=expiries,
        forwards=forwards)

print(computed_prices)



[0.         0.         0.15865525 0.99764937 0.85927418]


In [6]:
# binary prices bulk
np.random.seed(321)
num_examples = 1000
forwards = np.exp(np.random.normal(size=num_examples))
strikes = np.exp(np.random.normal(size=num_examples))
volatilities = np.exp(np.random.normal(size=num_examples))
expiries = np.random.gamma(shape=1.0, scale=1.0, size=num_examples)
log_scale = np.sqrt(expiries) * volatilities
log_loc = np.log(forwards) - 0.5 * log_scale**2
call_options = np.random.binomial(n=1, p=0.5, size=num_examples)
discount_factors = np.random.beta(a=1.0, b=1.0, size=num_examples)

cdf_values = stats.norm.cdf(x=np.log(strikes), loc=log_loc, scale=log_scale)

expected_prices = discount_factors * (
    call_options + ((-1.0)**call_options) * cdf_values)

is_call_options = np.array(call_options, dtype=np.bool)
computed_prices = jqf.black_scholes.binary_price(
        volatilities=volatilities,
        strikes=strikes,
        expiries=expiries,
        forwards=forwards,
        is_call_options=is_call_options,
        discount_factors=discount_factors)

print(computed_prices)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


[3.58545596e-002 9.69207118e-001 1.34239390e-004 2.23953148e-002
 1.62040879e-003 2.88733980e-002 4.70586148e-002 9.83563773e-002
 1.13260899e-002 3.30920386e-002 2.95348097e-002 6.51747359e-003
 3.75060485e-003 8.45758322e-001 5.72902644e-001 4.40049340e-002
 4.66335325e-002 6.33137574e-004 6.89176996e-019 3.37375324e-002
 4.07608463e-001 1.49692875e-001 6.94546499e-003 5.72520725e-002
 7.36813543e-001 0.00000000e+000 1.18630748e-004 7.46170433e-001
 1.14143550e-003 2.36988656e-001 1.75137082e-004 8.22553506e-002
 1.94425715e-001 1.78120028e-001 7.00162552e-002 4.66477854e-002
 5.70145310e-001 4.88917265e-001 1.21004863e-003 7.30883096e-002
 2.83024360e-001 7.49798717e-002 1.82826745e-002 9.80980326e-001
 9.11462857e-001 8.94344631e-001 1.40011884e-001 1.73672375e-001
 9.27605846e-001 4.36884625e-004 8.42508163e-004 6.95085825e-003
 0.00000000e+000 1.81776196e-005 4.43363310e-002 8.17280089e-001
 1.17474036e-001 1.96439740e-005 9.14414902e-001 1.21957236e-002
 1.95320770e-002 1.128159

In [8]:
# asset_or_nothing_prices
forwards = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
strikes = np.array([3.0, 3.0, 3.0, 3.0, 3.0])
volatilities = np.array([0.0001, 102.0, 2.0, 0.1, 0.4])
expiries = 1.0
computed_prices = jqf.black_scholes.asset_or_nothing_price(
        volatilities=volatilities,
        strikes=strikes,
        expiries=expiries,
        forwards=forwards)
print(computed_prices)

[0.         2.         2.52403424 3.99315108 4.65085383]
