# Real-Business Cycle Model with Search and Matching (Andolfatto, 1996)

In [1]:
using Plots
using LinearAlgebra

# Model Parameters
const β = 0.99       # Discount factor
const α = 0.36       # Capital share
const δ = 0.1        # Depreciation rate
const γ = 2.0        # Risk aversion parameter
const ψ = 1.5        # Labor supply elasticity

# Employment Parameters
const b = 0.5        # Unemployment benefit replacement rate
const p_employed = 0.97  # Probability of remaining employed

# Grid Parameters
const k_min = 2 #0.1    # Minimum capital
const k_max = 2.75 #10.0   # Maximum capital
const N_k = 2000      # Number of capital grid points
const N_a = 2        # Number of aggregate productivity states

# Utility function with labor disutility
function utility(c, n; γ=γ, ψ=ψ)
    if c <= 0 || n < 0 || n > 1
        return -Inf
    end
    # CRRA consumption utility minus labor disutility
    return (c^(1-γ) / (1-γ)) - (n^(1+1/ψ) / (1+1/ψ))
end

# Derive optimal labor supply given wage and marginal utility of consumption
function optimal_labor_supply(w, λ; ψ=ψ)
    # First-order condition for labor supply
    # λ is the marginal utility of consumption
    # n = (w * λ)^ψ
    n = min(1.0, max(0.0, (w * λ)^ψ))
    return n
end

# Create grids
function create_grids()
    # Logarithmically spaced capital grid
    k_grid = exp.(range(log(k_min), log(k_max), length=N_k))
    
    # Aggregate productivity grid
    a_grid = [0.9, 1.1]  # Low and high productivity states
    
    return k_grid, a_grid
end

# Value function iteration
function value_function_iteration(k_grid, a_grid)
    # Initialize arrays
    V = zeros(N_k, N_a)
    V_new = similar(V)
    policy_c = zeros(N_k, N_a)
    policy_n = zeros(N_k, N_a)
    policy_k_next = zeros(N_k, N_a)
    
    # Convergence parameters
    tol = 1e-6
    max_iter = 1000
    err = Inf
    iter = 0
    
    while err > tol && iter < max_iter
        iter += 1
        
        for i_a in 1:N_a, i_k in 1:N_k
            k = k_grid[i_k]
            a = a_grid[i_a]
            
            # Best value for this state
            max_value = -Inf
            
            # Compute income when employed
            w = a * (1 - α) * (k)^α
            r = a * α * (k)^(α-1) - δ
            
            # Iterate through possible next period capital choices
            for i_k_next in 1:N_k
                k_next = k_grid[i_k_next]
                
                # Employed state
                c_employed = w + (1 + r) * k - k_next
                
                if c_employed > 0
                    # Derive optimal labor supply
                    λ_employed = c_employed^(-γ)  # Marginal utility of consumption
                    n_employed = optimal_labor_supply(w, λ_employed)
                    
                    # Expected value when employed
                    EV_employed = 0.0
                    for j_a in 1:N_a
                        EV_employed += 0.5 * (utility(c_employed, n_employed) + 
                                              β * V[i_k_next, j_a])
                    end
                    
                    # Unemployed state
                    c_unemployed = b * w + (1 + r) * k - k_next
                    
                    if c_unemployed > 0
                        # Labor supply when unemployed
                        λ_unemployed = c_unemployed^(-γ)
                        n_unemployed = 0.0  # No labor supply when unemployed
                        
                        # Expected value when unemployed
                        EV_unemployed = 0.0
                        for j_a in 1:N_a
                            EV_unemployed += 0.5 * (utility(c_unemployed, n_unemployed) + 
                                                    β * V[i_k_next, j_a])
                        end
                        
                        # Choose the maximum value between employed and unemployed states
                        current_value = max(EV_employed, EV_unemployed)
                        
                        if current_value > max_value
                            max_value = current_value
                            
                            if EV_employed >= EV_unemployed
                                V_new[i_k, i_a] = EV_employed
                                policy_c[i_k, i_a] = c_employed
                                policy_n[i_k, i_a] = n_employed
                                policy_k_next[i_k, i_a] = k_next
                            else
                                V_new[i_k, i_a] = EV_unemployed
                                policy_c[i_k, i_a] = c_unemployed
                                policy_n[i_k, i_a] = n_unemployed
                                policy_k_next[i_k, i_a] = k_next
                            end
                        end
                    end
                end
            end
        end
        
        # Compute convergence error
        err = maximum(abs, V_new - V)
        
        # Update value function
        V .= V_new
    end
    
    return V, policy_c, policy_n, policy_k_next, iter
end

# Main solution and plotting function
function solve_and_plot_model()
    # Create grids
    k_grid, a_grid = create_grids()
    
    # Solve value function iteration
    V, policy_c, policy_n, policy_k_next, iter = value_function_iteration(k_grid, a_grid)
    
    # Print convergence information
    println("Value Function Iteration:")
    println("Iterations: ", iter)
    
    # Plot policy functions for both aggregate productivity states
    p1 = plot(k_grid, policy_c[:, 1], label="Consumption (Low Productivity)", 
              title="Consumption Policy Function", 
              xlabel="Current Capital", ylabel="Consumption")
    plot!(p1, k_grid, policy_c[:, 2], label="Consumption (High Productivity)")
    
    p2 = plot(k_grid, policy_n[:, 1], label="Labor Supply (Low Productivity)", 
               title="Labor Supply Policy Function", 
               xlabel="Current Capital", ylabel="Labor Supply")
    plot!(p2, k_grid, policy_n[:, 2], label="Labor Supply (High Productivity)")
    
    p3 = plot(k_grid, policy_k_next[:, 1], label="Next Period Capital (Low Productivity)", 
               title="Capital Accumulation Policy Function", 
               xlabel="Current Capital", ylabel="Next Period Capital")
    plot!(p3, k_grid, policy_k_next[:, 2], label="Next Period Capital (High Productivity)")
    
    # Combine plots
    combined_plot = plot(p1, p2, p3, layout=(3,1), size=(800, 1200))
    
    # Save the plot
    savefig(combined_plot, "andolfatto_policy_functions.png")
    
    return k_grid, a_grid, V, policy_c, policy_n, policy_k_next
end

# Run the model
k_grid, a_grid, V, policy_c, policy_n, policy_k_next = solve_and_plot_model()

Value Function Iteration:
Iterations: 1000


([2.0, 2.0003186384175518, 2.0006373276003244, 2.0009560675564058, 2.0012748582938844, 2.001593699820851, 2.0019125921453975, 2.002231535275617, 2.002550529219603, 2.0028695739854525  …  2.7460599888467567, 2.746497488951431, 2.7469350587582766, 2.7473726982783964, 2.7478104075228993, 2.748248186502892, 2.748686035229486, 2.7491239537137924, 2.7495619419669253, 2.75], [0.9, 1.1], [-112.1563495362108 -111.59034497418871; -112.15535854461231 -111.58955931359183; … ; -110.38669318911688 -110.03449376085985; -110.3858560095761 -110.03393319761761], [0.9550833078066141 0.6840230771174625; 0.9554363287102734 0.6839881010285729; … ; 1.1040414979822395 1.3077757343120862; 1.1040853127324857 1.308260784493207], [0.7295686404938418 1.0; 0.728822936464917 1.0; … ; 0.560892749922569 0.455996446797052; 0.5608742242418713 0.4555286251089809], [2.0 2.5277454102017325; 2.0 2.528148128600473; … ; 2.665885147867692 2.75; 2.666309874580138 2.75])