In [None]:
#we will learn how to implement jax.lax.scan
#we use jax because it is much more faster in compile time and super friendly jax

In [None]:
import jax
import jax.numpy as jnp

In [None]:
#Example 1
@jax.jit
def loop():
  x = 1.
  for i in range(20):
    x = x+1
  return x

In [None]:
loop()

DeviceArray(21., dtype=float32, weak_type=True)

In [None]:
#whatever is in the for loop needs to be written as a function

@jax.jit
def core(state, inputs): #state = that changes from iteration to iteration. basically x. We are not using inputs here but we can use IRL
  x = state
  x = x+1
  x_square = x**2
  # return x, None #we can return something instead of None which we will use in HW 2 American Put Option LSMC
  # return x, x_square #we can return some value and this value is calculated at every loop.
  # Basically first output is final answer of loop, and second output is the answer for every single loop
  # return x, jnp.array([x**2, x**3]) #we can output array/dictionary
  return x, {'x_squared': x**2, 'x_cubed': x**3} #outputting dictionary

def loop():
  x = 1 
  state = x
  state, out = jax.lax.scan(core, state, jnp.ones(20)) 
  # third input is a vector of input.
  # If we want the loop to run 20 times, we use a 20 dimensional vector as 3rd input. 
  # It could be any input with any dimension
  # It culd be anything now because we are not using it in the core function in this application]
  
  return state, out 


In [None]:
loop()

(DeviceArray(21, dtype=int32, weak_type=True),
 {'x_cubed': DeviceArray([   8,   27,   64,  125,  216,  343,  512,  729, 1000, 1331,
               1728, 2197, 2744, 3375, 4096, 4913, 5832, 6859, 8000, 9261],            dtype=int32, weak_type=True),
  'x_squared': DeviceArray([  4,   9,  16,  25,  36,  49,  64,  81, 100, 121, 144, 169,
               196, 225, 256, 289, 324, 361, 400, 441],            dtype=int32, weak_type=True)})

In [None]:
#example 2

rng = jax.random.PRNGKey(0)

@jax.jit
def loop(rng):
  x = 1 
  for i in range(20):
    rng, _ = jax.random.split(rng) #psuedo random number generator. not really random but it makes it looks random
    x = x + (jax.random.normal(rng, shape=[100]) * i).std() #x multiplied by a random shock
  
  return x 
   

In [None]:
loop(rng)

DeviceArray(1018202.75, dtype=float32)

In [None]:
@jax.jit
def core(state, input): #inputs are needed to perform the operation, but its not the state
  #in each loop, we need previous value of x, random number generator, and current index
  
  x = state #thing that changes iteration to iteration
  rng, i = input
  
  x = x + (jax.random.normal(rng, shape=[100]) * x).std()
  return x, x #instead of None, we return x so we can keep track of x for every single loop.
  #we use this when we need to keep track of whole time series of stock prices in HW2

def loop(rng):
  x = 1.
  state = x
  input = rng_vector, jnp.arange(20) #check code in next set 
  state, out = jax.lax.scan(core, x, input)
  # compared to previous example, we can observe that input parameter can now be a dictionary or array etc
  # important to remember is the first dimension of input must be 20 to run it 20 times
  return state, out

In [None]:
#This is for i to iterate 20 times meaning i will go from 0 to 19
jnp.arange(20)

DeviceArray([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
             15, 16, 17, 18, 19], dtype=int32)

In [None]:
# we create this vector of random keys
rng_vector = jax.random.split(rng, 20)


In [None]:
loop(rng) #answer not gonna be the same as normal for loop because rng_vector is different

(DeviceArray(769098.06, dtype=float32),
 DeviceArray([2.05442095e+00, 4.27148390e+00, 8.54490471e+00,
              1.60934753e+01, 3.27579803e+01, 6.48170319e+01,
              1.30230499e+02, 2.60836548e+02, 5.14093262e+02,
              1.01622534e+03, 1.92730469e+03, 3.69253613e+03,
              7.19437305e+03, 1.34246914e+04, 2.58675879e+04,
              4.76994883e+04, 9.47989922e+04, 1.86214312e+05,
              3.81152062e+05, 7.69098062e+05], dtype=float32))