In [1]:
import numpy as np
from scipy import interpolate
import time
import matplotlib.pyplot as plt

In [None]:
# Parameters - households and firms
beta = 0.95  # discount factor
eta = 2.0    # coefficient of relative risk aversion
gamma = 0.5  # disutility from working
alpha = 0.33 # capital share in production
delta = 0.1  # depreciation rate
omega = 0.1  # parent education time factor
phi_p = 0.5  # share of parental input in education
psi_e = 0.5  # CES parameter for education inputs
theta_I = 0.7  # human capital investment effectiveness
tfp = 1
small = 0.00001
neg = -1e10

#Parameters - government
tau = 0.15
Ie_gdp = 0.02 # public education investment as percent of GDP
Iz_gdp = 0.04 # public infrastructure investment as percent of GDP percent of GDP
sigma_bar = 0.2  # public investment inefficiency
G = 0.1  # grants (as a fraction of GDP)
psi = 0.1    # public infrastructure elasticity
nu = 0.6     # debt-to-GDP ratio limit

# Initial guesses
K_init = 1
H_init = 1
L_init = 0.2
Z_init = 1
tau_init = 0.15  # Initial guess for tax rate
T_init = 0.1  # Initial guess for transfers
D_init = 0

# Grid spec
nK = 20  # number of points in capital grid
nH = 20   # number of points in human capital grid
nZ = 20   # number of points in public infrastructure grid
Kmin = 0
Kmax = 10
Hmin = 0
Hmax = 10
Zmin = 0
Zmax = 10

# Number of periods
T = 60
working_periods = 36
parent_periods = 18

In [None]:
# Create grids
k_grid = np.linspace(Kmin, Kmax, nK)
h_grid = np.linspace(Hmin, Hmax, nH)
Z_grid = np.linspace(Zmin, Zmax, nZ)

In [None]:
def production(K_prev, H, L, Z_prev):
    return tfp * Z_prev**psi * K_prev**alpha * (H * L)**(1-alpha)

def wage_rate(K_prev, H, L, Z_prev):
    return H * tfp * (1 - alpha) * Z_prev**psi * (K_prev)**alpha * (H * L)**(-alpha)

def interest_rate(K_prev, H, L, Z_prev):
    return tfp * alpha * Z_prev**psi * (H * L)**(1-alpha) * K_prev**(alpha-1) - delta

In [None]:
def utility(c, l):
  if l == 1:
    return np.log(c+ small) + gamma * np.log(l)
  return (((c + small) * (1 - l) ** gamma) ** (1 - eta) - 1) / (1 - eta)

def human_capital_investment(e, I_e, h):
    I_c = (phi_p * (e)**psi_e + (1 - phi_p) * (I_e / 18)**psi_e)**(1 / psi_e)
    return I_c**theta_I * h**(1 - theta_I)

def value_function(k_prev, k_next, k, h, l, w, r, tau, T):
  c = ((1-tau)*w*h*l + (1+r)*k - max(0, tau * r * k_prev) + T - k) / (1 + tau) #post tax consumption = assets
  if c<=0:
    return neg
  else:
    return utility(c,l) + beta * vr_polate(k_next)

In [None]:
def government_budget(Y, K_prev, H, L, Z_prev, C, I_e, I_z, tau, T, D_prev, r_prev):
    tax_revenue = tau * (wage_rate(K_prev, H, L, Z_prev) * H * L + (r_prev) * K_prev) + tau*C
    D_delta = r_prev * D_prev + I_e + I_z + T - G * Y - tax_revenue
    return D_delta

In [None]:
def solve_steady_state(max_iter=100, tol=1e-4):
    # Initial guesses
    K = K_init
    H = H_init
    L = L_init
    Z = Z_init
    tau = tau_init 
    T = T_init 
    D = D_init
    
    for _ in range(max_iter):
        Y = production(K, H, L, Z)
        I_e = Ie_gdp * Y  
        I_z = Iz_gdp * Y  
        r = interest_rate(K, H, L, Z)
        w = wage_rate(K, H, L, Z)
        
        V = solve_household_problem(K, H, L, Z, I_e, tau, T)
        
        # Compute new aggregates
        K_new = 0
        H_new = 0
        L_new = 0
        for s in range(T):
            if s < working_periods:
                policy = lambda k, h: minimize(lambda x: -V[np.searchsorted(k_grid, k), np.searchsorted(h_grid, h), s], 
                                               [k/2, 0.5, 0.1] if s <= parent_periods else [k/2, 0.5], 
                                               method='L-BFGS-B').x
                K_new += np.mean([policy(k, h)[0] for k in k_grid for h in h_grid])
                H_new += np.mean([h for _ in k_grid for h in h_grid])
                L_new += np.mean([policy(k, h)[1] for k in k_grid for h in h_grid])
        
        K_new /= T
        H_new /= T
        L_new /= working_periods
        
        D_new = government_budget(Y, K, H, L, Z, I_e, I_z, tau, T, D, r)
        
        # Adjust tax rate and transfers to balance budget
        def budget_balance(x):
            tau_new, T_new = x
            return abs(government_budget(Y, K_new, H_new, L_new, Z, I_e, I_z, tau_new, T_new, D_new, r))
        
        result = minimize(budget_balance, [tau, T], method='Nelder-Mead')
        tau_new, T_new = result.x
        
        # Check convergence
        if (abs(K_new - K) < tol and abs(H_new - H) < tol and abs(L_new - L) < tol and 
            abs(tau_new - tau) < tol and abs(T_new - T) < tol):
            break

        # Update variables
        K, H, L, tau, T, D = K_new, H_new, L_new, tau_new, T_new, D_new
        
        # Update public infrastructure
        Z = (1 - delta) * Z + (1 - sigma_bar) * I_z
    
    return K, H, L, Z, I_e, I_z, tau, T, D, Y
