In [None]:
# generate a random pair q0,v0 (probably sliding on a surface)
# solve for x_ipopt using ipopt using explicit contact forces
# we want loss = (x_lag - x_ipopt)^2 == 0
# compute the gradient of the loss with respect to the parameters
# update the parameters using SGD
# repeat

In [1]:
using Revise

using RigidBodyDynamics
using RigidBodyDynamics: Bounds

using RigidBodyTreeInspector
using DrakeVisualizer
using Plots

using BilevelTrajOpt

In [2]:
urdf = joinpath("..", "urdf", "ball.urdf")
mechanism = parse_urdf(Float64, urdf)
body = findbody(mechanism, "ball")
basejoint = joint_to_parent(body, mechanism)
floatingjoint = Joint(basejoint.name, frame_before(basejoint), frame_after(basejoint), QuaternionFloating{Float64}())
replace_joint!(mechanism, basejoint, floatingjoint)
position_bounds(floatingjoint) .= Bounds(-100, 100)
velocity_bounds(floatingjoint) .= Bounds(-100, 100)
env = parse_contacts(mechanism, urdf, .5);
x0 = MechanismState(mechanism)
Δt = 0.01
sim_data = get_sim_data(x0,env,Δt,true);
xnext = MechanismState(mechanism);

In [18]:

N = 10
α_vect0 = [.9^i for i in 1:N]
c_vect0 = [1.5^i for i in 1:N]
I_vect0 = [.01 for i in 1:N]
z = vcat(α_vect0,c_vect0,I_vect0)

num_iter = 50
batch_size = 20
α_sgd = .9

for iter = 1:num_iter
    loss_batch = 0.
    g_batch = zeros(length(z))
    
    for j = 1:batch_size
        q0 = [1., 0., 0., 0., 0., 0., .0]
        v0 = vcat(zeros(3),rand(3)*2.-1.)
        u0 = zeros(sim_data.num_v)
        set_configuration!(x0, q0)
        set_velocity!(x0, v0)
        setdirty!(x0)

        traj = BilevelTrajOpt.simulate(x0,env,sim_data.Δt,1,implicit_contact=false)
        qnext = traj[1:sim_data.num_q,2]
        vnext = traj[sim_data.num_q+1:sim_data.num_q+sim_data.num_v,2]
        set_configuration!(xnext, qnext)
        set_velocity!(xnext, vnext)
        setdirty!(xnext)

        H = mass_matrix(x0)
        ϕs = Vector{Float64}(sim_data.num_contacts)
        Dtv = Matrix{Float64}(sim_data.β_dim,sim_data.num_contacts)
        for i = 1:sim_data.num_contacts
            v = point_velocity(twist_wrt_world(xnext,sim_data.bodies[i]), transform_to_root(xnext, sim_data.contact_points[i].frame) * sim_data.contact_points[i])
            Dtv[:,i] = map(sim_data.contact_bases[i]) do d
                dot(transform_to_root(xnext, d.frame) * d, v)
            end
            sim_data.rel_transforms[i] = (relative_transform(xnext, sim_data.obstacles[i].contact_face.outward_normal.frame, sim_data.world_frame),
                                      relative_transform(xnext, sim_data.contact_points[i].frame, sim_data.world_frame))
            sim_data.geo_jacobians[i] = geometric_jacobian(xnext, sim_data.paths[i])
            ϕs[i] = separation(sim_data.obstacles[i], transform(xnext, sim_data.contact_points[i], sim_data.obstacles[i].contact_face.outward_normal.frame))
        end
        HΔv = H * (vnext - v0)
        bias = u0 .- dynamics_bias(xnext)
    
        τ_ip, x_sol_ip = solve_implicit_contact_τ(sim_data,ϕs,Dtv,HΔv,bias,ip_method=true);
        τ_auglag, x_sol_auglag = solve_implicit_contact_τ(sim_data,ϕs,Dtv,HΔv,bias,ip_method=false,α_vect=z[1:N],c_vect=z[N+1:2*N],I_vect=z[2*N+1:3*N]);    
        J = ForwardDiff.jacobian(z̃ -> solve_implicit_contact_τ(sim_data,ϕs,Dtv,HΔv,bias,ip_method=false,α_vect=z̃[1:N],c_vect=z̃[N+1:2*N],I_vect=z̃[2*N+1:3*N])[2], z)
    
        loss = .5 * dot(x_sol_ip - x_sol_auglag, x_sol_ip - x_sol_auglag)
        
        loss_batch += loss
        g_batch += J'*(x_sol_ip - x_sol_auglag)
    end
    loss_batch /= batch_size
    g_batch /= batch_size
    
    z = z + α_sgd^iter*g_batch
    z = min.(1000.,max.(.001, z))
    
    println(loss_batch)
    println(z[1:N])
    println(z[N+1:2*N])
    println("***")
end

5.24688194951384
[0.001, 0.001, 30.4643, 0.001, 7.68541, 1.46014, 0.001, 0.522184, 1.44174, 1.63251]
[0.001, 0.001, 4.05832, 5.48041, 7.71202, 11.4521, 17.0818, 25.6269, 38.4428, 57.6648]
***
387.83901008311534
[17.7787, 17.7672, 7.8518, 0.156293, 11.8521, 4.54026, 2.04035, 2.85993, 4.32844, 4.5445]
[0.001, 0.001, 1.72076, 5.47974, 7.70045, 11.4431, 17.076, 25.6216, 38.4393, 57.6633]
***
22.49711418170737
[18.7532, 15.9396, 8.64959, 0.855179, 12.3024, 4.66828, 3.88538, 5.21047, 1.72293, 2.13878]
[0.382807, 0.001, 0.97656, 4.73482, 7.5298, 11.5706, 17.0597, 25.5639, 38.489, 57.6546]
***
13.041188695119734
[20.6796, 12.5002, 9.43916, 1.83976, 12.864, 3.05874, 4.35585, 4.99257, 1.27562, 1.50073]
[0.50896, 0.00824261, 0.992651, 4.73514, 7.55404, 11.6108, 17.0506, 25.5257, 38.4894, 57.6549]
***
12.653162026436245
[17.435, 13.103, 11.2461, 3.16321, 15.2128, 3.25981, 4.54581, 5.22589, 1.15166, 1.34512]
[0.883109, 0.001, 0.962598, 4.72781, 7.50781, 11.5787, 17.031, 25.518, 38.4888, 57.6553]
**