In [None]:
using Revise
using Pkg

ENV["PYTHON"] = Sys.which("python")
ENV["PYCALL_JL_RUNTIME_PYTHON"] = Sys.which("python")
Pkg.build("PyCall")
using FileIO
using JLD2
include("../src/RiskSensitiveSAC.jl")
using .RiskSensitiveSAC

In [None]:
include("$(@__DIR__)/../scripts/default_params/params_data_gaussian.jl");

dtr = 0.4;                                                                          # replanning time interval [s]
# dtc = 0.4;
tcalc = 0.2;                                                                        # pre-allocated control computation time [s] (< dtr)
sim_horizon = 16.0;                                                                 # simulation horizon [s]

# model_dir = normpath(joinpath(@__DIR__, "../CrowdNav/crowd_nav/data/output_om_sarl_radius_0.4")) # directory of the trained policy
model_dir = normpath(joinpath(@__DIR__, "../CrowdNav/crowd_nav/data/output")) # directory of the trained policy
env_config = "env.config"                                                           # environment config file name
policy_config = "policy.config"                                                     # policy config file name
policy_name = "sarl"                                                                # policy name

include("$(@__DIR__)/../scripts/parameter_setup_crowd_nav.jl");

In [None]:
scene_loader, controller, w_init, ado_inputs, measurement_schedule, target_trajectory, target_speed =
controller_setup(scene_param, cnt_param,
                 cost_param=cost_param,
                 dtc=dtc,
                 prediction_steps=prediction_steps,
                 ego_pos_init_vec=ego_pos_init_vec,
                 ego_pos_goal_vec=ego_pos_goal_vec,
                 target_speed=target_speed,
                 sim_horizon=sim_horizon,
                 verbose=true);

In [None]:
result, ~, ~ = evaluate(scene_loader, controller, w_init, ego_pos_goal_vec, target_speed,
                  measurement_schedule, target_trajectory, pos_error_replan, 
                  ado_inputs_init=ado_inputs);

In [None]:
display_log(result.log)

In [None]:
result.total_cnt_cost

In [None]:
result.total_pos_cost

In [None]:
result.total_col_cost

In [None]:
result.total_cnt_cost + result.total_pos_cost + result.total_col_cost

In [None]:
minimum([minimum(vcat([norm(get_position(w.e_state) - ap) for ap in values(w.ap_dict)], Inf))
                          for w in result.w_history])

In [None]:
make_gif(result, dtplot=0.4, fps=2, xlim=(-3. + -5.263534, 13. + -5.314636), 
         ylim=(0. + -5.263534, 10. + -5.314636), figsize=(600, 400), 
         legendfontsize=7, legend=:bottomright, markersize=5., filename="8_crowd_nav_data.gif")

In [None]:
save("8_crowd_nav_data.jld2", "result", result)