<a href="https://colab.research.google.com/github/mloyorev/Theory/blob/main/10_InvestmentAdjustmentCostsJax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
!pip install quantecon



In [15]:
import quantecon as qe
import numpy as np
import jax
import jax.numpy as jnp
from collections import namedtuple
import matplotlib.pyplot as plt

As in `9_OptimalInvestmentJax.ipynb` we use 64 bit floats.

In [16]:
jax.config.update("jax_enable_x64", True)

As in the previous notebook, we define the `succesive_approx` function:

In [17]:
def successive_approx(T,                     # Operator (callable)
                      x_0,                   # Initial condition
                      tolerance=1e-6,        # Error tolerance
                      max_iter=10000,        # Max iteration bound
                      print_step=25,         # Print at multiples
                      verbose=False):

    x = x_0
    error = tolerance + 1
    k = 1

    while error > tolerance and k <= max_iter:
        x_new = T(x)
        error = np.max(np.abs(x_new - x))
        if verbose and k % print_step == 0:
            print(f"Completed iteration {k} with error {error}.")
        x = x_new
        k += 1

    if error > tolerance:
        print(f"Warning: Iteration hit upper bound {max_iter}.")

    elif verbose:
        print(f"Terminated successfully in {k} iterations.")

    return x

# **Investment with Adjustment Costs with Google Jax**

As we have already mentioned, the **Bellman equation** of this model is given by

$$V(y,z)=\max_{y'}r(y,z,y')+β∑_{z'}V(y',z')Q(z,z')$$

**Details** on the assumptions of the model and derivation of the Bellman equation comes in the `8_InvestmentAdjustmentCostNumba.ipynb`.

The **main pourpose of this notebook** is to prove how `jax` improve the performance of the model solvers created in `8_InvestmentAdjustmentCostNumba.ipynb`.

In [5]:
Model = namedtuple("Model", ("beta", "a_0", "a_1", "gamma", "c","y_size", "z_size", "y_grid", "z_grid", "Q"))

def create_investment_model(
        r=0.01,                              # Interest rate
        a_0=10.0, a_1=1.0,                   # Demand parameters
        gamma=25.0, c=1.0,                   # Adjustment and unit cost
        y_min=0.0, y_max=20.0, y_size=100,   # Grid for output
        rho=0.9, nu=1.0,                     # AR(1) parameters
        z_size=150):                         # Grid size for shock

    beta = 1/(1+r)
    y_grid = np.linspace(y_min, y_max, y_size)
    mc = qe.tauchen(rho=rho, sigma=nu, n=z_size)
    z_grid, Q = mc.state_values, mc.P

    model = Model(beta=beta, a_0=a_0, a_1=a_1, gamma=gamma, c=c, y_size=y_size, z_size=z_size, y_grid=y_grid, z_grid=z_grid, Q=Q)
    return model

Then we modify the model to make it easier to pass to `jax` functions.

In [6]:
def create_investment_model_jax():
    model = create_investment_model()
    beta, a_0, a_1, gamma, c, y_size, z_size, y_grid, z_grid, Q = model

    # Break up parameters into static and nonstatic components
    constants = beta, a_0, a_1, gamma, c
    sizes = y_size, z_size
    arrays = y_grid, z_grid, Q

    # Shift arrays to the device (e.g., GPU)
    arrays = tuple(map(jax.device_put, arrays))
    return constants, sizes, arrays

Then we create a vectorized version of the RHS of the **Bellman equation** (before maximization), which is a 3D array represented by

$$B(y,z,y')=r(y,z,y')+\beta\sum_{z'}v(y',z')Q(y',y)$$

In [7]:
def B(v, constants, sizes, arrays):
    # Unpack
    beta, a_0, a_1, gamma, c = constants
    y_size, z_size = sizes
    y_grid, z_grid, Q = arrays

    # Compute current rewards r(y, z, yp) as array r[i, j, ip]
    y  = jnp.reshape(y_grid, (y_size, 1, 1))           # y[i]   ->  y[i, j, ip]
    z  = jnp.reshape(z_grid, (1, z_size, 1))           # z[j]   ->  z[i, j, ip]
    yp = jnp.reshape(y_grid, (1, 1, y_size))           # yp[ip] -> yp[i, j, ip]
    r = (a_0 - a_1 * y + z - c) * y - gamma * (yp - y)**2

    # Calculate continuation rewards at all combinations of (y, z, yp)
    v = jnp.reshape(v, (1, 1, y_size, z_size))  # v[ip, jp] -> v[i, j, ip, jp]
    Q = jnp.reshape(Q, (1, z_size, 1, z_size))  # Q[j, jp]  -> Q[i, j, ip, jp]
    EV = jnp.sum(v * Q, axis=3)                 # sum over last index jp

    # Compute the right-hand side of the Bellman equation
    return r + beta * EV

From the proofs contained in the `8_InvestmentAdjustmentCosts.ipynb` notebook, we know that the model satisfies the Blackwell sufficiency conditions to be a contraction mapping and, therefore, **has a unique solution**.

As in the previous notebook, we are going to use the **following algorithms** to solve the model:


*   Value Function Iteration (VFI).
*   Howard Policy Iteration (HPI).
*   Optimistic Policy Iteration (OPI).

Now we define the necessary operators for each algorithm



In [8]:
# -----COMPUTE CURRENT REWARD-----
def compute_r_sigma(sigma, constants, sizes, arrays):
  # Unpack model
    beta, a_0, a_1, gamma, c = constants
    y_size, z_size = sizes
    y_grid, z_grid, Q = arrays

  # Compute r_σ[i, j]
    y = jnp.reshape(y_grid, (y_size, 1))  # y[i]   ->  y[i, j]
    z = jnp.reshape(z_grid, (1, z_size))  # z[j]   ->  z[i, j]
    yp = y_grid[sigma]                    # Selection of values ​​from a grid y_grid using indices contained in an array called sigma.

    r_sigma = (a_0 - a_1 * y + z - c) * y - gamma * (yp - y) ** 2      # Compute current reward

    return r_sigma

# -----BELLMAN OPERATOR-----
def T(v, constants, sizes, arrays):
    return jnp.max(B(v, constants, sizes, arrays), axis=2)  #   The result of the B function is passed through the jnp.max function with axis=2. This means that
                                                            # the maximum value is computed along the third axis, which corresponds to the 'wp' axis. This operation
                                                            # calculates the maximum value for each combination of (w,y). The result of the T function is a 2D matrix

# -----GET GREEDY-----
def get_greedy(v, constants, sizes, arrays):
    return jnp.argmax(B(v, constants, sizes, arrays), axis=2) #   The result of the B function is then passed through the jnp.argmax function with axis=2. This means that the index
                                                              # of the maximum value is computed along the third axis, which corresponds to the 'wp' axis. This operation finds the
                                                              # index of the action that maximizes RHS of the Bellman equation for each combination (w,y).

# -----POLICY OPERATOR-----
def T_sigma(v, sigma, constants, sizes, arrays):
    # Unpack model
    beta, a_0, a_1, gamma, c = constants
    y_size, z_size = sizes
    y_grid, z_grid, Q = arrays

    r_sigma = compute_r_sigma(sigma, constants, sizes, arrays)  # Compute current reward

    # Compute the array v[σ[i, j], jp]
    zp_idx = jnp.arange(z_size)
    zp_idx = jnp.reshape(zp_idx, (1, 1, z_size))
    sigma = jnp.reshape(sigma, (y_size, z_size, 1))
    V = v[sigma, zp_idx]

    # Convert Q[j, jp] to Q[i, j, jp]
    Q = jnp.reshape(Q, (1, z_size, z_size))

    return r_sigma + beta * np.sum(V * Q, axis=2)

One more time, we define the **functions needed to compute the value** $v_{\sigma}$ of follow a particular policy $\sigma$.

In [9]:
def L_sigma(v, sigma, constants, sizes, arrays):
    # Unpack model
    beta, a_0, a_1, gamma, c = constants
    y_size, z_size = sizes
    y_grid, z_grid, Q = arrays

    # Set up the array v[σ[i, j], jp]
    zp_idx = jnp.arange(z_size)                     # Create one-dimensional arrays ranging from 0 to y_size-1
    zp_idx = jnp.reshape(zp_idx, (1, 1, z_size))    # Reshape zp_idx as an array of three dimensiones (i,j,ip)
    sigma = jnp.reshape(sigma, (y_size, z_size, 1)) # Reshape sigma as an array of three dimensiones (i,j,ip)
    V = v[sigma, zp_idx]                            # Evaluate v(w',y')

    # Expand Q[j, jp] to Q[i, j, jp]
    Q = jnp.reshape(Q, (1, y_size, y_size))

    # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
    return v - beta * np.sum(V * Q, axis=2)

def get_value(sigma, constants, sizes, arrays):
    # Unpack
    beta, a_0, a_1, gamma, c = constants
    y_size, z_size = sizes
    y_grid, z_grid, Q = arrays

    r_sigma = compute_r_sigma(sigma, constants, sizes, arrays)               # Computes current reward

    # Reduce R_σ to a function in v
    partial_R_sigma = lambda v: L_sigma(v, sigma, constants, sizes, arrays)  # Defines a function 'partial_R_sigma' that takes v as an argument

    return jax.scipy.sparse.linalg.bicgstab(partial_R_sigma, r_sigma)[0]

Now we build the `JIT` compiled versions of the previous functions. For a more detailed explanation of how does `static_argnums` works, check `9_OptimalInvestment.ipynb`.

In [10]:
B = jax.jit(B, static_argnums=(2,))
compute_r_σ = jax.jit(compute_r_sigma, static_argnums=(2,))
T = jax.jit(T, static_argnums=(2,))
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
get_value = jax.jit(get_value, static_argnums=(2,))

T_sigma = jax.jit(T_sigma, static_argnums=(3,))
L_sigma = jax.jit(L_sigma, static_argnums=(3,))

Then we introduce functions for each algorithm (VFI, OPI and HPI).

In [20]:
# Value Function Iteration
def value_iteration(model, tol=1e-5):
    constants, sizes, arrays = model
    vz = jnp.zeros(sizes)

    v_star = successive_approx(lambda v: T(v, constants, sizes, arrays), vz, tolerance=tol)
    return get_greedy(v_star, constants, sizes, arrays)

# Howard Policy Iteration
def policy_iteration(model, maxiter=250):
    constants, sizes, arrays = model
    vz = jnp.zeros(sizes)
    sigma = jnp.zeros(sizes, dtype=int)
    i, error = 0, 1.0
    while error > 0 and i < maxiter:
        v_sigma = get_value(sigma, constants, sizes, arrays)
        sigma_new = get_greedy(v_sigma, constants, sizes, arrays)
        error = jnp.max(jnp.abs(sigma_new - sigma))
        sigma = sigma_new
        i = i + 1
        print(f"Concluded loop {i} with error {error}.")
    return sigma

# Optimistic Policy Iteration
def optimistic_policy_iteration(model, tol=1e-5, m=10):
    constants, sizes, arrays = model
    v = jnp.zeros(sizes)
    error = tol + 1
    while error > tol:
        last_v = v
        sigma = get_greedy(v, constants, sizes, arrays)
        for _ in range(m):
            v = T_sigma(v, sigma, constants, sizes, arrays)
        error = jnp.max(jnp.abs(v - last_v))
    return get_greedy(v, constants, sizes, arrays)

Finally, here's a **test** of each solver.

In [12]:
model = create_investment_model_jax()

In [18]:
print("Starting VFI.")
qe.tic()
out = value_iteration(model)
elapsed = qe.toc()
print(out)
print(f"VFI completed in {elapsed} seconds.")

Starting VFI.
TOC: Elapsed: 0:00:17.90
[[ 2  2  2 ...  6  6  6]
 [ 3  3  3 ...  7  7  7]
 [ 4  4  4 ...  7  7  7]
 ...
 [82 82 82 ... 86 86 86]
 [83 83 83 ... 86 86 86]
 [84 84 84 ... 87 87 87]]
VFI completed in 17.905871152877808 seconds.


In [19]:
print("Starting OPI.")
qe.tic()
out = optimistic_policy_iteration(model, m=100)
elapsed = qe.toc()
print(out)
print(f"OPI completed in {elapsed} seconds.")

Starting OPI.
TOC: Elapsed: 0:00:8.15
[[ 2  2  2 ...  6  6  6]
 [ 3  3  3 ...  7  7  7]
 [ 4  4  4 ...  7  7  7]
 ...
 [82 82 82 ... 86 86 86]
 [83 83 83 ... 86 86 86]
 [84 84 84 ... 87 87 87]]
OPI completed in 8.158338069915771 seconds.
