In [1]:
pip install optax dm-haiku

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Conversion Steps
1. Numpy
2. Jax normal eqn (chebyshev polynomials)
<!-- 3. (2.5)Jax (normal polynomials) -->
3. Jax (Normal Polynomials, gradient descent): we revert back to normal polynomaisl because we don't need chebyshev in neural networks
4. Jax stochastic gradient *descent*
5. Jax (Neueral Network, stochastic gradient descent)
6. Price Exotic options

Due date: October 21 2022 

# Description
  In this problem we will use apply the LSMC method to price American put options. Specifically, we will replicate the result in the first row, 6th column of Table 1 in [Longstaff and Schwartz 2001](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf)

  

*  Read the introduction of the [paper](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf).
*   We will price an american put option as described in page 126 of the aforementioned article. Read paragraphs 1 and 2 of page 126
* As we saw in class, one of the ways we can use linear regression to fit nonlinear functions is to use polynomial features. A common choice in many applications is to use the so called ``Chebyshev polynomials''. Chebyshev polynomials are defined recursively by:

\begin{equation}
T_0(x) = 1\\
T_1(x) = x\\
T_{n + 1}(x)  =  2 x T_n(x) - T_{n - 1}(x)\\
\end{equation}


# Part 1
The code below simulates the evolution of a stock price that follows a geometric brownian motion. Write a JAX version of that code. You are not allowed to use functions from other libraries. For this part, the "simulate"
function does not need to be jit compiled. As we will see, jit compiling a funciton with for loops may introduce some complications
 

In [None]:
import numpy as np

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100


def simulate():
  np.random.seed(0)

  def step(S):
    dZ = np.random.normal(size=S.size) * np.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S

  S0 = np.ones(20000)
  S = S0
  S_list = []

  for t in range(m):
    S = step(S)
    S_list.append(S)

  S_array = np.stack(S_list)
  return S_array


In [None]:
import jax
import jax.numpy as jnp

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100


def simulate():
  key = jnp.randomPRNGKey(0)

  def step(S):
    rng, _ = jax.random.split(rng) #psuedo random number generator. not really random but it makes it looks random
    dZ = jnp.random.normal(rng, (S.size)) * np.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S

  S0 = jnp.ones(20000)
  S = S0
  S_list = []

  for t in range(m):
    S = step(S)
    S_list.append(S)

  S_array = np.stack(S_list)
  return S_array


# Part 2
Write a jit compiled version of the simulate function. You may want to check out the function jax.lax.scan.


In [None]:
import jax
import jax.numpy as jnp

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

@jax.jit
def simulate():
  seed = jax.random.PRNGKey(0)

  def core(state, input):
    S = state
    rng = input    
    dZ = jax.random.normal(rng, (S.size, )) * np.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S, S #first S is final state, 2nd S is the vector or matrix of all states found

  S0 = jnp.ones(int(m/dt))
  S = S0
  S_list = []

  rng_vector = jax.random.split(seed, m) #m is the length of vector
  state = S
  input = rng_vector

  state, out = jax.lax.scan(core, state, input)

  return state, out

In [None]:
simulate()

(DeviceArray([0.9991813, 1.066116 , 0.9695461, ..., 0.9957512, 1.0195965,
              0.9298693], dtype=float32),
 DeviceArray([[0.99995965, 1.0011096 , 0.99595296, ..., 0.9970569 ,
               1.0003599 , 0.99666893],
              [1.0031309 , 1.0017953 , 0.9969345 , ..., 0.9964643 ,
               1.0038731 , 1.0018477 ],
              [1.0038399 , 1.0019187 , 0.99650645, ..., 0.99268675,
               1.005771  , 1.006995  ],
              ...,
              [1.0062256 , 1.066818  , 0.9617757 , ..., 0.9942317 ,
               1.0212085 , 0.93702674],
              [1.003928  , 1.067885  , 0.96726656, ..., 0.99592817,
               1.0197272 , 0.9276689 ],
              [0.9991813 , 1.066116  , 0.9695461 , ..., 0.9957512 ,
               1.0195965 , 0.9298693 ]], dtype=float32))

# Part 3
The code below is computes the price of an American Put option using Least Squares Monte Carlo (LSMC). Write a JAX version of that code. You are not allowed to use functions from other libraries. Your "compute_price" function must be jit compiled.

1. Numpy version of the code

In [None]:
import numpy as np

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 100      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates


# Construct polynomial features of order up to k using the
# recursive formulation
def chebyshev_basis(x, k):
    B = [np.ones(len(x)), x]
    for n in range(2, k):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)

    return np.column_stack(B)


# scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b


# simulates one step of the stock price evolution
def step(S):
    dB = np.sqrt(Δt) * np.random.normal(size=S.size)
    S_tp1 = S + r * S * Δt + σ * S * dB
    return S_tp1


def payoff_put(S):
    return np.maximum(K - S, 0.)


# LSMC algorithm
def compute_price():
    np.random.seed(0)
    S0 = Spot * np.ones(n)
    S = [S0]

    for t in range(m):
        S_tp1 = step(S[t])
        S.append(S_tp1)

    discount = np.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount

    # Proceed recursively
    for i in range(m - 1):
        X = chebyshev_basis(scale(S[-2 - i]), order)
        Y = discounted_future_cashflows

        Θ = np.linalg.solve(X.T @ X, X.T @ Y)
        value_if_wait = X @ Θ
        value_if_exercise = payoff_put(S[-2 - i])
        exercise = value_if_exercise >= value_if_wait
        discounted_future_cashflows = discount * np.where(
            exercise,
            value_if_exercise,
            discounted_future_cashflows)

    return discounted_future_cashflows.mean()


print(compute_price())
# test = compute_price(order, Spot, σ, K, r)


4.460566940166749


2. Jax version of the code using chebyshev polynomials

In [None]:
import jax
import jax.numpy as jnp


Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 100      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates


# Construct polynomial features of order up to k using the
# recursive formulation
def chebyshev_basis(x, k):
    B = [jnp.ones(len(x)), x]
    for n in range(2, k):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)

    return jnp.column_stack(B)


# scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b


# simulates one step of the stock price evolution
def step(S, rng):
    e = jax.random.normal(rng, shape=S.shape)
    dB = jnp.sqrt(Δt) * e
    S = S + r * S * Δt + σ * S * dB
    return S, S

def payoff_put(S):
    return jnp.maximum(K - S, 0.)

@jax.jit
# LSMC algorithm
def compute_price():
    seed = jax.random.PRNGKey(0) 
    S = Spot * jnp.ones(int(n))
    
    rng_vector = jax.random.split(seed, int(m))
    state = S
    input = rng_vector #1st dimension of input needs to have same amount of steps to loop in the function
    _, S = jax.lax.scan(step, state, input) #the S = 100 x 100000 matrix

    discount = jnp.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount
    
    def core(state, input):
      Si = input
      discounted_future_cashflows = state

      X = chebyshev_basis(scale(Si), order)
      Y = discounted_future_cashflows
      Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)  # this line sets up the loss function and solve it by minimizing it
      value_if_wait = X @ Θ
      value_if_exercise = payoff_put(Si)
      exercise = value_if_exercise >= value_if_wait
      discounted_future_cashflows = discount * jnp.where(
          exercise,
          value_if_exercise,
          discounted_future_cashflows)

      return discounted_future_cashflows, None

    # we are reversing the order because we need to go backwards in time
    input = jnp.flip(S, 0)[1:] #pass theta as an input to core function because we need to pass diff Θ for every time step
    state = discounted_future_cashflows
    discounted_future_cashflows, _ = jax.lax.scan(core, state, input)

    # for i in range(m - 1):
    #   discounted_future_cashflows, _ = core(discounted_future_cashflows, input[i])  
    
    return discounted_future_cashflows.mean()

print(compute_price())
# test = compute_price(order, Spot, σ, K, r)

4.469381


3. Jax Gradient descent (normal polynomials)

In [None]:
# Step 1 of grad descent is to initialize Θ with 99 x 12(order of polynomial terms) columns
# Θ = Θ- learning_rate * delL
# Stock simulation stays the same in grad desc
#regress futurre cahsflow on future S[98]

In [None]:
import jax
import jax.numpy as jnp
import optax

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths # we dont need so many paths if we are using stochastic gradient descent
m = 50      # number of exercise dates
T = 1       # maturity
order = 5   # Polynmial order. Increase value and see what happens
Δt = T / m  # interval between two exercise dates

optimizer = optax.adam #we/re using optax to optimize gradient descent
α=1e-3 # learning rate

# # Construct polynomial features of order up to k using the
# # recursive formulation
# def chebyshev_basis(x, k):
#     B = [jnp.ones(len(x)), x]
#     for n in range(2, k):
#         Bn = 2 * x * B[n - 1] - B[n - 2]
#         B.append(Bn)

#     return jnp.column_stack(B)


# scales x to be in the interval(-1, 1)
def scale(x):
  xmin = x.min()
  xmax = x.max()
  a = 2 / (xmax - xmin)
  b = 1 - a * xmax
  return a * x + b


# simulates one step of the stock price evolution
def step(S, rng):
  e = jax.random.normal(rng, shape=S.shape)
  dB = jnp.sqrt(Δt) * e
  S = S + r * S * Δt + σ * S * dB
  return S, S

def payoff_put(S):
  return jnp.maximum(K - S, 0.)

#Step 1 of grad descent
Θ = jnp.zeros((m-1, order)) # Theta is 99x order filled with zeroes. This is our initial guess for theta(wrong values. used just to see what happens)
opt_state = optimizer(1.).init(Θ) # initializing optimizer state

def mse(predictions, label):
  return          

@jax.jit
# LSMC algorithm
def compute_price(Θ): # Step 2: pass theta as input to function because it is an iterative solution dependent on Θ
  seed = jax.random.PRNGKey(0)
  S = Spot * jnp.ones(int(n))
  
  rng_vector = jax.random.split(seed, int(m))
  state = S
  input = rng_vector #1st dimension of input needs to have same amount of steps to loop in the function
  _, S = jax.lax.scan(step, state, input) #the S = 100 x 100000 matrix

  discount = jnp.exp(-r * Δt)

  # Very last date
  value_if_exercise = payoff_put(S[-1])
  discounted_future_cashflows = value_if_exercise * discount

  def core(state, input):
    Si, Θi = input #we take in Θ as input because it keeps changing for every time  when computing mse
    discounted_future_cashflows = state

    X = jnp.column_stack([Si**k for k in range(order)])
    # X = chebyshev_basis(scale(Si), order)
    Y = discounted_future_cashflows
    
    
    # Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)  
    # above line sets up the loss function and solve it by minimizing it but we only want the mse
    
    #Below line says continuation value is features matrix * the initial guessed Θ values
    value_if_wait = X @ Θi # Theta becomes Θi compared to previous step. This doesnt change from step to step but each step needs a different Θ
    
    # here we are finding the mse but not solving(minimizing it)
    mse_i = jnp.mean((value_if_wait-discounted_future_cashflows)**2) 
    
  
    value_if_exercise = payoff_put(Si)
    exercise = value_if_exercise >= value_if_wait
    discounted_future_cashflows = discount * jnp.where(exercise, value_if_exercise, discounted_future_cashflows)

    return discounted_future_cashflows, mse_i #second output is a time series of mean square errors
    #we need to keep track of mse because we need minimize the sum and find mean of mse in our Loss function

  input = jnp.flip(S, 0)[1:], Θ #pass theta as an input to core function because we need to pass diff Θ for every time step
  state = discounted_future_cashflows
  discounted_future_cashflows, mse_i = jax.lax.scan(core, state, input) #second output would be a vector of mse

  # for i in range(m - 1):
  #   discounted_future_cashflows, _ = core(discounted_future_cashflows, input[i])  
  
  return discounted_future_cashflows.mean(), mse_i # mse_i is a vector containing 99 values of the mse for each time step except the last time stamp

print(compute_price(Θ)) #calculates the price using the guessed(wrong) 
# Θ

#each value in the 99 dimensional vector is the mse of the time stamp
# test = compute_price(order, Spot, σ, K, r)

# observe the price fell to 3.97 from 4.46 because we removed chebyshev polynomials

(DeviceArray(3.9540339, dtype=float32), DeviceArray([37.68034 , 37.37557 , 37.100048, 36.81555 , 36.519394,
             36.196728, 35.936237, 35.66297 , 35.36962 , 35.0462  ,
             34.7042  , 34.42286 , 34.073692, 33.78233 , 33.42189 ,
             33.036335, 32.716248, 32.342255, 31.92985 , 31.629984,
             31.293911, 31.001877, 30.602314, 30.206573, 29.85009 ,
             29.468176, 29.110407, 28.67713 , 28.275558, 27.83037 ,
             27.404428, 26.93171 , 26.48799 , 26.007463, 25.574371,
             25.080164, 24.610832, 24.116224, 23.58388 , 23.024042,
             22.501335, 21.925869, 21.309639, 20.723457, 20.11183 ,
             19.439226, 18.764383, 18.062632, 17.374859], dtype=float32))


In [None]:
@jax.jit #because we want it to run quick
def update_gradient_descent(Θ, opt_state): 

  def L(Θ): #setting up the Loss function after preparing our above code to get mse vector is good practice
    _, mse_i = compute_price(Θ)
    return mse_i.sum()

  grad = jax.grad(L)(Θ) #compute gradient of L at previous parameter Θ
  updates, opt_state = optimizer(α).update(grad, opt_state) # call optimizer at learning rate and give in the gradient and the previous optimizer state. 
                                                            # that function returns updates with new optimizer state
  
  Θ = optax.apply_updates(Θ, updates) #compute new parameter Θ
  return Θ, opt_state 

@jax.jit
def evaluate(Θ):
  p, _ = compute_price(Θ) #we can observe the jump is option price because the Theta values are more optimized
  return p

for iteration in range(1000): #change range to 10000 and see what happens
  Θ, opt_state = update_gradient_descent(Θ, opt_state)

  if iteration % 100 == 0:
    print(evaluate(Θ))

#Why this specific pattern in the output?

3.8474746
4.44126
4.1777544
4.2251863
4.2377887
4.242069
4.243798
4.2425575
4.2410126
4.2380733


Very slow becauae at each iteration we are computing 100000 price paths and we simulate 100 peers. and we waste alot of computational power to perform 1 step of gradient descent.

Instead of 100000 simulations, we can simulate 500 bu using stochastic gradient descent. Also we know in principal stochastic gradient descent works like gradient descent but much much faster

Lets see why chebyshev polunomials wont work in stochastic gradient descent

In [None]:
import jax
import jax.numpy as jnp
import optax

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths # we dont need so many paths if we are using stochastic gradient descent
m = 50      # number of exercise dates
batch_size = 512 #batch size
T = 1       # maturity
order = 5   # Polynmial order. Increase value and see what happens
Δt = T / m  # interval between two exercise dates

optimizer = optax.adam #we/re using optax to optimize gradient descent
α=1e-3 # learning rate

# Construct polynomial features of order up to k using the
# recursive formulation
def chebyshev_basis(x, k):
    B = [jnp.ones(len(x)), x]
    for n in range(2, k):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)

    return jnp.column_stack(B)


# scales x to be in the interval(-1, 1)
def scale(x):
  xmin = 10 #x.min()
  xmax = 70 #x.max()
  a = 2 / (xmax - xmin)
  b = 1 - a * xmax
  return a * x + b


# simulates one step of the stock price evolution
def step(S, rng):
  e = jax.random.normal(rng, shape=S.shape)
  dB = jnp.sqrt(Δt) * e
  S = S + r * S * Δt + σ * S * dB
  return S, S

def payoff_put(S):
  return jnp.maximum(K - S, 0.)

#Step 1 of grad descent
Θ = jnp.zeros((m-1, order)) 
opt_state = optimizer(1.).init(Θ)

def model(Θi, Si):
  X = chebyshev_basis(scale(Si), order)
  value_if_wait = X @ Θi
  return value_if_wait


# @jax.jit
# LSMC algorithm
def compute_price(Θ, n, seed): #pass seed because it is SGD
  S = Spot * jnp.ones(n)
  
  rng_vector = jax.random.split(seed, int(m))
  state = S
  input = rng_vector 
  _, S = jax.lax.scan(step, state, input) 

  discount = jnp.exp(-r * Δt)

  value_if_exercise = payoff_put(S[-1])
  discounted_future_cashflows = value_if_exercise * discount

  def core(state, input):
    Si, Θi = input
    discounted_future_cashflows = state
    Y = discounted_future_cashflows

    #Below line says continuation value is features matrix * the initial guessed Θ values
    # value_if_wait = X @ Θi
    value_if_wait = model(Θi, Si)
    
    mse_i = jnp.mean((value_if_wait-discounted_future_cashflows)**2) 
    
    value_if_exercise = payoff_put(Si)
    exercise = value_if_exercise >= value_if_wait
    discounted_future_cashflows = discount * jnp.where(exercise, value_if_exercise, discounted_future_cashflows)

    return discounted_future_cashflows, mse_i
  
  input = jnp.flip(S, 0)[1:], Θ
  state = discounted_future_cashflows
  discounted_future_cashflows, mse_i = jax.lax.scan(core, state, input)

  return discounted_future_cashflows.mean(), mse_i

In [None]:
@jax.jit
def update_gradient_descent(Θ, opt_state, seed): 

  seed, _ = jax.random.split(seed)

  def L(Θ):
    _, mse_i = compute_price(Θ, batch_size, seed)
    return mse_i.sum()

  grad = jax.grad(L)(Θ)
  updates, opt_state = optimizer(α).update(grad, opt_state) 
                                                            
  
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state, seed 

@jax.jit
def evaluate(Θ):
  seed = jax.random.PRNGKey(0) #doesnt matter if we use the same seed for evaluation
  p, _ = compute_price(Θ, 200000, seed)
  return p

seed = jax.random.PRNGKey(0)

for iteration in range(1000000):
  Θ, opt_state, seed = update_gradient_descent(Θ, opt_state, seed)

  if iteration % 100 == 0:
    print(evaluate(Θ))

# We can observe the predicted values gets better

3.9511049
3.951169
3.9513664
3.952066
3.9534402
3.9558983
3.9602818
3.9675856
3.9780893
3.9932992
4.0120234
4.034157
4.05979
4.086995
4.115958
4.14497
4.170677
4.1953187
4.217897
4.238947
4.257911
4.2746468
4.2887616
4.3000746
4.310746
4.318579
4.3243866
4.3289695
4.3338013
4.3364935
4.3383493
4.3393617
4.341795
4.342681
4.3453813
4.3457966
4.3456273
4.347289
4.348111
4.3473597
4.348216
4.347362
4.347083
4.344932
4.3432746
4.343918
4.343767
4.3406205
4.3389845
4.336547
4.335172
4.3335967
4.3317304
4.328073
4.3243384
4.322855
4.3195524
4.3152437
4.311239
4.3100977
4.3060117
4.3020735
4.295375
4.293712
4.29027
4.2847056
4.279872
4.2785397
4.270176
4.2672606
4.2636366
4.2595825
4.255797
4.2507925
4.245539
4.242451
4.239733
4.2311397
4.234604
4.226414
4.2255254
4.2251186
4.224458
4.217593
4.22411
4.2217035
4.223073
4.2219205
4.2216644
4.2205257
4.220975
4.220631
4.221236
4.220102
4.2212844
4.2224836
4.2217298
4.222287
4.222148
4.221369
4.226428
4.227549
4.2298293
4.2306657
4.2279997
4.2315

KeyboardInterrupt: ignored

Skip this step for now, we figured chebyshev works with higher number of iterations

4. Jax Stochastic gradient descent (normal polynomials)

We use regular polynomialls because .
But in stochastic gradient descent minimum and maximum are gonna keep changing at each iteration

Each time we call compute_price, we are gonna be sampling different shocks. If we use the code the way it is
Different simulations will have different Si. At each step chebyshev scales the features at the maximum or minimum and this needs to be done in order to use chebyshev. If this scaling depends on our sample, it means the scaling is changing from iteration to iteration. But the scaling should be the same throughout trading. So it wont work. 

In [None]:
import jax
import jax.numpy as jnp
import optax

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths # we dont need so many paths if we are using stochastic gradient descent
m = 100      # number of exercise 
batch_size = 512
T = 1       # maturity
order = 5   # Polynmial order
Δt = T / m  # interval between two exercise dates
optimizer = optax.adam #we/re using optax to optimize gradient descent


# scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b


# simulates one step of the stock price evolution
def step(S, rng):
  e = jax.random.normal(rng, shape=S.shape)
  dB = jnp.sqrt(Δt) * e
  S = S + r * S * Δt + σ * S * dB
  return S, S

def payoff_put(S):
  return jnp.maximum(K - S, 0.)


#Step 1 of grad descent
Θ = jnp.zeros((m-1, order)) # Theta is 99x5 filled with zeroes. This is our initial guess for theta
opt_state = optimizer(1.).init(Θ)

def mse(predictions, label):
  return jnp.mean((predictions-label)**2)         


seed = jax.random.PRNGKey(0) #we remove it outside from the compute_price function and pass it as input


# @jax.jit
# LSMC algorithm
def compute_price(Θ, n, seed): # give size of simulation as input, because we're using batch_size 512 to compute loss function but we're using 100000 simulations to evaluate the quality of the output
  S = Spot * jnp.ones(int(n))
  rng_vector = jax.random.split(seed, int(m))
  state = S
  input = rng_vector #1st dimension of input needs to have same amount of steps to loop in the function
  _, S = jax.lax.scan(step, state, input) #the S = 100 x 100000 matrix

  discount = jnp.exp(-r * Δt)

  # Very last date
  value_if_exercise = payoff_put(S[-1])
  discounted_future_cashflows = value_if_exercise * discount  


  def core(state, input):
    Si, Θi = input #we take in Θ as input because it keeps changing for every time step
    discounted_future_cashflows = state

    X = jnp.column_stack([Si**k for k in range(order)])
    # X = chebyshev_basis(scale(Si), order)
    Y = discounted_future_cashflows
    # Θ = jnp.linalg.solve(X.T @ X, X.T @ Y) # not gonna compute regression
    value_if_wait = X @ Θi # Theta becomes Theta i

    mse_i = mse(value_if_wait, discounted_future_cashflows)

    value_if_exercise = payoff_put(Si)
    exercise = value_if_exercise >= value_if_wait
    discounted_future_cashflows = discount * jnp.where(exercise, value_if_exercise, discounted_future_cashflows)

    return discounted_future_cashflows, mse_i #we need to keep track of mse because we need minimize the sum and find mean of mse

  input = jnp.flip(S[:-1], 0), Θ #pass theta as an input to core function because we need to pass diff Θ for every time step
  state = discounted_future_cashflows
  discounted_future_cashflows, mse_i = jax.lax.scan(core, state, input) #second output would be a vector of mse

  # for i in range(m - 1):
  #   discounted_future_cashflows, _ = core(discounted_future_cashflows, input[i])  
  
  return discounted_future_cashflows.mean(), mse_i

# print(compute_price(Θ, n, seed)) #calculates the price using the guessed 
#each value in the 99 dimensional vector is the mse of the time stamp
# test = compute_price(order, Spot, σ, K, r)


(DeviceArray(3.9780037, dtype=float32), DeviceArray([37.56345 , 37.446575, 37.298405, 37.1356  , 36.979633,
             36.826267, 36.686455, 36.54746 , 36.394558, 36.257782,
             36.110565, 35.97117 , 35.787395, 35.635506, 35.44384 ,
             35.30973 , 35.17018 , 35.002453, 34.83765 , 34.706955,
             34.56372 , 34.382797, 34.222466, 34.06043 , 33.905365,
             33.803722, 33.628613, 33.461033, 33.312878, 33.162132,
             33.00243 , 32.852543, 32.698105, 32.52248 , 32.31892 ,
             32.16418 , 31.971344, 31.765043, 31.56413 , 31.387058,
             31.221035, 31.062067, 30.920769, 30.766369, 30.558311,
             30.365704, 30.179184, 30.008032, 29.82569 , 29.63478 ,
             29.436325, 29.24955 , 29.032795, 28.83644 , 28.643064,
             28.453009, 28.260712, 28.02322 , 27.795649, 27.562605,
             27.350769, 27.146765, 26.931189, 26.725199, 26.524939,
             26.291368, 26.039965, 25.834959, 25.605137, 25.388474,
        

In [None]:
import jax

def L(Θ, seed):
  price, mse_i = compute_price(Θ, batch_size, seed) #we're passing seed as input because each time we simulate the stock price we want a different random seed
  return mse_i.sum()

α=1e-5 # learning rate

@jax.jit
def update(Θ, opt_state, seed):
  seed, _ = jax.random.split(seed) #split to make sure because we're using different seed when we're using it
  grad = jax.grad(L)(Θ, seed)
  updates, opt_state = optimizer(α).update(grad, opt_state, seed) # provide computed gradient and the previous optimizer state
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state, seed

In [None]:
for iteration in range(1000): #change range to 10000 and see what happens
  Θ, opt_state, seed = update(Θ, opt_state, seed)

In [None]:
compute_price(Θ, n, seed) #we can observe the jump is option price because the Theta values are more optimized

(DeviceArray(3.9047234, dtype=float32),
 DeviceArray([25.9571    , 26.03179   , 28.30797   , 26.487484  ,
              33.98196   , 23.261839  , 18.357145  , 20.911139  ,
              18.861992  , 23.744764  , 37.110355  , 33.16365   ,
              34.124043  , 32.42057   , 34.113705  , 32.352974  ,
              23.001568  , 22.787683  , 18.880028  , 17.8579    ,
              17.86499   , 23.699694  , 32.010426  , 26.327065  ,
              28.527948  , 21.600445  , 23.341635  , 24.439177  ,
              31.058313  , 31.720592  , 30.588894  , 27.225357  ,
              25.545197  , 17.32303   , 15.062975  , 14.587344  ,
              14.089913  , 14.731817  , 14.511595  , 14.353942  ,
              13.672367  , 13.670407  , 14.524605  , 13.771536  ,
              13.72872   , 14.093289  , 11.748952  , 13.548732  ,
              11.669214  , 11.593536  , 13.253572  , 12.508336  ,
              11.367655  , 11.479922  , 11.700112  , 11.503025  ,
              10.911019  , 10.710645

5. Jax stochastic gradient descent neural network

Why do we need neural networks? Because payoff can depend on a basket of stocks rather than only a single stock price

we move from a 1D problem to a 3D problem

First we make the NN model for 1 stock

In [23]:
# pip install optax dm-haiku

import jax
import jax.numpy as jnp
import optax
import haiku as hk

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths # we dont need so many paths if we are using stochastic gradient descent
m = 100      # number of exercise dates
batch_size = 512 #batch size 
#At each gd update we are sampling 512 data points to estimate the grade. In theory we need to use an infinite amount of data points

T = 1       # maturity
order = 5   # Polynmial order. Increase value and see what happens
Δt = T / m  # interval between two exercise dates

optimizer = optax.adam #we/re using optax to optimize gradient descent
α=1e-3 # learning rate

# We remove chebyshev because we arent using it

# simulates one step of the stock price evolution
def step(S, rng):
  e = jax.random.normal(rng, shape=S.shape)
  dB = jnp.sqrt(Δt) * e
  S = S + r * S * Δt + σ * S * dB
  return S, S

def payoff_put(S): # S is a vector with different simulations for 1 time step
                   # 100000 elements if gd or 512 if step
  return jnp.maximum(K - S, 0.)


def model(Si):
  # out = jnp.column_stack([Si])
  out = (jnp.column_stack([Si]) - 37)/5  # S is vector and haiku expects to be working on matrices. Sο S gets transformed to a matrix
                                         # 37 is the mean and 5 is the std deviation
                                         # we are normalizing the data
                                         # see what happens to the predictions in the end if we dont normalize the data

  out = hk.Linear(32)(out)
  out = jax.nn.relu(out)
  # out = jax.nn.silu(out) #this activation function doesnt have a problem when u dont normalize the data

  out = hk.Linear(32)(out)
  out = jax.nn.relu(out)
  # out = jax.nn.silu(out)

  out = hk.Linear(1)(out) # 512 x 1 matrix

  return jnp.squeeze(out) # squeeze coverts to a vector

init, model = hk.without_apply_rng(hk.transform(model))

In [24]:
seed = jax.random.PRNGKey(0)
Θ = init(seed, jnp.array(1.))
opt_state = optimizer(α).init(Θ)


def stack(Θ):
  return jnp.stack([Θ] * (m-1))  #stacking theta m - 1 times

Θ = stack(Θ)


# Code wont work because stack only works on numpy arrays. Theta is not a Numpy Array. 
# What is Theta?

TypeError: ignored

In [73]:
type(Θ)

dict

In [74]:
Θ.keys()
# Below is because we are doing 3 transformations in corresponding order

dict_keys(['linear', 'linear_1', 'linear_2'])

In [75]:
Θ['linear']
# 2 keys: w for weight; b for bias
# w is a 1 x 32 matrix. (32 columns)
# b is a vector with 32 elements

{'w': DeviceArray([[ 0.3043298 , -0.20807579,  1.5253608 , -0.33471823,
               -0.819684  ,  0.07262449,  0.75095516,  0.39394712,
                1.1008915 , -0.08661517,  0.37275982,  0.62971044,
               -1.2111528 ,  0.37929872,  0.16987085, -0.01042713,
               -0.02237888, -0.25954646,  0.37912896,  0.35421863,
                1.0106694 ,  0.49792686,  0.9558353 ,  1.2206774 ,
               -1.1474493 ,  0.19136466, -0.3144515 ,  0.77004355,
               -0.05240701,  0.6409038 ,  1.7368565 ,  0.2765501 ]],            dtype=float32),
 'b': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0.], dtype=float32)}

Correct way to stack below

In [25]:
# Ideally we want to create a new theta that is a new dict of dict, 
# where every leaf will be a copy of the corresponding leaf of the original theta

seed = jax.random.PRNGKey(0)
Θ = init(seed, jnp.array(1.))
opt_state = optimizer(α).init(Θ)


def stack(Θ):
  return jnp.stack([Θ] * (m-1))  #stacking theta m - 1 times

# Θ = stack(Θ)
Θ = jax.tree_map(stack, Θ)
opt_state = optimizer(α).init(Θ)

In [77]:
type(Θ)

dict

In [78]:
Θ.keys()

dict_keys(['linear', 'linear_1', 'linear_2'])

In [79]:
Θ['linear'].keys()

dict_keys(['b', 'w'])

In [80]:
Θ['linear']['w'].shape
# Now theta represents a collection(row stacked 99 times) of no_of_steps neural networks

(99, 1, 32)

In [81]:
Θ['linear']['b'].shape
# We can observe the first dimension to be 99

(99, 32)

In [26]:
#Step 1 of grad descent
# Θ = jnp.zeros((m-1, order)) 
# opt_state = optimizer(α).init(Θ)


# @jax.jit
# LSMC algorithm
def compute_price(Θ, n, seed): #pass seed because it is SGD
  S = Spot * jnp.ones(n)
  
  rng_vector = jax.random.split(seed, int(m))
  state = S
  input = rng_vector 
  _, S = jax.lax.scan(step, state, input) 

  discount = jnp.exp(-r * Δt)

  value_if_exercise = payoff_put(S[-1])
  discounted_future_cashflows = value_if_exercise * discount

  def core(state, input):
    Si, Θi = input
    discounted_future_cashflows = state
    Y = discounted_future_cashflows

    value_if_wait = model(Θi, Si) # using theta to calculate continouation value, 
                                  # ie how much dcf we expect to get given how much cashflow we'll get in the future
                                  # if we find wrong value_if_wait, we will excercise when arent supposed to or vice versa
                                  # actions are not optimal
    
    mse_i = jnp.mean((value_if_wait-discounted_future_cashflows)**2) 
    
    value_if_exercise = payoff_put(Si)
    exercise = value_if_exercise >= value_if_wait # decide if we want to excercise or not
    
    discounted_future_cashflows = discount * jnp.where(exercise, value_if_exercise, discounted_future_cashflows)

    return discounted_future_cashflows, mse_i
  
  input = jnp.flip(S, 0)[1:], Θ
  state = discounted_future_cashflows
  discounted_future_cashflows, mse_i = jax.lax.scan(core, state, input)

  return discounted_future_cashflows.mean(), mse_i

In [27]:
@jax.jit
def update_gradient_descent(Θ, opt_state, seed): 

  seed, _ = jax.random.split(seed)

  def L(Θ):
    _, mse_i = compute_price(Θ, batch_size, seed)
    return mse_i.sum()

  grad = jax.grad(L)(Θ)
  updates, opt_state = optimizer(α).update(grad, opt_state) 
                                                            
  
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state, seed 

@jax.jit
def evaluate(Θ):
  seed = jax.random.PRNGKey(0) #doesnt matter if we use the same seed for evaluation
  p, _ = compute_price(Θ, 200000, seed)
  return p

In [28]:
seed = jax.random.PRNGKey(0)

for iteration in range(1000000):
  Θ, opt_state, seed = update_gradient_descent(Θ, opt_state, seed)

  if iteration % 100 == 0:
    print(evaluate(Θ))

# We improved found the model by normalizing the data

3.9746199
4.220788
4.2116103
4.236565
4.2453923
4.227541
4.262583
4.2285404
4.2582145
4.2047067
4.219553
4.2235665
4.215366
4.1404123
4.159133
4.204374
4.1307473
4.1960907
4.145729
4.2015853
4.1845417
4.1999936
4.1738873
4.127784
4.055139
4.16718
4.1879697
4.1982393
4.2164335
4.1763687
4.1564307
4.134822
4.1399903
4.1595397
4.2170157
4.2179804
4.1733956
4.188957
4.0413756
4.1784067
4.0953817
4.192585
4.1882777
4.232761
4.2434406
4.1993184
4.2331095
4.165975
4.164168
4.1043935
4.1734076
4.24499
4.1056356
4.1809583
4.1289825
4.254113
4.1966076
4.1711965
4.204183
4.077337
4.112798
4.164596
4.0749545
4.1415453
4.191174
4.236075
4.168531


KeyboardInterrupt: ignored

Now we need to modify our step function and payoff function to handle more than 1 stock

In [10]:
# pip install optax dm-haiku

import jax
import jax.numpy as jnp
import optax
import haiku as hk

Spot = jnp.array([38, 36, 35])   # stock , no longer scalar, it is a vector
σ = jnp.array([.2, .25, 0.3])    # we can give volatility for each stock     
K = 40      
r = 0.06    
n = 512  
m = 100     

T = 1       
order = 5   
Δt = T / m  

optimizer = optax.adam
α=1e-3

# no we need to change this function to take in multiple stock values
def step(S, rng):
  e = jax.random.normal(rng, shape=S.shape) # we dont need to change this line because we wrote it in an agnostic way.
                                            # If S is a matrix, e will be a matrix
  dB = jnp.sqrt(Δt) * e
  S = S + r * S * Δt + σ * S * dB
  return S, S

# we need to change the payoff
def payoff_put(S):
  return jnp.maximum(K - jnp.max(S, axis=1), 0.) #axis is 1 because we ned the sum of column. reason is in lecture notes written in the end

def model(Si):
  out = (Si - 37)/5.  # we remove column stack because S is a matrix 
                      # we can scale each row with a respective mean and td deviation?
  out = hk.Linear(32)(out)
  out = jax.nn.relu(out)

  out = hk.Linear(32)(out)
  out = jax.nn.relu(out)

  out = hk.Linear(1)(out) #512 x 1 matrix

  return jnp.squeeze(out)

init, model = hk.without_apply_rng(hk.transform(model))

In [13]:
seed = jax.random.PRNGKey(0)
Θ = init(seed, jnp.zeros((512, 3))) #we need to give a sample of your inputs 
#why do we change it like this?

opt_state = optimizer(α).init(Θ)

def stack(Θ):
  return jnp.stack([Θ] * (m-1))  #stacking theta m - 1 times

Θ = jax.tree_map(stack, Θ)
opt_state = optimizer(α).init(Θ)

In [14]:
# n = 512
S = jnp.column_stack([Spot[i] * jnp.ones(n) for i in range(3)])
S.shape

(512, 3)

In [15]:
# @jax.jit
# LSMC algorithm
def compute_price(Θ, n, seed): #pass seed because it is SGD
  # S = Spot * jnp.ones(n, 3) # S is a matrix with different simulations in 1 time step for 3 stocks 
  S = jnp.column_stack([Spot[i] * jnp.ones(n) for i in range(3)])
  
  rng_vector = jax.random.split(seed, int(m))
  state = S
  input = rng_vector 
  _, S = jax.lax.scan(step, state, input) 

  discount = jnp.exp(-r * Δt)

  value_if_exercise = payoff_put(S[-1])
  discounted_future_cashflows = value_if_exercise * discount

  def core(state, input):
    Si, Θi = input
    discounted_future_cashflows = state
    Y = discounted_future_cashflows

    value_if_wait = model(Θi, Si)
    
    mse_i = jnp.mean((value_if_wait-discounted_future_cashflows)**2) 
    
    value_if_exercise = payoff_put(Si)
    exercise = value_if_exercise >= value_if_wait
    
    discounted_future_cashflows = discount * jnp.where(exercise, value_if_exercise, discounted_future_cashflows)

    return discounted_future_cashflows, mse_i
  
  input = jnp.flip(S, 0)[1:], Θ
  state = discounted_future_cashflows
  discounted_future_cashflows, mse_i = jax.lax.scan(core, state, input)

  return discounted_future_cashflows.mean(), mse_i

In [16]:
@jax.jit
def update_gradient_descent(Θ, opt_state, seed): 

  seed, _ = jax.random.split(seed)

  def L(Θ):
    _, mse_i = compute_price(Θ, batch_size, seed)
    return mse_i.sum()

  grad = jax.grad(L)(Θ)
  updates, opt_state = optimizer(α).update(grad, opt_state) 
                                                            
  
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state, seed 

@jax.jit
def evaluate(Θ):
  seed = jax.random.PRNGKey(0) #doesnt matter if we use the same seed for evaluation
  p, _ = compute_price(Θ, 200000, seed)
  return p

In [17]:
seed = jax.random.PRNGKey(0)

for iteration in range(1000000):
  Θ, opt_state, seed = update_gradient_descent(Θ, opt_state, seed)

  if iteration % 100 == 0:
    print(evaluate(Θ))

# Correct answer is 2.25

1.9517
2.1202378
2.1510992
2.206948
2.2358725
2.2463207
2.2508252
2.2458668
2.2445405
2.2517066
2.2541823
2.2480638
2.2572012
2.2468977
2.2545292
2.2505789
2.2537115
2.2479827
2.2452836
2.2520442
2.2537427
2.2495224
2.258139
2.2514815
2.2507145
2.2499056
2.2546976
2.2589822
2.2518132
2.2558422
2.256078
2.2534692
2.2586975
2.2520983
2.2599974
2.2616544
2.2565858
2.2609622
2.2493684
2.25171
2.257472
2.2449715
2.2563703
2.2483518
2.2533047
2.2505329
2.2511134
2.257806
2.258041
2.2498991
2.2564812
2.256799
2.2604113
2.2494233
2.2538314
2.2359686
2.2526343
2.2567616
2.2539997
2.2525725
2.2586255
2.257685
2.251334
2.2598798
2.2527988
2.2535276
2.2615912
2.261707
2.2576766
2.2545316
2.2591352
2.2526622
2.2577872
2.2570925
2.253807
2.250126
2.2581263
2.2602859
2.2625966
2.2501795
2.2544734
2.2573736
2.2579818
2.253785
2.2643669
2.2611625
2.2545593
2.2496173
2.2473187
2.2593272


KeyboardInterrupt: ignored

In normal equations, we scale because each varaible will have its own scale.

In NN the 