In [None]:
] activate .

In [None]:
] st

In [None]:
using StaticArrays, DifferentialEquations, ForwardDiff, Optimization, OptimizationOptimJL, Zygote, Plots

In [None]:
function f(t, x, u)
    return @SVector [x[2], u[1]]
end

In [None]:
function l(x)
    # l(x) >= 0 is the safe set
    return 1.5 - x[1] 
end

In [None]:
function H(t, x, V, DxV) 
    u2 = @SVector [-1.0]
    H2 = DxV' * f(t, x, u2)
    γ = 1.0
    return min(0, H2 + γ*V)
    # return  H2 + γ*V
end

function DsH(t, x, V, DxV)
    return ForwardDiff.gradient(s-> H(t, x, V, s), DxV)
end



function DxH(t, x, V, DxV)
    return ForwardDiff.gradient(xx-> H(t, xx, V, DxV), x)
end



function DzH(t, x, V, DxV)
    return ForwardDiff.derivative(z-> H(t, x, z, DxV), V)
end

In [None]:
DsH(0.0, [1.0, 1.0], 1.0, [1.0, 1.0])

In [None]:
DxH(0.0, [1.0, 1.0], 1.0, [1.0, 1.0])

In [None]:
DzH(0.0, [1.0, 1.0], 1.0, [1.0, 1.0])

In [None]:
function characteristics_system!(D, X, params, t)
    # decompose the system first
    xinds = 1:2
    sinds = 3:4
    zind = 5
    zinds = 5:5
    
    x = X[xinds]
    s = X[sinds]
    z = X[zind]

    # update the values
    D[xinds] = DsH(t, x, z, s)
    D[sinds] = -DxH(t, x, z, s) - DzH(t, x, z, s) * s
    D[zinds] .= s' * DsH(t, x, z, s) - H(t, x, z, s)
    # @show D[5]
    # @show H(t, x, z, s)
    # @show s' * DsH(t, x, z, s)

    return 
end
    

In [None]:
x0 = [1.0, 3.0]
s0 = [1.0, 1.0]
z0 = 0.0
X0 = vcat(x0, s0, z0)
D0 = similar(X0)

In [None]:
characteristics_system!(D0, X0, nothing, 0.0)

In [None]:
D0

In [None]:
# function objective(x0, s0; tspan=(0.0, 2.0))
#     X0 = vcat(x0, s0, 0.0)
#     prob = ODEProblem(characteristics_system!, X0, tspan)
#     sol = solve(prob)
#     # @show sol
#     xend = sol.u[end][1:2]
#     zend = sol.u[end][5]

#     return sol, l(xend) + zend

# end

In [None]:
prob = ODEProblem(characteristics_system!, X0, (0.0, 2.0))

In [None]:
function loss(p)
    x0 = [1.0, 3.0]
    p0 = p
    z0 = 0.0
    X0 = vcat(x0, p0, z0)

    # resolve the problem
    sol = solve(prob, Tsit5(), u0 = X0)

    # extract solution
    xend = sol.u[end][1:2]
    zend = sol.u[end][5]
    
    loss = l(xend) + zend

    return loss, sol
end

callback = function (p, l, pred)
    display(l)
    plt = plot(pred)
    display(plt)
    # Tell Optimization.solve to not halt the optimization. If return true, then
    # optimization stops.
    return false
end

  

In [None]:
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)


In [None]:
optprob = Optimization.OptimizationProblem(optf, [0.0, 0.0])

In [None]:
# Import a solver package and solve the optimization problem


In [None]:
result_ode = Optimization.solve(optprob, NelderMead();
    callback = callback,
    maxiters = 100)  

In [None]:
sol, J = objective(x0, s0)

In [None]:
J

In [None]:
sol