Write a series of functions to estimate the value function using approximation. The idea here is to guess some coefficients (for the basis functions) and in each iteration the model will imply another set of coefficients and you do that until everything converges. 


# Step 1: Specify the initial guess and the interval of capital that we are trying to maximize on 

In [5]:
 using LinearAlgebra
 using Optim
 using Plots
 params = (alpha = 0.75, # capital share
           beta = 0.95, # discount
           eta = 2, # EMUC
           steady_state = (0.75*0.95)^(1/(1 - 0.75)),
           k_0 = (0.75*0.95)^(1/(1 - 0.75))/2, # initial state
           capital_upper = (0.75*0.95)^(1/(1 - 0.75))*1.01, # upper bound
           capital_lower = (0.75*0.95)^(1/(1 - 0.75))/2, # lower bound
           num_points = 7, # number of grid points
           tolerance = 0.0001)

(alpha = 0.75, beta = 0.95, eta = 2, steady_state = 0.25771486816406236, k_0 = 0.12885743408203118, capital_upper = 0.26029201684570297, capital_lower = 0.12885743408203118, num_points = 7, tolerance = 0.0001)

# Step 2: Make the guess of the betas (coefficients on the basis functions)

In [6]:
 coefficients = zeros(params.num_points) # # coeffs = # grid points in collocation

7-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

# Step 3: Select the convergence rule:

Max change in value on the grid is < 0.0001%

# Step 4: Construct grid points

In [None]:
function cheb_polys(x, n)
     if n == 0
         return x ./ x               # T_0(x) = 1
     elseif n == 1
         return x                    # T_1(x) = x
     else
         cheb_recursion(x, n) =
             2x .* cheb_polys.(x, n-1) .- cheb_polys.(x, n-2)
         return cheb_recursion(x, n) # T_n(x) = 2xT_{n-1}(x) - T_{n-2}(x)
     end
 end;

In [7]:
cheb_nodes(n) = cos.(pi * (2*(1:n) .- 1)./(2n));
grid = cheb_nodes(params.num_points) # [-1, 1] grid with n points

7-element Vector{Float64}:
  0.9749279121818236
  0.7818314824680298
  0.4338837391175582
  6.123233995736766e-17
 -0.43388373911755806
 -0.7818314824680297
 -0.9749279121818236

Need to "scale" the grid from [-1,1] to be the interval that we have and "descale"

In [8]:
expand_grid(grid, params) = # function that expands [-1,1] to [a,b]
(1 .+ grid)*(params.capital_upper - params.capital_lower)/2 .+ params.capital_lower
capital_grid = expand_grid(grid, params)

7-element Vector{Float64}:
 0.2586443471450049
 0.2459545728087113
 0.2230883895732961
 0.19457472546386706
 0.16606106135443804
 0.14319487811902284
 0.13050510378272925

In [9]:
 shrink_grid(capital) = 
   2*(capital - params.capital_lower)/(params.capital_upper - params.capital_lower) - 1;
 shrink_grid.(capital_grid)

7-element Vector{Float64}:
  0.9749279121818237
  0.7818314824680297
  0.43388373911755806
 -2.220446049250313e-16
 -0.43388373911755806
 -0.7818314824680297
 -0.9749279121818236

For each grid point, evaluate the n degree chebyshev polynomials.

In [11]:
 construct_basis_matrix(grid, params) = hcat([cheb_polys.(shrink_grid.(grid), n) for n = 0:params.num_points - 1]...);
 basis_matrix = construct_basis_matrix(capital_grid, params)
 basis_inverse = basis_matrix \ I # pre-invert


7×7 Matrix{Float64}:
 0.142857    0.142857    0.142857   …   0.142857    0.142857    0.142857
 0.278551    0.22338     0.123967      -0.123967   -0.22338    -0.278551
 0.25742     0.0635774  -0.17814       -0.17814     0.0635774   0.25742
 0.22338    -0.123967   -0.278551       0.278551    0.123967   -0.22338
 0.17814    -0.25742    -0.0635774     -0.0635774  -0.25742     0.17814
 0.123967   -0.278551    0.22338    …  -0.22338     0.278551   -0.123967
 0.0635774  -0.17814     0.25742        0.25742    -0.17814     0.0635774

After having the basis matrix, multiply it by the coefficients to get the value functions at every grid point

In [13]:
 eval_value_function(coefficients, grid, params) = construct_basis_matrix(grid, params) * coefficients;
eval_value_function([1,1,1,1,1,1,1], grid, params)

7-element Vector{Float64}:
     9.26624982788923e7
     1.6940689772371337e7
 77224.00563563094
 15354.384388936067
     2.287632997495987e7
     3.3075659788769543e8
     9.839582890300741e8

In [16]:
grid

7-element Vector{Float64}:
  0.9749279121818236
  0.7818314824680298
  0.4338837391175582
  6.123233995736766e-17
 -0.43388373911755806
 -0.7818314824680297
 -0.9749279121818236

In [19]:
for (iteration, capital) in enumerate(grid)
    println("$iteration + $capital")
end


1 + 0.9749279121818236
2 + 0.7818314824680298
3 + 0.4338837391175582
4 + 6.123233995736766e-17
5 + -0.43388373911755806
6 + -0.7818314824680297
7 + -0.9749279121818236
