In [13]:
using CairoMakie, Dierckx, Optim, LinearAlgebra, QuantEcon

In [14]:
pars = (;α = 0.33, # Capital share
        β = 0.9, # Discount factor
        A = 10.0, # TFP
        γ = 1.0, # Risk aversion
        δ = 0.1, # Depreciation rate
        nk = 31, # Number of capital gridpoints
        θ = 4, # Grid expansion parameter
        lb = 0.00, # Lower bound of capital grid
        ub = 1000.0, # Upper bound of capital grid
        nz = 19, # Number of shock gridpoints
        ρ = 0.9, # Persistence of AR(1) process
        μ = 0.0, # Mean of AR(1) process
        σ = 0.007, # Var of AR(1) process
        toler = 1e-6, # Tolerance
        maxiter = 10000) # Maximum number of iterations

(α = 0.33, β = 0.9, A = 10.0, γ = 1.0, δ = 0.1, nk = 31, θ = 4, lb = 0.0, ub = 1000.0, nz = 19, ρ = 0.9, μ = 0.0, σ = 0.007, toler = 1.0e-6, maxiter = 10000)

In [15]:
function utility(c)
    if pars.γ == 1.0
        return log(c)
    else
        return (c^(1-pars.γ)) / (1-pars.γ)
    end
end

function ar1(pars)
    (; μ, ρ, σ, nz) = pars
    mc = rouwenhorst(nz, μ, ρ, σ)
    Π, Zvals = mc.p, exp.(mc.state_values)
    return Π, Zvals
end

function exp_grid(pars)
    (; nk, θ, lb, ub) = pars
    grid = LinRange(1e-4, 1.0, nk)
    expgrid = lb .+ (ub .- lb) .* grid.^θ
    return expgrid
end

function interpolate(grid, vals, pars)
    spline = Spline1D(grid, vals, k = 1, bc = "extrapolate")
    return spline
end

function production(k, z ,pars)
    (; α, δ, A) = pars
    return A * z * (k ^ α) + (1.0 - δ) * k
end

function initial_guess(grid, pars)
    (; nk, nz) = pars
    v = ones(nk, nz)
    v_out = zeros(nk, nz)
    for j in 1:nz
        v_out[:,j] = utility.(grid)
    end
    return v_out
end

function resource_grid(kgrid, zgrid, pars)
    (; nk, nz) = pars
    Ygrid = zeros(nk, nz)
    for i in 1:nk
        for j in 1:nz
            Ygrid[i,j] = production(kgrid[i], zgrid[j], pars)
        end
    end
    return Ygrid
end

resource_grid (generic function with 1 method)

In [29]:
function egm(pars)
    (; γ, nk, nz, toler, maxiter, β) = pars

    Π, Zvals = ar1(pars)
    Kvals = exp_grid(pars)

    v1 = initial_guess(exp_grid(pars), pars)
    v2 = zeros(nk, nz)
    v3 = zeros(nk, nz)
    v4 = zeros(nk, nz)

    c = zeros(nk, nz)

    derivatives = zeros(nk, nz)

    Yvals = resource_grid(Kvals, Zvals, pars)
    Ystar = zeros(nk, nz)

    Kinterpolators = Vector{Spline1D}(undef, nz)
    Yinterpolators = Vector{Spline1D}(undef, nz)

    error = toler + 1
    iter = 0
    if iter == 0
        println("Iterating...")
    end
    while ((error > toler) && (iter < maxiter))
        Kinterpolators = [interpolate(Kvals, v1[:,j], pars) for j in 1:nz]
        for (j, interpolator) in enumerate(Kinterpolators)
            derivatives[:,j] = Dierckx.derivative(interpolator, Kvals)
        end
        c = (derivatives).^(-1/γ)
        Ystar = c .+ Kvals
        v2 = utility.(c) .+ v1
        Yinterpolators = [interpolate(Ystar[:,j], v2[:,j], pars) for j in 1:nz]
        for (j, interpolator) in enumerate(Yinterpolators)
            v3[:,j] = interpolator(Yvals[:,j])
        end
        v4 = β * (v3 * Π')
        error = maximum(abs.(v4 - v1) ./ (1.0 .+ abs.(v1)))
        if iter % 10 == 0
            println("--------------------")
            println("Iteration: $iter, Error: $error")
        end
        v1 .= v4
        iter += 1
    end
    println("--------------------")
    println("Converged in $iter iterations")
    println("--------------------")
    return v1
end
test = egm(pars)  

Iterating...
--------------------
Iteration: 0, Error: 2.4950651087663642
--------------------
Iteration: 10, Error: 0.06487028005195125
--------------------
Iteration: 20, Error: 0.016215586365355816
--------------------
Iteration: 30, Error: 0.0051302026572092665
--------------------
Iteration: 40, Error: 0.0017318089522014716
--------------------
Iteration: 50, Error: 0.000597154420308118
--------------------
Iteration: 60, Error: 0.00020741056332557623
--------------------
Iteration: 70, Error: 7.222215213126745e-5
--------------------
Iteration: 80, Error: 2.5170473844820828e-5
--------------------
Iteration: 90, Error: 8.774963334097296e-6
--------------------
Iteration: 100, Error: 3.0594656874694094e-6
--------------------
Iteration: 110, Error: 1.0667484674043117e-6
--------------------
Converged in 112 iterations
--------------------


31×19 Matrix{Float64}:
 -56.0665  -56.0665  -56.0665  -56.0665  …  -56.0665  -56.0665  -56.0665
  25.4719   25.4719   25.4719   25.4719      25.4719   25.4719   25.4719
  26.9597   26.9597   26.9597   26.9597      26.9597   26.9597   26.9597
  27.892    27.892    27.892    27.892       27.892    27.892    27.892
  28.6088   28.6088   28.6088   28.6088      28.6088   28.6088   28.6088
  29.206    29.206    29.206    29.206   …   29.206    29.206    29.206
  29.7487   29.7487   29.7487   29.7487      29.7487   29.7487   29.7487
  30.2751   30.2751   30.2751   30.2751      30.2751   30.2751   30.2751
  30.7688   30.7688   30.7688   30.7688      30.7688   30.7688   30.7688
  31.2535   31.2535   31.2535   31.2535      31.2535   31.2535   31.2535
   ⋮                                     ⋱                      
  37.7055   37.7055   37.7055   37.7055      37.7055   37.7055   37.7055
  38.2182   38.2182   38.2182   38.2182      38.2182   38.2182   38.2182
  38.7313   38.7313   38.7313   38.731