In [7]:
import jax
import jax.numpy as jnp
import jax.random as random

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates

def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b

def payoff_put(S):    
  return jnp.maximum(K - S, 0.)

def chebyshev_basis(X,k):
  def func(carry, x):
    temp, temp_1 = carry
    new = 2 * X * temp - temp_1
    return (new, temp), new
  return jax.lax.scan(func, (X, jnp.ones(len(X))) , xs=None, length=k)

def step(S, xs):
  dZ = xs * jnp.sqrt(Δt)
  dS = r * S  * Δt + σ  * S  * dZ
  S = S + dS
  return S, S

def compute_price():
  key = random.PRNGKey(10)
  S0 = Spot*jnp.ones(n)
  xs = jnp.array(random.normal(key,shape=(m,n)))
  S = jax.lax.scan(step, S0, xs)[1]
  discount = jnp.exp(-r * Δt)

  value_if_exercise = payoff_put(S[-1])
  discounted_future_cashflows = value_if_exercise * discount

  def func(init, x):
    X = chebyshev_basis(scale(S[-2 - x]), order)[1].T
    Y = init
    Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
    value_if_wait = X @ Θ
    value_if_exercise = payoff_put(S[-2 - x])
    exercise = value_if_exercise >= value_if_wait
    Y = discount * jnp.where(exercise, value_if_exercise, Y)
    return Y, Y

  return jax.lax.scan(func, discounted_future_cashflows, xs = jnp.array(range(m-1)))[0].mean()

In [8]:
print(compute_price())

4.462592
