In [None]:
using DifferentialEquations
using Plots
include("NN_solver.jl")

In [None]:
default(size = (400, 300), linewidth=3, markersize=5, 
        markerstrokewidth=0)

In [None]:
function lotka_volterra!(t, u, du)
    du[1] = 1.5 .* u[1] - 1.0 .* u[1].*u[2]
    du[2] = -3 .* u[2] + u[1].*u[2]
end

y0_list = [1.0, 1.0]
tspan = (0.0,5.0)
t = collect(linspace(tspan[1], tspan[2], 51))
t = reshape(t, 1, :) # training points

In [None]:
prob = ODEProblem(lotka_volterra!, y0_list, tspan)
sol = solve(prob, saveat=0.05, reltol=1e-6, abstol=1e-6);

In [None]:
nn = init_nn(lotka_volterra!, t, y0_list, n_hidden = 20);
show(nn) # print basic info

In [None]:
p_BFGS = readdlm("weights_BFGS.txt")
p_GD = readdlm("weights_GD.txt")

l_BFGS = readdlm("loss_BFGS.txt")
l_GD = readdlm("loss_GD.txt")

size(p_BFGS), size(l_BFGS)

In [None]:
function quickplot(t, p)
    nn.params_list = get_unflat(p[t,:], nn)
    y_pred_list,_ = predict(nn)
    y_pred_list, _ = predict(nn)
    
    plot(nn.t[:], y_pred_list[1][:], label="y1 NN", lw=0, marker=:circle, legend = :topleft)
    plot!(nn.t[:], y_pred_list[2][:], label="y2 NN", lw=0, marker=:circle)
    
    plot!(sol.t, sol[1,:], label="y1 true")
    plot!(sol.t, sol[2,:], label="y2 true")


    ylims!(0, 8)
    xlabel!("t")
end

In [None]:
quickplot(350, p_BFGS)
title!(@sprintf("BFGS; iter=%d; loss=%.2e",150,l_BFGS[150]))

## Time-series

In [None]:
for i=1:1001
    if (i%50 == 0) print(i," ") end
    quickplot(i, p_BFGS)
    title!(@sprintf("BFGS; iter=%d; loss=%.2e", i-1, l_BFGS[i]))
    savefig("./figures/BFGS_"*lpad(i,3,0)*".pdf")
end

In [None]:
for i=1:1001
    if (i%50 == 0) print(i," ") end
    quickplot(i, p_GD)
    title!(@sprintf("GD; iter=%d; loss=%.2e", i-1, l_GD[i]))
    savefig("./figures/GD_"*lpad(i,3,0)*".pdf")
end