In [1]:
pip install -q optax dm-haiku

[K     |████████████████████████████████| 154 kB 31.1 MB/s 
[K     |████████████████████████████████| 352 kB 62.3 MB/s 
[K     |████████████████████████████████| 85 kB 4.7 MB/s 
[?25h

Due Date: November 30
# Problem statement

The code below is similar to the Cake Eating problem code we implemented in class. The differences are:
- Each time interval corresponds to one year (instead of one month)
- The consumption policy function is written as a simple sigle-layer neural network, with tanh activation (instead of the usual relu)

We will interpret the size of the cake as being total wealth, and cake consumption as general consumption. The fraction of wealth not consumed today are the *savings* (line 51). The dynamics of wealth are described by line 54. That line is equivalent to assuming that your savings are invested in a risk-free savings account that pays 0 interest, and therefore has a gross return of 1, denoted by *R* (line 53).




In [None]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95


def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50


def nnet(x):
  X = jnp.column_stack([x])
  X = hk.Linear(32)(X)
  X = jnp.tanh(X)
  X = hk.Linear(1)(X)
  X = jnp.squeeze(X)
  return X


init, nnet = hk.without_apply_rng(hk.transform(nnet))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.array(1.))


opt_state = optimizer(lr).init(Θ)


def L(Θ):

  x = 1.
  G = 0.

  state = x
  inputs = jnp.arange(T)

  def core(state, inputs):
    t = inputs
    xt = state

    ct = jax.nn.sigmoid(nnet(Θ, xt) - 4.) * xt
    ut = U(ct)
    savings = xt - ct

    R = 1.
    x_tp1 = R * savings # X at next time step is worth R * savings

    discounted_utility = β**t * ut
    return x_tp1, discounted_utility

  x, discounted_utility = jax.lax.scan(core, state, inputs)
  G = discounted_utility.sum()
  return -G


@jax.jit
def evaluation(Θ):
  return -L(Θ)


@jax.jit
def update_gradient_descent(Θ, opt_state):
  grad = jax.grad(L)(Θ)
  updates, opt_state = optimizer(lr).update(grad, opt_state)
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state


for iteration in range(100000):
  Θ, opt_state = update_gradient_descent(Θ, opt_state)

  if iteration % 10000 == 0:
    print(evaluation(Θ))

-1301.5056
-823.7884


KeyboardInterrupt: ignored

Suppose now that your savings are fully invested in the stock market, so the evolution of wealth is now stochastic. The stock market gross return is modeled by the function below:

In [None]:
rng = jax.random.PRNGKey(1)

def stock_return(rng):
  μs = 0.06
  σs = 0.2
  ε = jax.random.normal(rng, ())
  log_return = μs + σs * ε
  # print(jnp.exp(log_return))
  return jnp.exp(log_return)

print(stock_return(rng))

0.8378997



First, let's try and use jax.lax.scan to generate random stock returns.

In [None]:
def stock_return(rng):
  
  μs = 0.06
  σs = 0.2
  ε = jax.random.normal(rng, ())
  log_return = μs + σs * ε
  return jnp.exp(log_return)

In [None]:
seed = jax.random.PRNGKey(1)
T = 50

def core(state, input):
  W0 = state
  rng = input
  stock_r = stock_return(rng) * W0
  return state, stock_r

rng_vector = jax.random.split(seed, T) #m is the length of vector
state = 1 #starts with 100%
input = rng_vector

state, out = jax.lax.scan(core, state, input)

In [None]:
out
# output when seed = 0

DeviceArray([0.84260744, 1.1804781 , 1.2274637 , 0.95285577, 0.78522426,
             1.1539401 , 1.0861521 , 1.3128853 , 0.9442702 , 1.0119264 ,
             1.0058765 , 0.96876556, 1.309265  , 1.139243  , 0.9535717 ,
             0.84595954, 0.8710616 , 1.0380025 , 1.3171368 , 0.81579727,
             0.9257629 , 0.9616835 , 1.2427455 , 1.3873675 , 1.0009146 ,
             1.3997176 , 0.7696073 , 1.1823578 , 0.96337336, 1.3092004 ,
             1.125602  , 1.1189346 , 1.0482037 , 0.8010677 , 1.2646732 ,
             0.9692891 , 0.7821357 , 1.128498  , 1.3679446 , 1.002268  ,
             1.3832148 , 1.2221909 , 0.7562105 , 1.1158998 , 1.1951793 ,
             1.0257734 , 1.1197132 , 1.091551  , 1.4884138 , 1.1023614 ],            dtype=float32)

In [None]:
out
# output when seed = 1

# we have observed that we need to change the seed everytime we run a trial to get stochastic returns for all trials

DeviceArray([0.84260744, 1.1804781 , 1.2274637 , 0.95285577, 0.78522426,
             1.1539401 , 1.0861521 , 1.3128853 , 0.9442702 , 1.0119264 ,
             1.0058765 , 0.96876556, 1.309265  , 1.139243  , 0.9535717 ,
             0.84595954, 0.8710616 , 1.0380025 , 1.3171368 , 0.81579727,
             0.9257629 , 0.9616835 , 1.2427455 , 1.3873675 , 1.0009146 ,
             1.3997176 , 0.7696073 , 1.1823578 , 0.96337336, 1.3092004 ,
             1.125602  , 1.1189346 , 1.0482037 , 0.8010677 , 1.2646732 ,
             0.9692891 , 0.7821357 , 1.128498  , 1.3679446 , 1.002268  ,
             1.3832148 , 1.2221909 , 0.7562105 , 1.1158998 , 1.1951793 ,
             1.0257734 , 1.1197132 , 1.091551  , 1.4884138 , 1.1023614 ],            dtype=float32)

Write a code to solve for the optimal consumption policy in this environment. 
What is the expceted sum of discounted rewards (value function) resulting from that policy? Use at least 1 million sample paths to estimate that number.

Solution below

In [4]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk


γ = 2.
β = 0.95


def U(c):
    return c**(1 - γ) / (1 - γ)


optimizer = optax.adam
lr = 1e-3
T = 50


def nnet(x):
  X = jnp.column_stack([x])
  X = hk.Linear(32)(X)
  X = jnp.tanh(X)
  X = hk.Linear(1)(X)
  X = jnp.squeeze(X)
  return X

init, nnet = hk.without_apply_rng(hk.transform(nnet))

seed = jax.random.PRNGKey(0)
Θ = init(seed, jnp.array(1.)) # use initialisation function, then give random seed and sample input

opt_state = optimizer(lr).init(Θ)

seed_vector = jax.random.split(seed, 1000000) # create a random vector for lax.scan input for stock_return function 

# we need to pass different rng everytime we call it.
def stock_return(rng):
  
  μs = 0.06
  σs = 0.2
  ε = jax.random.normal(rng, ())
  log_return = μs + σs * ε
  return jnp.exp(log_return)

def L(Θ, seed): # now pass seed as input because it changes everytime we run a trial 

  x = 1.
  G = 0.
  rng_vector = jax.random.split(seed, T) # create a random vector for lax.scan input for stock_return function 

  state = x
  inputs = jnp.arange(T), rng_vector #pass rng vector as input

  def core(state, inputs):
    t, rng_vector = inputs #read rng vector[i]
    xt = state
    ct = jax.nn.sigmoid(nnet(Θ, xt) - 5.) * xt
    ut = U(ct)
    savings = xt - ct

    # R = 1.

    # x_tp1 = R * savings
    x_tp1 = stock_return(rng_vector) * savings # Wealth at next time step = (random stock return) * savings

    discounted_utility = β**t * ut
    return x_tp1, discounted_utility

  x, discounted_utility = jax.lax.scan(core, state, inputs)
  G = discounted_utility.sum()
  return -G


@jax.jit
def evaluation(Θ, seed): # pass seed as input because it changes from trial to trial
  return -L(Θ, seed)


@jax.jit
def update_gradient_descent(Θ, opt_state, seed): #passing the seed as input because we need to use it in the Loss function
  seed, _ = jax.random.split(seed) #splitting the seed and returning it to the next gradient descent
  grad = jax.grad(L)(Θ, seed)
  updates, opt_state = optimizer(lr).update(grad, opt_state)
  Θ = optax.apply_updates(Θ, updates)
  return Θ, opt_state


In [7]:
seed = jax.random.PRNGKey(0)
i = 0;
running_sum = 0;

for iteration in range(1000000):
  seed, _ = jax.random.split(seed) #changing seed from trial to trial
  Θ, opt_state = update_gradient_descent(Θ, opt_state, seed) #passiing seed into GD function
  running_sum = running_sum = evaluation(Θ, seed)
  
  if iteration % 1000 == 0:
    x = evaluation(Θ, seed)
    print(x)

  i = i+1 #incrementing i so that we pass next seed to the next iteration

-773.8599
-294.99646
-244.39142
-339.10553
-403.46866
-248.97012
-322.73706


KeyboardInterrupt: ignored

In [None]:
# jax.nn.sigmoid(θ)

In [11]:
ε = jax.random.normal(seed, (2, 100))
ε.shape

(2, 100)