In [None]:
using Pkg
pkg"activate ."

In [None]:
using PerlaTonettiWaugh, LinearAlgebra, Plots, BenchmarkTools, Interpolations, QuadGK, DifferentialEquations, BlackBoxOptim, Optim, Dierckx

Set up parameters and find the corresponding stationary solution:

In [None]:
z_min = 0.0 
z_max = 5.0
M = 1000
z_grid = range(z_min, stop = z_max, length = M) # Since we only care about the grid. 

# Define common objects. 
d_0 = 5
d_T = 2.3701
params = (ρ = 0.02, σ = 4.2508, N = 10, θ = 5.1269, γ = 1.00, κ = 0.013, ζ = 1, η = 0, Theta = 1, χ = 1/(2.1868), υ = 0.0593, μ = 0, δ = 0.053) # Baselines per Jesse. 
δ = params.δ
# solve for stationary solution at t = 0
params_0 = merge(params, (d = d_0,)) # parameters to be used at t = 0
params_T = merge(params, (d = d_T,)) # parameters to be used at t = T

stationary_sol_0 = stationary_numerical(params_0, z_grid) # solution at t = 0
stationary_sol = stationary_numerical(params_T, z_grid) # solution at t = T

Ω_0 = stationary_sol_0.Ω
Ω_T = stationary_sol.Ω
settings = (z = z_grid, tstops = nothing, Δ_E = 1e-06)

println("g = $(stationary_sol.g), z_hat = $(stationary_sol.z_hat), Ω = $(stationary_sol.Ω)")

Define the objective function:

In [None]:
function solve_with_candidate(candidate)
    candidate = [candidate...] # if candidate is a tuple, convert it to an array
    T = candidate[end]
    
    candidate = [sort(candidate[1:(end-1)]); 0.0] # fix the point at T to be zero
    
    # construct Ω and E
    E_hat_vec_range = candidate[end] - candidate[1]
    E_hat_vec_scaled = (candidate .- candidate[1]) ./ E_hat_vec_range .- 1.0 
    ts = range(0.0, stop=T, length=length(candidate))
    E_hat_interpolation = CubicSplineInterpolation(ts, E_hat_vec_scaled) # might worth trying cubic spline
    E_hat(t) = E_hat_interpolation(t)

    M = log(Ω_T/Ω_0) / quadgk(E_hat, 0, T)[1]
    Ω_derivative(Ω,p,t) = M*E_hat(t)*Ω
    Ω_solution = try DifferentialEquations.solve(ODEProblem(Ω_derivative,Ω_0,(0.0, T)), reltol = 1e-15) catch; return Inf end
    Ω(t::Float64) = t <= T ? Ω_solution(t) : Ω_solution(T)
    E(t::Float64) = M*E_hat(t) + δ

    # solve the dynamics and get the resulting entry_residual vector; if solution is not valid, return Inf
    return solve_dynamics(params_T, stationary_sol, settings, T, Ω, E; detailed_solution = true)
end

function evaluate_candidate(candidate)
    candidate = [candidate...] # if candidate is a tuple, convert it to an array
    T = candidate[end]
    candidate = [sort(candidate[1:(end-1)]); 0.0] # fix the point at T to be zero
    
    # construct Ω and E
    E_hat_vec_range = candidate[end] - candidate[1]
    E_hat_vec_scaled = (candidate .- candidate[1]) ./ E_hat_vec_range .- 1.0 
    ts = range(0.0, stop=T, length=length(candidate))
    E_hat_interpolation = Spline1D(ts, E_hat_vec_scaled; k = 3) # cubic spline
    
    M = log(Ω_T/Ω_0) / integrate(E_hat_interpolation, 0, T)
    Ω_derivative(Ω,p,t) = M*E_hat_interpolation(t)*Ω
    Ω_solution = try DifferentialEquations.solve(ODEProblem(Ω_derivative,Ω_0,(0.0, T)), reltol = 1e-15) catch; return Inf end
    Ω = t -> t <= T ? Ω_solution(t) : Ω_solution(T)
    E = t -> M*E_hat_interpolation(t) + δ

    # solve the dynamics and get the resulting entry_residual vector; if solution is not valid, return Inf
    solved = try solve_dynamics(params_T, stationary_sol, settings, T, Ω, E; detailed_solution = false).results catch; return Inf end
    
    t = solved.t
    entry_residual = solved.entry_residual

    # interpolate on returned entry_residual
    entry_residual_interpolated = LinearInterpolation(t, entry_residual)

    # evaluate entry_residual on entry_residual_nodes, return the norm
    entry_residuals_nodes = range(0, stop = T, length = ENTRY_RESIDUALS_NODES_COUNT + 2)
    
    entry_residuals_vector = entry_residual_interpolated.(entry_residuals_nodes[2:(end-1)])
    
    return (sqrt(sum(entry_residuals_vector .* WEIGHTS .* entry_residuals_vector)))
end

Setup for optimizer:

In [None]:
ENTRY_RESIDUALS_NODES_COUNT = 15
E_NODE_COUNT = 15
MAX_TIME = 600
SOLUTION = [-0.999796577020432; -0.8036317953452055; -0.645518803247465; -0.49798592837781547; -0.3129924179644726; -0.2924929980811199; -0.19484667032438624; -0.10940512894596714; -0.09952048306670645; -0.09391117400750042; -0.04912000885073606; -0.03274479455962143; -0.016461404893110952; -0.011579533358115129; 34.61456648920588]

Find optimal `Ω`:

In [None]:
# WEIGHTS = [2; fill(1.5, 3); fill(1, ENTRY_RESIDUALS_NODES_COUNT-4)] # reducing residuals in the beginning helps overall behav.
# result2 = optimize(evaluate_candidate, SOLUTION, BFGS(), Optim.Options(time_limit = MAX_TIME, show_trace = true))
# SOLUTION = result.minimizer

Find the corresponding solution

In [None]:
# [-0.999796577020432; -0.8036317953452055; -0.645518803247465; -0.49798592837781547; -0.3129924179644726; -0.2924929980811199; -0.19484667032438624; -0.10940512894596714; -0.09952048306670645; -0.09391117400750042; -0.04912000885073606; -0.03274479455962143; -0.016461404893110952; -0.011579533358115129; 34.61456648920588]
# gives good solution.
solved = solve_with_candidate(SOLUTION)

v_t0 = solved.sol.u[1][1:M]
v0 = solved.results[:v_0]
# save v0 and v_hat_t0
v_hat_t0 = map(z -> exp((params.σ-1)*z), z_grid) .* v_t0;

solved = solved.results

## Plots for Ω and `entry_residuals`

In [None]:
plot_Ω = plot(solved.t, solved.Ω, label = "Omega", lw = 3)
plot_residual = plot(solved.t, solved.entry_residual, label = "entry_residual", lw = 3)
plot(plot_Ω, plot_residual, layout = (2,1))

## Primary Plots

In [None]:
plot1 = plot(solved.t, solved.g, label = "g", lw = 3)
plot2 = plot(solved.t, solved.z_hat, label = "z_hat", lw = 3)
plot3 = plot(solved.t, solved.S, label = "S", lw = 3)
plot4 = plot(solved.t, solved.entry_residual, label = "entry_residual", lw = 3)
plot(plot1, plot2, plot3, plot4, layout=(2,2))

## Static Equations

In [None]:
plot1 = plot(solved.t, solved.L_tilde, label = "L_tilde", lw = 3)
plot2 = plot(solved.t, solved.z_bar, label = "z_bar", lw = 3)
plot3 = plot(solved.t, solved.π_min, label = "pi_min", lw = 3)
plot4 = plot(solved.t, solved.λ_ii, label = "lambda_ii", lw = 3)
plot5 = plot(solved.t, solved.c, label = "c", lw = 3)
plot6 = plot(solved.t, solved.E, label = "E", lw = 3)
plot7 = plot(solved.t, solved.log_M, label = "log_M", lw = 3)
plot8 = plot(solved.t, solved.U, label = "U", lw = 3)
plot9 = plot(z_grid, v_hat_t0, label = "v_hat at t = 0", lw = 3)
plot(plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, layout=(4,2))

In [None]:
# Can examine the returned data with the Voyager and/or Vegalite
using DataVoyager, VegaLite
solved |> Voyager()
#solved |> @vlplot(:line, x = :t, y = :g, width=400, height=400)