In [None]:
import numpy as np
from numba import njit

In [None]:
# Model parameters
median = 1.75/1000
γ = 0.018
τ = median * γ
δ = 0.01
η = 0.032
ξ_m = 0.00256

μ_2 = 1.
ρ = 0.5
σ_2 = np.sqrt((0.21)**2*2*ρ/μ_2) # Match moments, using 100 year's std

Equation:

\begin{align*}
0 &= \max_{e}\min_{h_2} b\left[\delta \eta \log e - \tau z_2 e + \xi_m \frac{(h_2)^2}{2}\right] - \ell e - \frac{\partial \psi}{\partial b}(b,z_2;\ell) \delta b\\
&+\left[\frac{\partial \psi}{\partial z_2}(b,z_2;\ell)\right]\left[-\rho(z_2-\mu_2)+\sqrt{z_2}\sigma_2 h_2\right] + \left[\frac{\partial^2 \psi}{\partial(z_2)^2}(b,z_2;\ell)\right]\left(\frac{z_2|\sigma_2|^2}{2}\right)
\end{align*}

FOC for $h_2$ gives:
$$
h^* = -\frac{\frac{\partial \psi}{\partial z_2}(b,z_2;\ell)\sqrt{z_2}\sigma_2}{b\xi_m}
$$

FOC for $e$ gives:
$$
e^* = \frac{b\delta \eta}{b\tau z_2 + \ell}
$$

Solve:

\begin{align*}
\frac{\color{red}{\psi_{i+1}(b,z_2;\ell)}-\color{blue}{\psi_{i}(b,z_2;\ell)}}{\epsilon} &= b\left[\delta \eta \log \color{blue}{e^*} - \tau z_2 \color{blue}{e^*} - \xi_m \frac{(\color{blue}{h_2^*})^2}{2}\right] - \ell \color{blue}{e^*} + \color{red}{\frac{\partial \psi}{\partial b}(b,z_2;\ell)} \delta b\\
&+\left[\color{red}{\frac{\partial \psi}{\partial z_2}(b,z_2;\ell)}\right]\left[-\rho(z_2-\mu_2)+\sqrt{z_2}\sigma_2 \color{blue}{h_2^*}\right] + \left[\color{red}{\frac{\partial^2 \psi}{\partial(z_2)^2}(b,z_2;\ell)}\right]\left(\frac{z_2|\sigma_2|^2}{2}\right)
\end{align*}

In [None]:
@njit(parallel=True, cache=True)
def solver(ψ_grid, b_grid, z_grid, ℓ, ϵ, τ, δ, η, ξ_m, μ_2, σ_2, ρ):
    n_b = len(b_grid)
    n_z = len(z_grid)
    Δ_b = b_grid[1] - b_grid[0]
    Δ_z = z_grid[1] - z_grid[0]
    LHS = np.zeros((n_b*n_z, n_b*n_z))
    RHS = np.zeros(n_b*n_z)
    for j in range(n_z):
        for i in range(n_b):
            idx = j*n_b + i
            idx_bp1 = idx + 1
            idx_bm1 = idx - 1
            idx_zp1 = (j+1)*n_b + i
            idx_zp2 = (j+2)*n_b + i
            idx_zm1 = (j-1)*n_b + i
            idx_zm2 = (j-2)*n_b + i
            ψ = ψ_grid[idx]
            z = z_grid[j]
            b = b_grid[i]
            LHS[idx, idx] += - 1./ϵ
            if i == 0:
                dψdb = (ψ_grid[idx_bp1]-ψ_grid[idx])/Δ_b
                e = b*δ*η/(b*τ*z+ℓ)
                LHS[idx, idx] += δ*b/Δ_b
                LHS[idx, idx_bp1] += -δ*b/Δ_b
            elif i == n_b-1:
                dψdb = (ψ_grid[idx]-ψ_grid[idx_bm1])/Δ_b
                e = b*δ*η/(b*τ*z+ℓ)
                LHS[idx, idx] += -δ*b/Δ_b
                LHS[idx, idx_bm1] += δ*b/Δ_b
            else:
                dψdb = (ψ_grid[idx_bp1]-ψ_grid[idx_bm1])/(2*Δ_b)
                e = b*δ*η/(b*τ*z+ℓ)
                LHS[idx, idx] += -δ*b/Δ_b
                LHS[idx, idx_bm1] += δ*b/Δ_b
            temp_2 = z*σ_2**2/2
            if j == 0:
                dψdz = (ψ_grid[idx_zp1]-ψ_grid[idx])/Δ_z
                h = -dψdz*np.sqrt(z)*σ_2/(b*ξ_m)
                temp_1 = -ρ*(z-μ_2) + np.sqrt(z)*σ_2*h
                LHS[idx, idx] += -temp_1/Δ_z + temp_2/(Δ_z**2)  
                LHS[idx, idx_zp1] += temp_1/Δ_z - temp_2*2/(Δ_z**2)
                LHS[idx, idx_zp2] += temp_2/(Δ_z**2)
            elif j == n_z-1:
                dψdz = (ψ_grid[idx]-ψ_grid[idx_zm1])/Δ_z
                h = -dψdz*np.sqrt(z)*σ_2/(b*ξ_m)
                temp_1 = -ρ*(z-μ_2) + np.sqrt(z)*σ_2*h
                LHS[idx, idx] += temp_1/Δ_z + temp_2/(Δ_z**2)
                LHS[idx, idx_zm1] += -temp_1/Δ_z - temp_2*2/(Δ_z**2)
                LHS[idx, idx_zm2] += temp_2/(Δ_z**2)            
            else:
                dψdz = (ψ_grid[idx_zp1]-ψ_grid[idx_zm1])/(2*Δ_z)
                h = -dψdz*np.sqrt(z)*σ_2/(b*ξ_m)
                temp_1 = -ρ*(z-μ_2) + np.sqrt(z)*σ_2*h
                LHS[idx, idx] += temp_1/Δ_z*(-1.*(temp_1>0)+(temp_1<0)) - temp_2*2/(Δ_z**2)
                LHS[idx, idx_zp1] += temp_1/Δ_z*(temp_1>0) + temp_2/(Δ_z**2)
                LHS[idx, idx_zm1] += -temp_1/Δ_z*(temp_1<0) + temp_2/(Δ_z**2)
            RHS[idx] = -(1./ϵ*ψ + b*(δ*η*np.log(e) - τ*z*e + ξ_m*h**2/2) - ℓ*e)
    ψ_grid = np.linalg.solve(LHS, RHS)
    return ψ_grid


@njit
def false_transient(ψ_grid, b_grid, z_grid, ℓ, ϵ, τ, δ, η, ξ_m, μ_2, σ_2, ρ, max_iter=10_000, tol=1e-9):
    error = 1.
    count = 0
    while error > tol and count < max_iter:
        ψ_grid_old = ψ_grid.copy()
        ψ_grid = solver(ψ_grid, b_grid, z_grid, ℓ, ϵ, τ, δ, η, ξ_m, μ_2, σ_2, ρ)
        error = np.max(np.abs(ψ_grid_old-ψ_grid))/ϵ
        count += 1
        print('Iteration:', count, ', error:', error)
    return ψ_grid

In [None]:
ϵ = 0.5
b_max = 1.0
z_max = 2.0
n_b = 200
n_z = 20
b_grid = np.linspace(1e-2, 1., n_b)
z_grid = np.linspace(1e-5, z_max, n_z)
ψ_grid = np.zeros(n_b*n_z) # initial guess

In [None]:
# log_ell_grid = np.linspace(log_ell_min, log_ell_max, grid_size)
log_ell_grid = np.linspace(-20, -5, 20)
ψ_grid_list = []
for i in range(20):
#     ψ_grid = np.load('res_guess.npy') # initial guess
    ψ_grid = np.zeros(n_b*n_z) # initial guess
    log_ell = log_ell_grid[i]
    ℓ = np.exp(log_ell)
    print(i)
    ψ_grid = false_transient(ψ_grid, b_grid, z_grid, ℓ, ϵ, τ, δ, η, ξ_m, μ_2, σ_2, ρ, max_iter=10_000, tol=1e-7)   
    np.save('res_'+str(i), ψ_grid)

In [None]:
ψ_grid = ψ_grid.reshape((n_r, n_z), order='F')

In [None]:
# Compute e_grid conditioned on z2
z_loc = 10
z = z_grid[z_loc]
Δ_r = r_grid[1] - r_grid[0]
dϕdr = np.zeros_like(r_grid) 
for i in range(n_r):
    if i == 0:
        dϕdr[i] = (ϕ[i+1, z_loc] - ϕ[i, z_loc])/Δ_r
    elif i == n_r-1:
        dϕdr[i] = (ϕ[i, z_loc] - ϕ[i-1, z_loc])/Δ_r
    else:
        dϕdr[i] = (ϕ[i+1, z_loc] - ϕ[i-1, z_loc])/(2*Δ_r)
e_grid = δ*η/(τ*z + dϕdr)

In [None]:
import matplotlib.pyplot as plt
# plt.plot(e_grid)
# plt.ylim(0, 20)

In [None]:
plt.plot(ϕ[:,0])
plt.ylim(0, 0.05)

In [None]:
ϕ[:, 10]

In [None]:
import pickle

In [None]:
with open('solu_modified_40200_0900', 'rb') as handle:
    b = pickle.load(handle)

In [None]:
ϕ_suri = b['phi']

In [None]:
plt.plot(ϕ_suri[0, :])
plt.ylim(0, 0.05)

In [None]:
ϕ_suri_new = ϕ_suri.T.reshape(-1, order='F')

In [None]:
res = solver(ϕ_suri_new, r_grid, z_grid, ϵ, τ, δ, η, ξ_m, μ_2, σ_2, ρ)

In [None]:
np.abs(np.max((res - ϕ_suri_new)/ϵ))