In [None]:
using DifferentialEquations, Flux, Zygote, LinearAlgebra, ForwardDiff, FiniteDiff, Distributions, Plots, Random


In [None]:
using LNNProject
Threads.nthreads()

## Demonstration of the framework

## Analytical

In [None]:
function analytical_RHS(du, u, temp, t=0)
    t1, t2, w1, w2 = u[1], u[2], u[3], u[4]
    m1, m2, l1, l2, g = (1, 1, 1, 1, 9.81)
    a1 = (l2 / l1) * (m2 / (m1 + m2)) * cos(t1 - t2)
    a2 = (l1 / l2) * cos(t1 - t2)
    f1 = -(l2 / l1) * (m2 / (m1 + m2)) * (w2^2) * sin(t1 - t2) - (g / l1) * sin(t1)
    f2 = (l1 / l2) * (w1^2) * sin(t1 - t2) - (g / l2) * sin(t2)
    g1 = (f1 - a1 * f2) / (1 - a1 * a2)
    g2 = (f2 - a2 * f1) / (1 - a1 * a2)
    du .= [w1, w2, g1, g2]
end
function analytical_Lagrangian(x)
    m1, m2, l1, l2, g = (1,1,1,1,9.81)
    θ1, θ2, v1, v2 = x[1], x[2], x[3], x[4]
    term1 = 0.5*(m1+m2)*l1^2*v1^2
    term2 = 0.5*m2*l2^2*v2^2
    term3 = m2*l1*l2*v1*v2*cos(θ1-θ2)
    term4 = (m1+m2)*g*l1*cos(θ1)
    term5 = m2*g*l2*cos(θ2)
    return term1 + term2 + term3 + term4 + term5
end

In [None]:
function energy(q, q_dot, m1, m2, l1, l2, g)
  t1, t2 = q     # theta 1 and theta 2
  w1, w2 = q_dot # omega 1 and omega 2

  # kinetic energy (T)
  T1 = 0.5 * m1 * (l1 * w1)^2
  T2 = 0.5 * m2 * ((l1 * w1)^2 + (l2 * w2)^2 +
                    2 * l1 * l2 * w1 * w2 * jnp.cos(t1 - t2))
  T = T1 + T2
  
  # potential energy (V)
  y1 = -l1 * jnp.cos(t1)
  y2 = y1 - l2 * jnp.cos(t2)
  V = m1 * g * y1 + m2 * g * y2

  return T + V
end

$$
\begin{align}
L = & \frac{1}{2}(m_1 + m_2) l_1^2 \dot{\theta}_1^2 +
	\frac{1}{2}m_2 l_2^2 \dot{\theta}_2^2 + m_2l_1l_2\dot{\theta}_1\dot{\theta}_2
	\cos(\theta_1 - \theta_2)\nonumber\\[3pt]
     &+ (m_1 + m_2) g l_1 \cos\theta_1 + m_2 g l_2\cos\theta_2
\end{align}
$$

In [None]:
function analytical_sol(x_0, saveat)
    prob = ODEProblem(analytical_RHS, x_0, (0.0, saveat[end]), [0.], saveat=saveat)
    sol = solve(prob, Tsit5())
    data = sol.u'
    return saveat, data
end



## Demonstration of our code
### First generate LNN

In [None]:
model = Chain(
    Dense(4, 16, softplus),        
    Dense(16, 4, sigmoid), 
    Dense(4,1)
)
p, res = Flux.destructure(model)

LNN = NeuralLagrangian(model, [0.0, 0.1], saveat=0:0.01:0.10, dt = 0.001, adaptive = false)


In [None]:
# First run precompiles the derivatives... Sloooow. Maybe add a single run to object generation?
LNN([Float32(0.1), Float32(0.1), Float32(0.0), Float32(0.0)], Euler(), p)


### Training


In [None]:
function cost(p, x_0)
    sol = LNN(x_0, Heun(), p)
    L_dat = sol.u'
    ts, true_dat = analytical_sol(x_0, LNN.kwargs[:saveat])
    return norm(L_dat .- true_dat,2) + 0.1*norm(p,2)
end

In [None]:
opt = ADAM(0.001)
#sol = LNN(x_0, Euler(), p)
Epochs = 10
lossvec = zeros(Epochs)
threshold = 4

batch_size = 11
grads = Vector(undef, batch_size)
# Generate random initial conditions between -pi and pi

for i in 1:Epochs
    x_0s = [0.5*pi.-rand(Float32, 4).*pi/2 for j in 1:batch_size]
    
    Threads.@threads for j in 1:batch_size
        grads[j] = FiniteDiff.finite_difference_gradient((p) -> cost(p, x_0s[j]), p)
    end

    avg_grad = sum(grads)/batch_size
    norm_grad = norm(avg_grad, 2)
    if norm_grad > threshold
        @show norm_grad
        avg_grad = threshold.*avg_grad./norm_grad
    end
    old_p = deepcopy(p)
    Flux.update!(opt, p, avg_grad)
    @show cost(p, x_0s[1])
    lossvec[i] = cost(p, x_0s[1])
end
ts, true_dat = analytical_sol(x_0, LNN.kwargs[:saveat])


lastsol = LNN(x_0, RK4(), p)
#plot(lastsol, label="Learned")


prob = ODEProblem(analytical_RHS, x_0, (0.0, 0.1), 0, saveat=LNN.kwargs[:saveat])
sol_analytical = solve(prob, Tsit5())
data = sol_analytical.u
plot(sol_analytical, label="Analytical", color="black")
plot!(lastsol, label="Learned", color = "red")

# ts, dat = anal_sol(x_0)
#display(plot(ts, L_dat, label=label=[raw"$\theta_1$" raw"$\theta_2$" raw"$\dot{\theta_1}$" raw"$\dot{\theta_2}$" ], 
 #       xlabel=raw"$t$", title="After training"))
#plot(ts, dat, label=[raw"$\theta_1^a$" raw"$\theta_2^a$" raw"$\dot{\theta_1^a}$" raw"$\dot{\theta_2^a}$" ])
#pltt = plot!(ts, L_dat, title="Comparison of learned and analytical solution", label=[raw"$\theta_1^L$" raw"$\theta_2^L$" raw"$\dot{\theta_1^L}$" raw"$\dot{\theta_2^L}$" ],
#xlabel=raw"$t$")

