# Train a policy using RL, part 2: multi-agent environment

In this notebook, we training multiple agents to perform the same task as in the `RL` notebook.

In [41]:
import gymnasium as gym
import numpy as np
from navground import sim
from navground_learning import ControlActionConfig, ObservationConfig
from navground_learning.reward import SocialReward
from navground_learning.evaluate import make_experiment
from navground_learning.config import WorldConfig, GroupConfig
from stable_baselines3 import SAC

import warnings
warnings.filterwarnings('ignore')

sensor = sim.load_state_estimation("""
type: Discs
number: 5
range: 5.0
max_speed: 0.12
max_radius: 0.0
""")

scenario = sim.load_scenario("""
type: Cross
agent_margin: 0.1
side: 4
target_margin: 0.1
tolerance: 0.5
groups:
  -
    type: thymio
    number: 10
    radius: 0.1
    control_period: 0.1
    speed_tolerance: 0.02
    color: gray
    kinematics:
      type: 2WDiff
      wheel_axis: 0.094
      max_speed: 0.12
    behavior:
      type: HL
      optimal_speed: 0.12
      horizon: 5.0
      tau: 0.25
      eta: 0.5
      safety_margin: 0.1
    state_estimation:
      type: Bounded
      range: 5.0
""")

action_config = ControlActionConfig(max_acceleration=1.0, max_angular_acceleration=10.0, 
                                    use_acceleration_action=True)

observation_config = ObservationConfig(include_target_distance=True, include_velocity=True, 
                                       include_angular_speed=True, flat=True)

sa_model = SAC.load('/Users/jerome.guzzi/Dev/ROS/ros2_ws/src/navground_learning/docs/source/tutorials/policies/RL/SAC.zip')

gc = GroupConfig(action=action_config, observation=observation_config, sensor=sensor, indices=slice(0, 20, 1))
wc = WorldConfig(groups=[gc], policies=[(slice(0, 20, 1), sa_model.policy)])
exp = make_experiment(scenario=scenario, config=wc, steps=3600, terminate_when_all_idle_or_stuck=False)

In [42]:
exp.run_once(1)

<navground.sim._navground_sim.ExperimentalRun at 0x3a87e06b0>

In [43]:
exp.runs[1].duration.total_seconds()

4.271785