Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions DojoEnvironments/src/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,36 @@ end
state: provided state
"""
function state_map(::Environment, state)
return x
return state
end

"""
input_map(environment, input)

maps the provided input to the environments internal input
maps the provided input to the environment's internal input

environment: environment
input: provided input
"""
function input_map(::Environment, u)
return u
function input_map(::Environment, input)
return input
end

function input_map(environment::Environment, ::Nothing)
return zeros(input_dimension(environment.mechanism))
end

"""
set_input!(environment, input)

sets the provided input to the environment's mechanism

environment: environment
input: provided input
"""
function Dojo.set_input!(environment::Environment, input)
set_input!(environment.mechanism, input_map(environment, input))
return
end

"""
Expand All @@ -57,10 +74,10 @@ end
record: record step in storage
opts: SolverOptions
"""
function Dojo.step!(environment::Environment, x, u=nothing; k=1, record=false, opts=SolverOptions())
x = state_map(environment, x)
u = input_map(environment, u)
Dojo.step_minimal_coordinates!(environment.mechanism, x, u; opts)
function Dojo.step!(environment::Environment, state, input=nothing; k=1, record=false, opts=SolverOptions())
state = state_map(environment, state)
input = input_map(environment, input)
Dojo.step_minimal_coordinates!(environment.mechanism, state, input; opts)
record && Dojo.save_to_storage!(environment.mechanism, environment.storage, k)

return
Expand All @@ -75,8 +92,9 @@ end
controller!: Control function
kwargs: same as for Dojo.simulate
"""
function Dojo.simulate!(environment::Environment{T,N}, controller! = (mechanism, k) -> nothing; kwargs...) where {T,N}
simulate!(environment.mechanism, 1:N, environment.storage, controller!; kwargs...)
function Dojo.simulate!(environment::Environment{T,N}, controller! = (environment, k) -> nothing; kwargs...) where {T,N}
controller_wrapper!(mechanism, k) = controller!(environment, k)
simulate!(environment.mechanism, 1:N, environment.storage, controller_wrapper!; kwargs...)
end

"""
Expand All @@ -90,6 +108,18 @@ function get_state(environment::Environment)
return get_minimal_state(environment.mechanism)
end

"""
initialize!(environment; kwargs...)

initializes the environment's mechanism

environment: Environment
kwargs: same as for DojoEnvironments' mechanisms
"""
function Dojo.initialize!(environment::Environment, model; kwargs...)
eval(Symbol(:initialize, :_, string_to_symbol(model), :!))(environment.mechanism; kwargs...)
end

"""
visualize(environment; kwargs...)

Expand Down
2 changes: 0 additions & 2 deletions DojoEnvironments/src/environments/ant_ars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ function ant_ars(;
dampers=0,
parse_springs=true,
parse_dampers=true,
limits=true,
joint_limits=Dict([
(:hip_1, [-30,30] * π / 180),
(:ankle_1, [30,70] * π / 180),
Expand All @@ -38,7 +37,6 @@ function ant_ars(;
dampers,
parse_springs,
parse_dampers,
limits,
joint_limits,
keep_fixed_joints,
friction_coefficient,
Expand Down
8 changes: 3 additions & 5 deletions DojoEnvironments/src/environments/cartpole_dqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ function cartpole_dqn(;
gravity=-9.81,
slider_mass=1,
pendulum_mass=1,
pendulum_length=1,
link_length=1,
radius=0.075,
color=RGBA(0.7, 0.7, 0.7, 1),
springs=0,
dampers=0,
limits=false,
joint_limits=Dict(),
keep_fixed_joints=false,
keep_fixed_joints=true,
T=Float64)

mechanism = get_cartpole(;
Expand All @@ -26,12 +25,11 @@ function cartpole_dqn(;
gravity,
slider_mass,
pendulum_mass,
pendulum_length,
link_length,
radius,
color,
springs,
dampers,
limits,
joint_limits,
keep_fixed_joints,
T
Expand Down
6 changes: 5 additions & 1 deletion DojoEnvironments/src/environments/include.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
include("ant_ars.jl")
include("cartpole_dqn.jl")
include("pendulum.jl")
include("quadruped_sampling.jl")
include("quadruped_waypoint.jl")
include("quadruped_sampling.jl")
include("quadrotor_waypoint.jl")
include("uuv_waypoint.jl")
include("youbot_waypoint.jl")
8 changes: 3 additions & 5 deletions DojoEnvironments/src/environments/pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ function pendulum(;
input_scaling=timestep,
gravity=-9.81,
mass=1,
length=1,
link_length=1,
color=RGBA(1, 0, 0),
springs=0,
dampers=0,
limits=false,
joint_limits=Dict(),
spring_offset=szeros(1),
orientation_offset=one(Quaternion),
Expand All @@ -24,18 +23,17 @@ function pendulum(;
input_scaling,
gravity,
mass,
length,
link_length,
color,
springs,
dampers,
limits,
joint_limits,
spring_offset,
orientation_offset,
T
)

storage = Storage(horizon, Base.length(mechanism.bodies))
storage = Storage(horizon, length(mechanism.bodies))

return Pendulum{T,horizon}(mechanism, storage)
end
Expand Down
164 changes: 164 additions & 0 deletions DojoEnvironments/src/environments/quadrotor_waypoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
mutable struct QuadrotorWaypoint{T,N} <: Environment{T,N}
mechanism::Mechanism{T}
storage::Storage{T,N}

rpms::AbstractVector
end

function quadrotor_waypoint(;
horizon=100,
timestep=0.01,
input_scaling=timestep,
gravity=-9.81,
urdf=:pelican,
springs=0,
dampers=0,
parse_springs=true,
parse_dampers=true,
joint_limits=Dict(),
keep_fixed_joints=false,
friction_coefficient=0.5,
contact_rotors=true,
contact_body=true,
T=Float64)

mechanism = get_quadrotor(;
timestep,
input_scaling,
gravity,
urdf,
springs,
dampers,
parse_springs,
parse_dampers,
joint_limits,
keep_fixed_joints,
friction_coefficient,
contact_rotors,
contact_body,
T
)

storage = Storage(horizon, length(mechanism.bodies))

return QuadrotorWaypoint{T,horizon}(mechanism, storage, zeros(4))
end

function state_map(::QuadrotorWaypoint, state)
state = [state;zeros(8)]
return state
end

function input_map(environment::QuadrotorWaypoint, input)
# Input is rotor rpm directly
# Rotors are only visualized, dynamics are mapped here
environment.rpms = input

body = get_body(environment.mechanism, :base_link)
q = body.state.q2

force_torque = rpm_to_force_torque(environment, input, q)

input = [force_torque;zeros(4)]

return input
end

function input_map(::QuadrotorWaypoint, ::Nothing)
return zeros(10)
end

function Dojo.step!(environment::QuadrotorWaypoint, state, input=nothing; k=1, record=false, opts=SolverOptions())
state = state_map(environment, state)
input = input_map(environment, input)
Dojo.step_minimal_coordinates!(environment.mechanism, state, input; opts)
record && Dojo.save_to_storage!(environment.mechanism, environment.storage, k)

return
end

function Dojo.simulate!(environment::QuadrotorWaypoint{T,N}, controller! = (environment, k) -> nothing; kwargs...) where {T,N}
mechanism = environment.mechanism

joint_rotor_0 = get_joint(mechanism, :rotor_0_joint)
joint_rotor_1 = get_joint(mechanism, :rotor_1_joint)
joint_rotor_2 = get_joint(mechanism, :rotor_2_joint)
joint_rotor_3 = get_joint(mechanism, :rotor_3_joint)

function controller_wrapper!(mechanism, k)
rpms = environment.rpms
set_minimal_velocities!(mechanism, joint_rotor_0, [rpms[1]])
set_minimal_velocities!(mechanism, joint_rotor_1, [-rpms[2]])
set_minimal_velocities!(mechanism, joint_rotor_2, [rpms[3]])
set_minimal_velocities!(mechanism, joint_rotor_3, [-rpms[4]])

controller!(environment, k)
end

simulate!(environment.mechanism, 1:N, environment.storage, controller_wrapper!; kwargs...)
end

function get_state(environment::QuadrotorWaypoint)
state = get_minimal_state(environment.mechanism)[1:12]
return state
end

function Dojo.visualize(environment::QuadrotorWaypoint; return_animation=false, kwargs...)
vis, animation = visualize(environment.mechanism, environment.storage; return_animation=true, kwargs...)

waypoints = [
[1;1;0.3;pi/4],
[2;0;0.3;-pi/4],
[1;-1;0.3;-3*pi/4],
[0;0;0.3;-5*pi/4],
]
for i=1:4
waypoint_shape = Sphere(0.2;color=RGBA(0,0.25*i,0,0.3))
visshape = Dojo.convert_shape(waypoint_shape)
subvisshape = vis["waypoints"]["waypoint$i"]
Dojo.setobject!(subvisshape, visshape, waypoint_shape)
Dojo.atframe(animation, 1) do
Dojo.set_node!(waypoints[i][1:3], one(Quaternion), waypoint_shape, subvisshape, true)
end
end
Dojo.setanimation!(vis,animation)

return_animation ? (return vis, animation) : (return vis)
end

# ## physics functions

function rpm_to_force_torque(::QuadrotorWaypoint, rpm::Real, rotor_sign::Int64)
force_factor = 0.001
torque_factor = 0.0001

force = sign(rpm)*force_factor*rpm^2
torque = sign(rpm)*rotor_sign*torque_factor*rpm^2

return [force;0;0], [torque;0;0]
end
function rpm_to_force_torque(environment::QuadrotorWaypoint, rpms::AbstractVector, q::Quaternion)
qympi2 = Dojo.RotY(-pi/2)
orientations = [qympi2;qympi2;qympi2;qympi2]
directions = [1;-1;1;-1]
force_vertices = [
[0.21; 0; 0.05],
[0; 0.21; 0.05],
[-0.21; 0; 0.05],
[0; -0.21; 0.05],
]

forces_torques = [rpm_to_force_torque(environment, rpms[i], directions[i]) for i=1:4]
forces = getindex.(forces_torques,1)
torques = getindex.(forces_torques,2)

forces = Dojo.vector_rotate.(forces, orientations) # in local frame
torques = Dojo.vector_rotate.(torques, orientations) # in local frame

torques_from_forces = Dojo.cross.(force_vertices, forces)

force = Dojo.vector_rotate(sum(forces), q) # in minimal frame
torque = Dojo.vector_rotate(sum(torques .+ torques_from_forces), q) # in minimal frame

return [force; torque]
end
4 changes: 1 addition & 3 deletions DojoEnvironments/src/environments/quadruped_sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ function quadruped_sampling(;
parse_springs=true,
parse_dampers=true,
spring_offset=true,
limits=true,
joint_limits=Dict(vcat([[
(Symbol(group,:_hip_joint), [-0.5,0.5]),
(Symbol(group,:_thigh_joint), [-0.5,1.5]),
(Symbol(group,:_calf_joint), [-2.5,-1])]
for group in [:FR, :FL, :RR, :RL]]...)),
keep_fixed_joints=true,
keep_fixed_joints=false,
friction_coefficient=0.8,
contact_feet=true,
contact_body=true,
Expand All @@ -36,7 +35,6 @@ function quadruped_sampling(;
parse_springs,
parse_dampers,
spring_offset,
limits,
joint_limits,
keep_fixed_joints,
friction_coefficient,
Expand Down
Loading