In [89]:
using CairoMakie, Dierckx, Optim, LinearAlgebra, QuantEcon

In [117]:
pars = (;α = 0.33, # Capital share
        β = 0.96, # Discount factor
        A = 10.0, # TFP
        γ = 2.0, # Risk aversion
        δ = 1.0, # Depreciation rate
        nk = 11, # Number of capital gridpoints
        θ = 2, # Grid expansion parameter
        lb = 0, # Lower bound of capital grid
        ub = 100.0, # Upper bound of capital grid
        nz = 15, # Number of shock gridpoints
        ρ = 0.98, # Persistence of AR(1) process
        μ = 0.5, # Mean of AR(1) process
        σ = 0.01, # Var of AR(1) process
        toler = 1e-6, # Tolerance
        maxiter = 1, # Maximum number of iterations
        print_skip = 10) # Print every x iterations

(α = 0.33, β = 0.96, A = 10.0, γ = 2.0, δ = 1.0, nk = 11, θ = 2, lb = 0, ub = 100.0, nz = 15, ρ = 0.98, μ = 0.5, σ = 0.01, toler = 1.0e-6, maxiter = 1, print_skip = 10)

In [121]:
function utility(c)
    if pars.γ == 1.0
        return log(c)
    else
        return (c^(1-pars.γ)) / (1-pars.γ)
    end
end 

function ar1(pars)
    (; μ, ρ, σ, nz) = pars
    mc = rouwenhorst(nz, μ, ρ, σ)
    @show mc.state_values
    Π, Zvals = mc.p, exp.(mc.state_values)
    return Π, Zvals
end

function exp_grid(pars)
    (; nk, θ, lb, ub) = pars
    grid = LinRange(1e-10, 1.0, nk)
    expgrid = lb .+ (ub .- lb) .* grid.^θ
    return expgrid
end

function interpolate(grid, vals, pars)
    spline = Spline1D(grid, vals, k = 1, bc = "extrapolate")
    return spline
end

function production(k, z ,pars)
    (; α, δ, A) = pars
    return A * z * (k ^ α) + (1.0 - δ) * k
end

function initial_guess(kgrid, zgrid, pars)
    (; nk, nz) = pars
    v_out = zeros(nk, nz)
    for i in 1:nk
        for j in 1:nz
            v_out[i,j] = utility(production(kgrid[i], zgrid[j], pars))
        end
    end
    return v_out
end

function resource_grid(kgrid, zgrid, pars)
    (; nk, nz) = pars
    Ygrid = zeros(nk, nz)
    for i in 1:nk
        for j in 1:nz
            Ygrid[i,j] = production(kgrid[i], zgrid[j], pars)
        end
    end
    return Ygrid
end

function capital(x, zgrid, res, pars)
    (; nk, nz, lb) = pars
    for i in 1:nk
        for j in 1:nz
            objective_function = k -> production(k, zgrid[j], pars) - res[i,j]
            ub = production(k, zgrid[j], pars)
            result = optimize(objective_function, lb, ub, Brent())
            x[i,j] = result.minimizer
        end
    end
    return x
end

function policy(pol, X, kgrid, pars)
    (; nz) = pars
    for j in 1:nz
        spl = x -> Spline1D(X[:,j], kgrid, k = 1, bc = "extrapolate")(x)
        pol[:,j] = spl.(kgrid)
    end
    return pol
end

function value(val, v, X, ystar, ygrid, kgrid, pars)
    (; nz) = pars
    for j in 1:nz
        spl = x -> Spline1D(X[:,j], ystar[:,j], k = 1, bc = "extrapolate")(x)
        spl2 = x -> Spline1D(ygrid[:,j], v[:,j], k = 1, bc = "extrapolate")(x)
        val[:,j] = spl2.(spl.(kgrid))
    end
    return val
end    

value (generic function with 1 method)

In [122]:
function egm(pars)
    (; γ, nk, nz, toler, maxiter, β, print_skip) = pars

    Π, Zvals = ar1(pars)
    Kvals = exp_grid(pars)

    v1 = initial_guess(Kvals, Zvals, pars)
    v2 = zeros(nk, nz)
    v3 = zeros(nk, nz)
    v4 = zeros(nk, nz)
    c = zeros(nk, nz)

    derivatives = zeros(nk, nz)

    Yvals = resource_grid(Kvals, Zvals, pars)
    Ystar = zeros(nk, nz)

    Kinterpolators = Vector{Spline1D}(undef, nz)
    Yinterpolators = Vector{Spline1D}(undef, nz)

    X = zeros(nk, nz)
    Value = zeros(nk, nz)
    Policy = zeros(nk, nz)

    error = toler + 1
    iter = 0
    if iter == 0
        println("Iterating...")
    end
    
    while ((error > toler) && (iter < maxiter))
        Kinterpolators = [interpolate(Kvals, v1[:,j], pars) for j in 1:nz]
        for (j, interpolation) in enumerate(Kinterpolators)
            derivatives[:,j] = Dierckx.derivative(interpolation, Kvals)
        end
        c = (derivatives).^(-1/γ)
        Ystar = c .+ Kvals
        #@show Ystar
        #@show Yvals
        v2 = utility.(c) .+ v1
        Yinterpolators = [interpolate(Ystar[:,j], v2[:,j], pars) for j in 1:nz]
        for (j, interpolation) in enumerate(Yinterpolators)
            v3[:,j] = interpolation(Yvals[:,j])
        end
        for i in 1:nk
            for j in 1:nz
                v4[i,j] = β * dot(v3[i,:], Π[j,:])
            end
        end
        #v4 = β * (v3 * Π')
        error = maximum(abs.(v4 - v1) ./ (1.0 .+ abs.(v1)))
        if iter % print_skip == 0
            println("--------------------")
            println("Iteration: $iter, Error: $error")
        end
        #fig2 = Figure(size = (800, 600))
        #ax2 = Axis(fig2[1, 1], xlabel = "Capital", ylabel = "Value")
        #for j in 1:3
        #    lines!(ax2, Kvals, v3[:,j], label = "Value $j, v3")
        #    lines!(ax2, Kvals, v4[:,j], label = "Value $j, v4")
        #end
        #legend = Legend(fig2[1, 2], ax2, halign = :right)
        #display(fig2)
        #v1 = copy(v4)
        iter += 1
    end

    println("--------------------")
    println("Converged in $iter iterations")
    println("--------------------")

    #X = capital(X, Zvals, Ystar, pars)
    #Value = value(Value, v1, X, Ystar, Yvals, Kvals, pars)
    #Policy = policy(Policy, X, Kvals, pars)

    return c, v1, Kvals, Ystar, Yvals, v2, v3, v4
end

egm (generic function with 1 method)

In [123]:
@time begin
    #C1, V1 = egm(pars)
    C1, V1, Kvals, Ystar, Yvals, V2, V3, V4 = egm(pars)
end

@show Yvals
#@show V1

#fig1 = Figure(size = (800, 600))
#ax1 = Axis(fig1[1, 1], title = "Value Functions", xlabel = "Assets", ylabel = "Value")
#for j in 1:pars.nz
#    lines!(ax1, Kvals, V4[:,j], label = "Shock $j")
#end
#legend = Legend(fig1[1,2], ax1, "Legend", orientation = :vertical, fontsize = 4)
#fig1

11×15 Matrix{Float64}:
 1.69764e-7  3.10841e-7  5.69153e-7  …     0.000441402     0.000808212
 0.147859    0.270731    0.495712        384.445         703.923
 0.233629    0.427777    0.783266        607.454        1112.26
 0.305314    0.559034    1.0236          793.842        1453.53
 0.369153    0.675923    1.23762         959.828        1757.46
 0.427727    0.783174    1.434       …  1112.13         2036.32
 0.482421    0.883319    1.61737        1254.34         2296.7
 0.534086    0.977918    1.79058        1388.67         2542.67
 0.583292    1.06801     1.95555        1516.61         2776.93
 0.630444    1.15435     2.11363        1639.21         3001.41
 0.675844    1.23748     2.26584     …  1757.25         3217.55

In [124]:
@show Ystar

11×15 Matrix{Float64}:
   0.000412025    0.000557531    0.000754423  …    0.0210096     0.0284291
   2.09921        2.48739        3.01266          57.0496       76.8435
   6.23052        7.01823        8.08412         117.737       157.903
  12.5155        13.7569        15.4369          188.257       251.562
  20.9255        22.665         25.0187          267.158       355.855
  31.442         33.717         36.7954       …  353.486       469.49
  44.0518        46.8953        50.7429          446.569       591.561
  58.7451        62.1866        66.8434          545.912       721.396
  75.5144        79.5806        85.0829          651.129       858.473
  94.3535        99.0693       105.45            761.909      1002.37
 113.354        118.069        124.45         …  780.909      1021.37

In [120]:
@show C1

11×15 Matrix{Float64}:
 1.0      0.707107  0.57735  0.5       …  0.27735   0.267261  0.258199
 1.73205  1.22474   1.0      0.866025     0.480384  0.46291   0.447214
 2.23607  1.58114   1.29099  1.11803      0.620174  0.597614  0.57735
 2.64575  1.87083   1.52753  1.32288      0.733799  0.707107  0.68313
 3.0      2.12132   1.73205  1.5          0.83205   0.801784  0.774597
 3.31662  2.34521   1.91485  1.65831   …  0.919866  0.886405  0.856349
 3.60555  2.54951   2.08167  1.80278      1.0       0.963624  0.930949
 3.87298  2.73861   2.23607  1.93649      1.07417   1.0351    1.0
 4.12311  2.91548   2.38048  2.06155      1.14354   1.10195   1.06458
 4.3589   3.08221   2.51661  2.17945      1.20894   1.16496   1.12546
 4.3589   3.08221   2.51661  2.17945   …  1.20894   1.16496   1.12546

In [116]:
@show V1

11×15 Matrix{Float64}:
  1.0   2.0   3.0   4.0   5.0   6.0  …   11.0   12.0   13.0   14.0   15.0
  2.0   4.0   6.0   8.0  10.0  12.0      22.0   24.0   26.0   28.0   30.0
  3.0   6.0   9.0  12.0  15.0  18.0      33.0   36.0   39.0   42.0   45.0
  4.0   8.0  12.0  16.0  20.0  24.0      44.0   48.0   52.0   56.0   60.0
  5.0  10.0  15.0  20.0  25.0  30.0      55.0   60.0   65.0   70.0   75.0
  6.0  12.0  18.0  24.0  30.0  36.0  …   66.0   72.0   78.0   84.0   90.0
  7.0  14.0  21.0  28.0  35.0  42.0      77.0   84.0   91.0   98.0  105.0
  8.0  16.0  24.0  32.0  40.0  48.0      88.0   96.0  104.0  112.0  120.0
  9.0  18.0  27.0  36.0  45.0  54.0      99.0  108.0  117.0  126.0  135.0
 10.0  20.0  30.0  40.0  50.0  60.0     110.0  120.0  130.0  140.0  150.0
 11.0  22.0  33.0  44.0  55.0  66.0  …  121.0  132.0  143.0  154.0  165.0

In [None]:
fig1 = Figure(size = (800, 600))
ax1 = Axis(fig1[1, 1], title = "Value Functions", xlabel = "Assets", ylabel = "Value")
for j in 1:pars.nz
    lines!(ax1, Kvals, Value[:,j], label = "Shock $j")
end
legend = Legend(fig1[1,2], ax1, "Legend", orientation = :vertical, fontsize = 4)
display(fig1)

fig2 = Figure(size = (800, 600))
ax2 = Axis(fig2[1, 1], title = "Policy Functions", xlabel = "Assets Today", ylabel = "Assets Tomorrow")
for j in 1:pars.nz
    lines!(ax2, Kvals, Policy[:,j], label = "Shock $j")
end
lines!(ax2, Kvals, Kvals, label = "45 Deg Line", color = :black, linestyle = :dash)
legend = Legend(fig2[1,2], ax2, "Legend", orientation = :vertical, fontsize = 4)
display(fig2)