In [1]:
using Revise
using LazySets
using DifferentialEquations
using LazySets
using ProgressMeter
using ProgressBars
using JLD2
using Flux
using LinearAlgebra
using Zygote
using ReverseDiff
using Plots
using Statistics
using Optimisers, ParameterSchedulers
using RobotDynamics
using RobotZoo
using Random
using Rotations

In [3]:
using RobotZoo
import RobotDynamics as RD

function random_point_in_hyperrectangle(hyperrectangle::Hyperrectangle, non_admissible_area=nothing;q=false)
    dimensions = dim(hyperrectangle)
    random_point = zeros(dimensions)
    for i in 1:dimensions
        random_point[i] = rand() * (high(hyperrectangle, i)-low(hyperrectangle, i)) + low(hyperrectangle, i)
    end
    if q
        model = RobotZoo.Quadrotor()
        x,u = rand(model)
        random_point[4:7] .= x[4:7]
    end
    isnothing(non_admissible_area) && return random_point, true
    (random_point ∉ non_admissible_area) && return random_point, true
    return random_point, false
end

function generate_Xref(dmodel, x_0, dt, T, X, X_unsafe, U; max_u=10000,euler=false)
    n_steps = Int(floor(T / dt))
    Uref = []
    Xref = []
    push!(Xref, x_0)
    for i in 1:n_steps
        u = nothing
        x = Xref[end]
        x′ = nothing
        feasible = false
        for j in 1:max_u
            u, _ = random_point_in_hyperrectangle(U)
            if euler
                f(x, p, t) = quadrotor_dynamics_euler!(x, u)
                tspan = (0.0, dt)
                prob = ODEProblem(f , x, tspan)
                sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8)
                x′ = sol[end]
            else
                x′ = RD.discrete_dynamics(dmodel, x, u, 0.0, dt)
            end
            if (x′ ∉ X_unsafe) && (x′ ∈ X)
                feasible = true
                break
            end
        end
        if !feasible
            (length(Uref)==1) && (return Xref, Uref)
            (length(Xref)==1) && (return Xref, Uref)
            pop!(Xref)
            pop!(Uref)
            continue
        end
        push!(Xref, x′)
        push!(Uref, u)
    end
    return Xref, Uref
end

function generate_random_traj(dmodel, num, dt, T,X, X_unsafe, U;q=false,euler=false)
    Xrefs = []
    Urefs = []
    @showprogress for i = 1:num
        x_0 = nothing
        while true
            x_0, safe_flag = random_point_in_hyperrectangle(X, X_unsafe;q=q)
            safe_flag && break
        end
        
        Xref, Uref = generate_Xref(dmodel, x_0, dt, T, X, X_unsafe, U;euler=euler)
        push!(Xrefs, Xref)
        push!(Urefs, Uref)
    end
    return Xrefs, Urefs
end



generate_random_traj (generic function with 1 method)

In [5]:
function plot_function(Xrefs; n_ignore=50,q=false)
    # p = plot()
    plt1 = plot(Hyperrectangle(low=low(X)[1:2], high=high(X)[1:2]))
    plot!(plt1, Hyperrectangle(low=low(X_unsafe)[1:2], high=high(X_unsafe)[1:2]), fillcolor=:red)
#     @show length(Xrefs), length(Urefs[1])
    valid_num = 0
    for k = 1:length(Xrefs)
        if length(Urefs[k])<n_ignore+1
            continue
        end
#         @show length(Urefs[k]), length(Xrefs[k])
        @assert length(Urefs[k]) == (length(Xrefs[k]) - 1) 
        
        xs = [Xrefs[k][i][1] for i in 1:length(Urefs[k])-n_ignore]
        ys = [Xrefs[k][i][2] for i in 1:length(Urefs[k])-n_ignore]
        # @show length(xs)
        if q
            zs = [Xrefs[k][i][3] for i in 1:length(Urefs[k])-n_ignore]
            plot!(xs, ys,zs, legend = false)
        else
            plot!(xs, ys, legend = false)
        end
        valid_num += length(Urefs[k]) - n_ignore
    end
    display(plt1)
    @show valid_num
end

plot_function (generic function with 1 method)

In [6]:
function build_dataset(name, Xrefs, Urefs, X, X_unsafe, U; n_ignore=50,q=false)
    data = []
    for k = 1:length(Xrefs)
        if length(Urefs[k]) < n_ignore+1
            continue
        end
        for i in 1:length(Urefs[k])-n_ignore
            push!(data, [Xrefs[k][i], Urefs[k][i],[true]]) # safe and persistently feasible
        end
    end
    n_safe = Int(floor(length(data)*0.8))
    for i in 1:n_safe
        random_x0, safe_flag = random_point_in_hyperrectangle(X_unsafe, X_unsafe;q=q)
        random_u0, _ = random_point_in_hyperrectangle(U)
        @assert safe_flag==false
        push!(data, [random_x0, random_u0, [safe_flag]])
    end
    
    data = reduce(hcat,data)
    shuffled_indices = shuffle(1:size(data, 2))
    data = data[:, shuffled_indices]
    training_data = data[:, 1:end-10000]
    test_data = data[:, end-10000:end]
    save_object(name*"_training_data.jld2", training_data)
    save_object(name*"_test_data.jld2", test_data)
end

build_dataset (generic function with 1 method)

In [None]:

dyn_model = RobotZoo.PlanarQuadrotor()
n,m = RD.dims(dyn_model)
# @show n,m
dmodel = RD.DiscretizedDynamics{RD.RK4}(dyn_model)
# @show dmodel

X = Hyperrectangle(low = [0, 0, -0.1, -1, -1 ,-1], high = [4,4, 0.1, 1,1,1])
U = Hyperrectangle(low = [4, 4], high = [6,6])
X_unsafe = Hyperrectangle(low = [1.5, 0,-0.1,-1, -1 ,-1], high = [2.5,2, 0.1, 1,1,1])


Xrefs, Urefs = generate_random_traj(dmodel, 500000, 0.1, 10, X, X_unsafe, U);
plot_function(Xrefs)
build_dataset("planarquad", Xrefs, Urefs, X, X_unsafe, U)

In [None]:

dyn_model = RobotZoo.DubinsCar()
n,m = RD.dims(dyn_model)
dmodel = RD.DiscretizedDynamics{RD.RK4}(dyn_model)

X = Hyperrectangle(low = [0, 0, 0], high = [4,4, π])
U = Hyperrectangle(low = [-1, -1], high = [1,1])
X_unsafe = Hyperrectangle(low = [1.5, 0,0], high = [2.5,2, π])


Xrefs, Urefs = generate_random_traj(dmodel, 50000, 0.1, 10, X, X_unsafe, U);
plot_function(Xrefs)
build_dataset("car", Xrefs, Urefs, X, X_unsafe, U)

In [None]:

dyn_model = RobotZoo.DoubleIntegrator(2)
n,m = RD.dims(dyn_model)
dmodel = RD.DiscretizedDynamics{RD.RK4}(dyn_model)

X = Hyperrectangle(low = [0, 0, -1, -1], high = [4,4, 1, 1])
U = Hyperrectangle(low = [-1, -1], high = [1,1])
X_unsafe = Hyperrectangle(low = [1.5, 0, -1, -1], high = [2.5,2, 1, 1])



Xrefs, Urefs = generate_random_traj(dmodel, 500000, 0.1, 10, X, X_unsafe, U);
# plot_function(Xrefs)
build_dataset("point", Xrefs, Urefs, X, X_unsafe, U)