## Transition Dynamics for PBMA Aiyagari model

In this notebook, we discuss the transition dynamics from one equilibrium to the other equilibrium.

I use the methodology in Section 4 of [Transition Dynamics in the Aiyagari Model, with an application to Wealth Tax by Toshihiko Mukoyama](https://toshimukoyama.github.io/MyWebsite/Aiyagari.pdf) and apply it to the current setup.

Here I outline the algorithm as follows:

#### Step 1 Compute the Initial steady state and final steady state

We have already accomplished this computation. We obtained



*   $K_s$ as the steady-state capital for the short-run policy equilibrium (SRPE),

- the optimal short-run policy at the SRPE,  
- and its corresponding stationary asset distribution
*   $K_c$ as the steady-state capital for the continuation policy equilibrium (CPE),

- optimal continuation policy at the CPE,

- and its corresponding stationary asset distribution


#### Step 2 Guess the time series and Backward Induction

- Guess the time series $K_t$for $t=1,2,\cdots, T$, where $t=1$ is the period of unexpected change in discount factor

- $T$ is a sufficiently large future to ensure we arrive at another equilibrium after one unexpected change in discount factor

- Now we have a finite-horizon DP problem.

- We use backward induction to get the vaue functions and decision rules for $t=1,2,\cdots T$

#### Step 3 Forward simulation

- Use the decision rule above, simulate the economy forward

- calculate the simulated aggregate capital $K_t$ at $t=1,2,\cdots, T$

#### Step 4 Compare to the guess and iterate until convergence.

Now, we create the following function to do the above iteration.

In [None]:
# First, we import the functions from PBMA_Ayagari.ipynb

%run "https://github.com/longye-tian/anu-phd/blob/main/QHD/PBMA_Ayagari.ipynb"



We need a helper function to update the distribution.

In [None]:
# We modify out compute_asset_distribution a little bit to compute the joint distribution

@jax.jit
def compute_joint_stationary(σ, household):
    # Unpack
    β, δ, a_grid, z_grid, Π = household
    a_size, z_size = len(a_grid), len(z_grid)

    # Construct P_σ as an array of the form P_σ[i, j, ip, jp]
    ap_idx = jnp.arange(a_size)
    ap_idx = jnp.reshape(ap_idx, (1, 1, a_size, 1))
    σ = jnp.reshape(σ, (a_size, z_size, 1, 1))
    A = jnp.where(σ == ap_idx, 1, 0)
    Π = jnp.reshape(Π, (1, z_size, 1, z_size))
    P_σ = A * Π

    # Reshape P_σ into a matrix
    n = a_size * z_size
    P_σ = jnp.reshape(P_σ, (n, n))

    # Get stationary distribution and reshape back onto [i, j] grid
    ψ = compute_stationary(P_σ)
    ψ = jnp.reshape(ψ, (a_size, z_size))

    return ψ

# We also build a helper function to update the joint distribution.

@jax.jit
def update_joint_stationary(ψ, σ, household):
    # Unpack
    β, δ, a_grid, z_grid, Π = household
    a_size, z_size = len(a_grid), len(z_grid)

    # Construct P_σ as an array of the form P_σ[i, j, ip, jp]
    ap_idx = jnp.arange(a_size)
    ap_idx = jnp.reshape(ap_idx, (1, 1, a_size, 1))
    σ = jnp.reshape(σ, (a_size, z_size, 1, 1))
    A = jnp.where(σ == ap_idx, 1, 0)
    Π = jnp.reshape(Π, (1, z_size, 1, z_size))
    P_σ = A * Π

    # Reshape P_σ into a matrix
    n = a_size * z_size
    P_σ = jnp.reshape(P_σ, (n, n))

    # Reshape the input joint distribution into a vector
    ψ = jnp.reshape(ψ, (n, ))

    # Update the joint distribution
    ψ_new = jnp.dot(ψ, P_σ) / jnp.sum(ψ)

    # Reshape the updated joint distribution back to [i, j] grid
    ψ_new = jnp.reshape(ψ, (a_size, z_size))


    return ψ_new


In [None]:
# Compute transition dynamics from SRPE to CPE

def transit_SRPE_to_CPE(household,       # household parameters
                        firm,            # firm parameters
                        T=100,           # Max transition period
                        γ=0.9,           # damping parameter
                        max_iter=10000,  # maximum iteration
                        tol=1e-4,        # tolerance
                        verbose=False
                        ):
  # Step 1.1: Compute SRPE
  print("Computing initial steady state.....")

  K_init, _ = compute_equilibrium_short(firm, household)        # Compute SRPE capital as initial capital
  r_init = r_given_k(K_init, firm)                              # Compute SRPE interest rate as initial interest rate
  w_init = r_to_w(r_init, firm)                                 # Compute SRPE wage as initial wage
  price_init = create_price(r=r_init, w=w_init)                 # Compute SRPE price as initial price

  σ_init, v_init = compute_lifetime(household, price_init)      # compute the SRPE optimal policy and value function

  # Step 1.2: Compute CPE
  print("Computing terminal steady state....")

  K_term, _ = compute_equilibrium_continuation(firm, household)     # Compute CPE capital as terminal capital
  r_term = r_given_k(K_term, firm)                                  # Compute CPE interest rate as terminal
  w_term = r_to_w(r_term, firm)                                     # Compute CPE wage as terminal
  price_term = create_price(r=r_term, w=w_term)                     # Compute CPE price as terminal

  σ_term, v_term = howard_policy_iteration(household, price_term)   # compute the CPE optimal policy and value function


  # Step 1.3: Get initial stationary asset distribution
  print("Computing Initial stationary asset distribution......")

  ψ_init = compute_joint_stationary(σ_init, household)       # Compute initial stationary joint distribution

  # Step 2.1: Set up the Initial Guess
  K_path = np.linspace(K_init, K_term, T)                    # Set up the initial guess

  # Unpack the household parameters
  β, δ, a_grid, z_grid, Π = household
  a_size, z_size = len(a_grid), len(z_grid)

  # Initialize arrays to store policies and value functions for each time period
  σ_path = np.zeros((T, a_size, z_size))
  v_path = np.zeros((T, a_size, z_size))

  # For tracking convergence
  iter_count = 0
  max_diff = tol + 1

  # Iteration Loop
  while max_diff > tol and iter_count < max_iter:
    # Compute interest rate and wages for the current capital path.
    r_path = np.array([r_given_k(K, firm) for K in K_path])
    w_path = np.array([r_to_w(r, firm) for r in r_path])

    # Step 2.2 Backward Induction
    # Start with final steady state value function
    v_next  = v_term

    # Backward induction
    for t in range(T-1, -1, -1):

      # Create prices for this period
      price_t = create_price(r=r_path[t], w=w_path[t])

      # Solve a two-period DP problem under discount factor δ
      σ_t = get_greedy(v_next, household, prices_t)    # v_next-greedy policy
      v_t = get_value(σ_t, household, prices_t)        # σ_t-value function

      # Store policies and value functions
      σ_path[t] = σ_t
      v_path[t] = v_t

      # Update v_next for the next iteration
      v_next = v_t

    # Step 3: Forward Simulation
    # Start with initial stationary joint distribution
    ψ_t = ψ_init

    # Initialize the forward simulation path
    K_path_new = np.zeros(T)

    # Forward Simution
    for t in range(T):

      # Compute aggregate capital implied by current distribution
      K_path_new[t] = float(jnp.sum(ψ_t * a_grid))

      # Update distribution using policy function for period t
      ψ_t = update_joint_distribution(ψ_t, σ_path[t], household)

    # Step 4: Check convergence and update capital path
    max_diff = np.max(np.abs(K_path_new - K_path))

    # Update capital path with damping
    K_path = γ * K_path + (1 - γ) * K_path_new

    # Update iteration count
    iter_count += 1

    if verbose and iter_count % 10 == 0:
      print(f"Iteration {iter_count}, max difference: {max_diff:.6f}")

  if max_diff <= tol:
    print(f"Convergence achieved after {iter_count} iterations.")
  else:
    print(f"Maximum iterations reached after {iter_count} without convergence.")


  # Compute transition path

  transition_path = {
      'K': K_path,
      'σ': σ_path,
      'v': v_path
  }

  return transition_path






δ