In [1]:
using Plots
using LinearAlgebra
using Revise
using NeuralVerification
using NeuralVerification:Network, Layer, ReLU, Id, read_nnet, compute_output
using LazySets
using Random
using BlackBoxOptim
using Statistics
include("unicycle_env.jl")
include("controller.jl")
include("problem.jl")
include("safe_set.jl")

┌ Info: Precompiling NeuralVerification [146f25fa-00e7-11e9-3ae5-fdbac6e12fa7]
└ @ Base loading.jl:1278
│ - If you have NeuralVerification checked out for development and have
│   added CPLEX as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with NeuralVerification


grad_phi (generic function with 2 methods)

In [2]:
net_path = "../nnet/unicycle-FC3-100-rk4/epoch_1000.nnet"
net = read_nnet(net_path);
obs_radius = 0.5

0.5

In [3]:
function generate_moving_target(;fps=10, tf=2, v=nothing, v_lim=0.5)
    T = tf*fps
    v = isnothing(v) ? [rand(), rand()]*v_lim*2 .- v_lim : v
    p = [0, 1.5]
    return [Obstacle(p+v*(i/fps), v, obs_radius) for i in 0:T-1]
end

function get_Xref(x0, xg, T, dt)
    tf = T*dt
    dp = [xg[1]-x0[1], xg[2]-x0[2]]
    da = xg[4]-x0[4]
    a = atan(dp[2], dp[1])
    v = norm(dp)/tf
    v = max(min(v, 1),-1)
    vx = v * cos(a)
    vy = v * sin(a)
    Xref = [[x0[1]+vx*k*dt, x0[2]+vy*k*dt, v, a] for k = 1:T]
    Xref[end][3] = 0
    return Xref
end

function tracking(rp::RP, ctrl; fps=10, tf=2, obstacles=nothing, safety_index=nothing, verbose=false)
    T=Int(ceil(fps*tf))
    dt=1.0/fps
    x = rp.x0
    X = [copy(rp.x0) for k = 1:T]
    U = [zeros(2) for k = 1:T-1]
    safe_sets = []
    Xrefs = []
    tot_time = 0
    col_cnt = 0
    infeas=false
    for i in 1:T-1
        Xref = get_Xref(x, rp.xg, fps, dt)
        xref = Xref[1]
        push!(Xrefs, Xref)
        timed_result = @timed get_control(ctrl, xref, x, rp.net, rp.obj_cost, dt, obstacles=obstacles, safety_index=safety_index)
        u, safe_set = timed_result.value
        if isnothing(u)
            u = i == 1 ? [0.,0.] : U[i-1]
            infeas=true
#             return X[1:i], U[1:i-1], safe_sets, Xrefs, true
        end
        push!(safe_sets, safe_set)
        tot_time += timed_result.time
        if verbose
            @show x
            @show xref
            @show u
            p = phi(x, obstacle)
            @show p
        end
        x = forward(rp.net, x, u, dt)
        X[i+1] = x
        U[i] = u
        if norm(x[1:2] - rp.xg[1:2]) < 0.1
            return X[1:i+1], U[1:i], safe_sets[1:i], Xrefs[1:i], infeas
        end
    end
    push!(Xrefs, Xrefs[end])
    return X, U, safe_sets, Xrefs, infeas
end

tracking (generic function with 1 method)

In [4]:
function collision_samples()
    nx = 20
    ny = 20
    nv = 10
    nt = 10
    xs = range(0,stop=5,length=nx)
    ys = range(0,stop=5,length=ny)
    vs = range(-2,stop=2,length=nv)
    θs = range(-π,stop=π,length=nt)
    samples = [([x,y,v,θ],[Obstacle([0.0, 0.0],[0,0],obs_radius)]) for x in xs, y in ys, v in vs, θ in θs];
    return samples
end
col_samples = collision_samples();

In [5]:
function exists_valid_control(safety_index, ctrl::ShootingController, x, obs, net, dt)
    safe_set = phi_safe_set(safety_index, x, obs, dt)
    for j in 1:ctrl.num_sample
        u_cand = rand(2) .* ctrl.u_lim * 2 - ctrl.u_lim
        dot_x_cand = compute_output(net, [x; u_cand])
        dot_x_cand ∈ safe_set && (return true)
    end
    return false
end

function eval_collision_index(coes)
    margin, gamma, phi_power, dot_phi_coe = coes
    index = CollisionIndex(margin, gamma, phi_power, dot_phi_coe)
    valid = 0
    net_path = "../nnet/unicycle-FC3-100-rk4/epoch_1000.nnet"
    net = read_nnet(net_path);
    dt = 0.1
    for sample in col_samples
        x, obs = sample
        if norm(x[1:2]) < 1e-8 # overlaped with the obstacle
            valid += 1
            continue
        end
        ctrl = ShootingController(1000)
        valid += exists_valid_control(index, ctrl, x, obs, net, dt)
    end
    return Float64(length(col_samples)-valid)
end

eval_collision_index (generic function with 1 method)

In [6]:
function draw_heat_plot(coes)
    margin, gamma, phi_power, dot_phi_coe = coes
    index = CollisionIndex(margin, gamma, phi_power, dot_phi_coe)
    valid = 0
    net_path = "../nnet/unicycle-FC3-100-rk4/epoch_1000.nnet"
    net = read_nnet(net_path);
    dt = 0.1
    for sample in col_samples
        x, obs = sample
        if norm(x[1:2]) < 1e-8 # overlaped with the obstacle
            valid += 1
            continue
        end
        ctrl = ShootingController(1000)
        valid += exists_valid_control(index, ctrl, x, obs, net, dt)
    end
    return Float64(length(col_samples)-valid)
end

function find_infeas_states(coes)
    margin, gamma, phi_power, dot_phi_coe = coes
    index = CollisionIndex(margin, gamma, phi_power, dot_phi_coe)
    valid = 0
    net_path = "../nnet/unicycle-FC3-100-rk4/epoch_1000.nnet"
    net = read_nnet(net_path);
    dt = 0.1
    infeas_states = Dict()
    infeas_map = zeros(size(col_samples)[1:2])
    for (idx, sample) in pairs(col_samples)
        x, obs = sample
        if norm(x[1:2]) < 1e-8 # overlaped with the obstacle
            valid += 1
            continue
        end
        ctrl = ShootingController(1000)
        feas = exists_valid_control(index, ctrl, x, obs, net, dt)
        valid += feas
        feas && continue
        haskey(infeas_states, (idx[1], idx[2])) || (infeas_states[(idx[1], idx[2])] = [])
        push!(infeas_states[(idx[1], idx[2])], sample)
        infeas_map[idx[1], idx[2]] += 1
    end
    return Float64(length(col_samples)-valid), infeas_states, infeas_map
end

find_infeas_states (generic function with 1 method)

In [8]:
function ctrl_collision_stat(num, ci; ctrl=nothing, verbose=false, min_invas=false)
    Random.seed!(127)
    success = 0
    phi0_vio_cnt = 0
    infeas_cnt = 0
    
    n = 2
    dt = 0.1
    obj_cost = [1,1,1,0.1]

    net_path = "../nnet/unicycle-FC3-100-rk4/epoch_1000.nnet"
    net = read_nnet(net_path);

    safeties = []
    x0_safeties = []
    objs = []
    unsafe_cnt = 0
    next_safe_cnt = 0

    seed = 0
    while unsafe_cnt < num
        seed += 1
        Random.seed!(seed)
        
        x0 = [0,-1.5,1+rand(),π/2+rand()*π/2-π/4]
        # @show x0
        u_ref = rand(2) .* ctrl.u_lim * 2 - ctrl.u_lim
        x_ref = forward(x0, u_ref, dt)
        
        obstacles = [Obstacle([0, rand()-0.5], [0, 0], obs_radius)]

        x0_safety = max(0, maximum([phi(ci, x0, obs) for obs in obstacles]))
        unsafe_cnt += (x0_safety > 0)

        if x0_safety <= 0
            continue
        end

        # if unsafe_cnt != 7
        #     continue
        # end
        
        u, safe_set = get_control(ctrl, x_ref, x0, net, obj_cost, dt; obstacles=obstacles, safety_index=ci, u_ref=min_invas ? u_ref : nothing)
        u = isnothing(u) ? (isnothing(u_ref) ? [0., 0.] : u_ref) : u
        x = forward(x0, u, dt)

        push!(objs, quad_cost(x, x_ref, obj_cost))
        push!(x0_safeties, x0_safety)
        push!(safeties, max(0, maximum([phi(ci, x, obs) for obs in obstacles])))
        
        safety = maximum([phi(ci, x, obs) for obs in obstacles])
        # Expect: safety < x0_safety - ci.gamma * dt
        if safety < x0_safety - ci.gamma * dt
            next_safe_cnt += 1
        end

        if verbose
            # @show j, success, phi0_vio_cnt, infeas_cnt
        end
    end
    @show mean(objs)
    @show mean(safeties)
    @show mean(x0_safeties)
    @show next_safe_cnt
    # @show objs
    # @show safeties
    # @show x0_safeties

end
ci = CollisionIndex(0.1, 0.1, 2, 0.1)
# ci = CollisionIndex(0.728854, 0.00915763, 0.277308, 1.05736)
include("controller.jl")
include("safe_set.jl")

grad_phi (generic function with 2 methods)

In [9]:
ci.gamma

0.1

In [10]:
# println("AdamBA 1")
# t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(1), min_invas=true);
# @show t.time
println("ShootingController 100")
t = @timed ctrl_collision_stat(100, ci, ctrl = ShootingController(100), min_invas=true);
@show t.time
println("ShootingController 1000")
t = @timed ctrl_collision_stat(100, ci, ctrl = ShootingController(1000), min_invas=true);
@show t.time
println("ShootingController 10000")
t = @timed ctrl_collision_stat(100, ci, ctrl = ShootingController(10000), min_invas=true);
@show t.time
println("ShootingController 100000")
t = @timed ctrl_collision_stat(100, ci, ctrl = ShootingController(100000), min_invas=true);
@show t.time

ShootingController 100
mean(objs) = 0.03334411155613559
mean(safeties) = 1.1701775101609142
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 3.561472619
ShootingController 1000
mean(objs) = 0.027326816053423023
mean(safeties) = 1.1845910940582036
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 1.339102973
ShootingController 10000
mean(objs) = 0.02630976058295553
mean(safeties) = 1.1849330222060155
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 13.673118192
ShootingController 100000
mean(objs) = 0.026045426006856803
mean(safeties) = 1.1854599788954665
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 138.085298487


138.085298487

In [45]:
# println("AdamBA 1")
# t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(1), min_invas=true);
# @show t.time
println("AdamBA 10")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(10), min_invas=true);
@show t.time
println("AdamBA 100")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(100), min_invas=true);
@show t.time
println("AdamBA 1000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000), min_invas=true);
@show t.time

AdamBA 10
mean(objs) = 0.03408957567400331
mean(safeties) = 1.6687598766227736
mean(x0_safeties) = 1.7831469515407297
next_safe_cnt = 1
objs = Any[0.03408957567400331]
safeties = Any[1.6687598766227736]
t.time = 0.550524278
AdamBA 100
mean(objs) = 0.03222553108966779
mean(safeties) = 1.6731558938396147
mean(x0_safeties) = 1.7831469515407297
next_safe_cnt = 1
objs = Any[0.03222553108966779]
safeties = Any[1.6731558938396147]
t.time = 0.022222433
AdamBA 1000
mean(objs) = 0.0324896506381907
mean(safeties) = 1.672538556688628
mean(x0_safeties) = 1.7831469515407297
next_safe_cnt = 1
objs = Any[0.0324896506381907]
safeties = Any[1.672538556688628]
t.time = 0.223682788


0.223682788

In [46]:
println("AdamBA 10000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(10000), min_invas=true);
@show t.time

AdamBA 10000
mean(objs) = 0.03244178420787575
mean(safeties) = 1.67265058518851
mean(x0_safeties) = 1.7831469515407297
next_safe_cnt = 1
objs = Any[0.03244178420787575]
safeties = Any[1.67265058518851]
t.time = 2.522140925


2.522140925

In [9]:
# a = [0.9240933406713393, 1.1285852758920145, 0.8150766026065783, 1.1016071140982289, 0.9882358314347126, 0.9434793868219085, 1.6725906561770425, 0.6113902337221013, 1.2306118525765102, 0.9593176997918553, 1.3930823866083786, 0.7847814372365992, 0.7885262427228066, 0.8899898196912508, 1.1194212187150119, 1.5175075937045106, 0.8971530299315629, 1.34175751737654, 1.4638389283868207, 1.0820244956487364, 0.7754372657459034, 0.8114512502690528, 1.2895380433196393, 1.0988915200958749, 1.1702414723438923, 1.4616126101850966, 1.2125007576638158, 0.872865063255638, 1.9386641044227118, 0.9226131040466722, 0.7129903892540643, 1.6654108199133946, 1.5549876563056646, 1.3873588806901178, 1.4765811650973717, 0.9197246016147433, 1.1921712994871199, 1.610362010277241, 0.665582346161891, 0.9574725754060804, 1.2270736211333884, 1.039323539836434, 0.9335841640153117, 0.8912076626761317, 1.178613869187875, 1.1476460300871296, 1.3456559837339817, 1.5147620377255173, 1.0851100331088968, 2.025706615382685, 1.465786120555888, 1.5758437213836045, 1.3747507925451112, 1.514696114782711, 1.0787511849648712, 1.2778081004129165, 1.4812062607454481, 1.1940242193520108, 1.2671272964139297, 0.5528866938629824, 1.750558686933631, 1.5193468895804731, 0.8524663199861438, 1.7599165303325204, 1.3515136815318292, 1.3837036311579065, 1.5880831919948328, 1.3435958101024181, 1.2759912066889614, 0.7381259545578985, 0.73296600146397, 0.9471806825751471, 1.1131543244039788, 0.8536413717288162, 1.6115270460675597, 1.1668625527470282, 0.7722461217080746, 0.6786116331417779, 1.014032704176449, 1.358596122815872, 1.387103672062142, 1.834677994712349, 1.6149555741863828, 1.5688623233518488, 1.1785995383086505, 1.1853505150484296, 1.4665387869862792, 1.0558289795362412, 0.9108658242363759, 1.1966668778130802, 1.1988699311187603, 1.1327034142171188, 0.9207035331288589, 1.2220422733947698, 0.9239849968612689, 1.0395731859140256, 1.2074618398920438, 0.9939465429304586, 1.2783957592802166, 0.8737639118069738]
# b = [0.9240933406713393, 1.1285852758920147, 0.8150766026065783, 1.2538944662649214, 0.9882358314347126, 1.018550753942142, 1.5821593636172415, 0.6294387898555892, 1.2306118525765102, 0.9593176997918553, 1.3288404145194104, 0.6136038541537142, 0.7885262427228066, 0.6686305915461584, 1.1194212187150119, 1.4530081392812304, 0.8971530299315629, 1.3366592038519216, 1.3957806660884235, 0.8461386480601483, 0.7754372657459035, 0.8114512502690528, 1.2895380433196393, 1.1356924946190035, 1.1702414723438923, 1.4616126101850966, 1.2824183770159372, 0.872865063255638, 1.9386641044227118, 0.9226131040466722, 0.717859369063162, 1.6654108199133946, 1.5467202978573822, 1.3873588806901178, 1.4738694729623387, 0.9535186076236677, 1.2026415959425336, 1.4440071078873638, 0.665582346161891, 0.93087211656529, 1.2778761762935165, 1.039323539836434, 0.9763474133173818, 0.8912076626761317, 1.178613869187875, 1.1476460300871296, 1.4127982217968929, 1.5147620377255173, 1.0108678654409176, 1.8370854448484004, 1.5136983710511311, 1.6934062475985836, 1.2991077529099246, 1.407705445072834, 1.0787511849648712, 1.2306197824770488, 1.4904629233928928, 1.1467840229865878, 1.2175393372545789, 0.5528866938629824, 1.750558686933631, 1.2754502148534816, 0.8524663199861438, 1.639877483262146, 1.2973048836444097, 1.3739165236067754, 1.5880831919948328, 1.2589529384552096, 1.2972915065003567, 0.7381259545578986, 0.73296600146397, 0.8942214656905456, 1.0725116854576517, 0.8536413717288162, 1.4882117957171392, 1.1476285458753828, 0.7722461217080746, 0.6786116331417779, 0.9230991404713407, 1.3251349345120993, 1.387103672062142, 1.7437013833073567, 1.5145398365241969, 1.4739654306020127, 1.1785995383086505, 1.156654024864681, 1.496170083394904, 0.9615470920650548, 0.9108658242363759, 1.279129009637253, 1.1988699311187603, 1.1252824415068956, 0.9207035331288589, 1.2220422733947698, 0.9239849968612689, 1.0395731859140256, 1.2074618398920438, 0.9939465429304586, 1.3471068598830225, 0.9251774394916732]
# findall(a .> b)
println("MIND adam 0")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(0), min_invas=true);
@show t.time
println("AdamBA 1000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000), min_invas=true);
@show t.time

MIND adam 0
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 60.25398101
AdamBA 1000
mean(objs) = 0.02588940064394062
mean(safeties) = 1.1853327552229165
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 15.70886455


15.70886455

In [40]:
# ci = CollisionIndex(0.1, 1e-1, 2, 1)
ci = CollisionIndex(0.728854, 0.00915763, 0.277308, 1.05736)
include("controller.jl")
include("safe_set.jl")

grad_phi (generic function with 2 methods)

In [11]:
ctrl_collision_stat(100, ci, ctrl = NvController(), min_invas=true);

mean(objs) = 0.010735292224368746
mean(safeties) = 0.618376953392255
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [13]:
include("controller.jl")

get_control (generic function with 5 methods)

In [25]:
println("MIND")
t = @timed ctrl_collision_stat(100, ci, ctrl = NvController(), min_invas=true);
@show t.time

MIND


LoadError: [91mInterruptException:[39m

In [14]:
println("MIND adam 0")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(0), min_invas=true);
@show t.time

MIND adam 0
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 36.207394484


36.207394484

In [10]:
println("MIND adam 0")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(0), min_invas=true);
@show t.time

MIND adam 0
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 35.599022046


In [11]:
println("MIND adam 100")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(100), min_invas=true);
@show t.time
println("MIND adam 200")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(200), min_invas=true);
@show t.time

MIND adam 100
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 37.168160258
MIND adam 200
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 38.560330322


38.560330322

In [12]:
println("AdamBA 10")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(10), min_invas=true);
@show t.time
println("AdamBA 100")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(100), min_invas=true);
@show t.time
println("AdamBA 1000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000), min_invas=true);
@show t.time

AdamBA 10
mean(objs) = 0.02846522423901491
mean(safeties) = 1.1887120385683168
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 0.127758681
AdamBA 100
mean(objs) = 0.026154209072604778
mean(safeties) = 1.186429603323412
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 1.514230697
AdamBA 1000
mean(objs) = 0.02588940064394062
mean(safeties) = 1.1853327552229165
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 15.012741526


15.012741526

In [15]:
println("AdamBA 2000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(2000), min_invas=true);
@show t.time
println("AdamBA 10000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(10000), min_invas=true);
@show t.time

AdamBA 2000
mean(objs) = 0.02585874900476202
mean(safeties) = 1.1852488944525295
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 29.197176705
AdamBA 10000
mean(objs) = 0.02585715713599576
mean(safeties) = 1.1852461068699995
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 146.795303865


146.795303865

In [2]:
println("shooting 1000")
t = @timed ctrl_collision_stat(100, ci, ctrl = ShootingController(1000), min_invas=true);
@show t.time
println("shooting 10000")
t = @timed ctrl_collision_stat(100, ci, ctrl = ShootingController(10000), min_invas=true);
@show t.time

shooting 1000


LoadError: [91mUndefVarError: ShootingController not defined[39m

In [67]:
println("MIND")
t = @timed ctrl_collision_stat(100, ci, ctrl = NvController());
@show t.time

MIND
mean(objs) = 0.010329527241514121
mean(safeties) = 0.6181879794934849
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100
t.time = 42.340758727


42.340758727

In [64]:
include("controller.jl")
println("AdamBA 100")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(100));
@show t.time
println("AdamBA 200")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(200));
@show t.time
println("AdamBA 1000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000));
@show t.time

AdamBA 100
mean(objs) = 0.00570127908050178
mean(safeties) = 0.6232056469212774
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100
t.time = 72.229913719
AdamBA 200
mean(objs) = 0.0065632473538424115
mean(safeties) = 0.6183959154306374
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100
t.time = 73.667533575
AdamBA 1000
mean(objs) = 0.010510589984142282
mean(safeties) = 0.618204532353517
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100
t.time = 90.217513618


90.217513618

MIND adam 0
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
objs = Any[0.0, 4.935195482492292e-32, 0.0, 0.0016126192021159282, 0.0, 0.00010530920551155175, 0.03246739933414052, 0.02443982791597609, 0.0, 0.0, 0.13982993918017148, 0.0009026637579596908, 0.0, 0.031611629459301765, 0.0, 0.014523527541330386, 1.232595164407831e-32, 0.00010403724347986067, 0.0642559555041486, 0.00022251532230335868, 1.232595164407831e-32, 0.0, 0.0, 0.006923130826817709, 0.0, 0.0, 0.024547211336191775, 0.0, 0.0, 0.0, 0.06743851898282306, 5.89816045468591e-34, 0.07353951477761797, 0.0, 0.20855967088408217, 0.0015022116376389948, 0.00519266629226399, 0.02537906285403969, 0.0, 0.02581463583345446, 0.05082108250841586, 0.0, 0.04062193859890905, 0.0, 4.942417719783744e-33, 0.0, 0.03984959442481309, 0.0, 0.11322472303373288, 0.03309426313738939, 0.11982988737907241, 0.0498312964115266, 0.12014053520324029, 0.09164735400024132, 0.0, 0.0

35.741975151

In [51]:
println("MIND adam 0")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(0), min_invas=true);
@show t.time
println("MIND adam 100")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(100), min_invas=true);
@show t.time
println("MIND adam 200")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(200), min_invas=true);
@show t.time

MIND adam 0


LoadError: [91mInterruptException:[39m

In [50]:
ci = CollisionIndex(0.728854, 0.00915763, 0.277308, 1.05736)
include("controller.jl")
include("safe_set.jl")
println("MIND")
t = @timed ctrl_collision_stat(100, ci, ctrl = NvController(), min_invas=true);
@show t.time
println("MIND adam 0")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(0), min_invas=true);
@show t.time
println("MIND adam 100")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(100), min_invas=true);
@show t.time
println("MIND adam 200")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamNvController(200), min_invas=true);
@show t.time
println("AdamBA 10")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(10), min_invas=true);
@show t.time
println("AdamBA 100")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(100), min_invas=true);
@show t.time
println("AdamBA 1000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000), min_invas=true);
@show t.time
println("AdamBA 10000")
t = @timed ctrl_collision_stat(100, ci, ctrl = AdamBAController(10000), min_invas=true);
@show t.time


MIND
mean(objs) = 0.025856311470707768
mean(safeties) = 1.1852410759037943
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 170.822197971
MIND adam 0
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 568.808211742
MIND adam 100
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 367.453760616
MIND adam 200
mean(objs) = 0.025856311489413423
mean(safeties) = 1.1852410759975889
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 548.003841854
AdamBA 10
mean(objs) = 0.02846522423901491
mean(safeties) = 1.1887120385683168
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 0.375107059
AdamBA 100
mean(objs) = 0.026154209072604778
mean(safeties) = 1.186429603323412
mean(x0_safeties) = 1.3850537013150683
next_safe_cnt = 100
t.time = 2.069699151
AdamBA 1000
mean(objs) = 0.0

181.852283202

In [45]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = NvController());

mean(objs) = 0.05937712629586693
mean(safeties) = 0.6105917568928835
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100
mean(objs) = 0.05503180988366449
mean(safeties) = 0.6531459480595868
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [73]:
ctrl_collision_stat(100, ci, ctrl = ShootingController(1));

mean(objs) = 0.04996880286758243
mean(safeties) = 0.6898609257924996
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [68]:
ctrl_collision_stat(100, ci, ctrl = ShootingController(1000));

mean(objs) = 0.06261668094248928
mean(safeties) = 0.6185657938551002
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [71]:
ctrl_collision_stat(100, ci, ctrl = ShootingController(10000));

mean(objs) = 0.06121764319235339
mean(safeties) = 0.6161597976289677
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [72]:
ctrl_collision_stat(100, ci, ctrl = ShootingController(100000));

mean(objs) = 0.059603538107113635
mean(safeties) = 0.6106403641514753
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


mean(objs) = 0.05937712629586693
mean(safeties) = 0.6105917568928835
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [25]:
ctrl_collision_stat(100, ci, ctrl = NvController());

mean(objs) = 0.05937712629586693
mean(safeties) = 0.6105917568928835
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [54]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000)); #debug, directly return nv results

mean(objs) = 4.7075676803460393e-7
mean(safeties) = 0.6400621345162333
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [58]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = AdamBAController(10));

mean(objs) = 0.049090845068017946
mean(safeties) = 0.6559487935929158
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [56]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000));

mean(objs) = 0.06072854172087785
mean(safeties) = 0.6162420449369568
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [53]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000));

mean(objs) = 0.05503180988366449
mean(safeties) = 0.6531459480595868
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [46]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = AdamBAController(1));

mean(objs) = 0.04996880286758243
mean(safeties) = 0.6898609257924996
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [50]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = AdamBAController(100,10));

mean(objs) = 0.061268216095466954
mean(safeties) = 0.616701011775147
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [51]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = AdamBAController(100,100));

mean(objs) = 0.05975939399769758
mean(safeties) = 0.610766868592338
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [52]:
include("controller.jl")
ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000,10));

mean(objs) = 0.06077597991123765
mean(safeties) = 0.6163473517261459
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [74]:
ctrl_collision_stat(100, ci, ctrl = AdamBAController(1000, 10));

mean(objs) = 0.06252044997892069
mean(safeties) = 0.6189743179365318
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [75]:
ctrl_collision_stat(100, ci, ctrl = AdamBAController(10, 1000));

mean(objs) = 0.0542412457276126
mean(safeties) = 0.6815785075337627
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [76]:
ctrl_collision_stat(100, ci, ctrl = AdamBAController(100, 100));

mean(objs) = 0.06423595835454991
mean(safeties) = 0.6309906128520527
mean(x0_safeties) = 0.3974052085097087
unsafe_cnt = 100


In [13]:
function find_max(dot_x_ref, obj_cost)
    
    x_lim = [0.5, 0.5, 2, π]
    # x_lim = [0.1, 0.1, 2, π]
    u_lim = [4, π]
    input = Hyperrectangle(low=[-x_lim; -u_lim], high=[x_lim; u_lim])
    output_bound = [1.,1.,1.,1.] .* 1000
    output = Hyperrectangle(dot_x_ref, output_bound)
    problem = TrackingProblem(net, input, output, dot_x_ref, obj_cost)
    result, start_values = NeuralVerification.solve(NNDynTrack(), problem)
    x = result.input
    dot_x = compute_output(net, x)
    return dot_x
end
function find_max_v_a(ctrl::NvController, net)
    
    v_dot_x_ref = [1000.,1000.,0.,0.]
    v_obj_cost = [1,1,0,0]
    a_dot_x_ref = [0.,0.,1000.,0.]
    a_obj_cost = [0,0,1,0]
    
    # max positive v
    dot_x = find_max(v_dot_x_ref, v_obj_cost)
    max_v = norm(dot_x[1:2])
    @show dot_x
    @show max_v

    # max positve a
    dot_x = find_max(a_dot_x_ref, a_obj_cost)
    max_a = dot_x[3]
    @show dot_x
    @show max_a

    # max negative v
    dot_x = find_max(-v_dot_x_ref, v_obj_cost)
    max_v = norm(dot_x[1:2])
    @show dot_x
    @show max_v

    # max negative a
    dot_x = find_max(-a_dot_x_ref, a_obj_cost)
    max_a = dot_x[3]
    @show dot_x
    @show max_a
end
include("controller.jl")
find_max_v_a(NvController(), net)

In [None]:
# x_lim = [0.1, 0.1, 2, π]
# u_lim = [4, π]
# dot_x = [1.3754507922154657, 1.5349397117512167, 3.7547234479671623, -2.871642356616455]
# max_v = 2.0610445896479437
# dot_x = [1.9573293306380013, 0.2828251752514743, 4.008726879784431, -3.0297639968196197]
# max_a = 0.2828251752514743
# dot_x = [-1.5369360276450639, -1.3430358739067063, 3.969877843788106, -3.119788462845145]
# max_v = 2.0410579883172693
# dot_x = [-2.041141947924824, 0.20699559707716228, -4.010629566411886, 2.4005379725253957]
# max_a = 0.20699559707716228