## IQL for Pursuit


### Arguments

In [1]:
NUM_AGENTS: int = 10
MAX_CYCLES: int = 200
NUM_EPISODES: int = 150
MAX_EPISODES_LEN: int = 150
# target DQN update/freeze frequency
DQN_UPDATE_FREQ: int = 25

SHOW_PLOT: bool = False
RENDER: bool = False

### Main Trainer

In [2]:
from a3marl.envs.utils import EnvConfig
from a3marl.agents import IqlAgent, IqlAgentConfig
from a3marl.trainer import iql_trainer
from pettingzoo.sisl import pursuit_v4


def main() -> None:
    pursuit_config: EnvConfig = EnvConfig(
        name_abbr="pu",
        env_creator=pursuit_v4.parallel_env,
        env_kwargs={
            "n_pursuers": NUM_AGENTS,
            "max_cycles": MAX_CYCLES,
            "n_catch": 1,
        },
    )

    env = pursuit_config.get_env(
        render_mode="human" if RENDER else None,
    )

    states, info = env.reset()
    obs_dims = dict([(agent_key, state.size) for agent_key, state in states.items()])
    act_dims = dict([(agent_key, int(env.action_space(agent_key).n)) for agent_key in obs_dims.keys()])

    cur_agents = {
        agent_key: IqlAgent(sid=agent_key, config=IqlAgentConfig(
            obs_dim=obs_dims[agent_key],
            act_dim=act_dims[agent_key],
            hidden_dims=[100, 50, 25],
            batch_size=256,
            lr=1e-3,
            grad_clip_value=5,
            gamma=0.95,
            eps_start=0.9,
            eps_decay=0.95,
            eps_min=0.05,
            mem_size=10_000,
        ), act_sampler=env.action_space(agent_key).sample)
        for agent_key in states.keys()
    }

    iql_trainer(
        env=env,
        env_config=pursuit_config,
        cur_agents=cur_agents,
        num_episodes=NUM_EPISODES,
        max_episode_lengths=MAX_EPISODES_LEN,
        dqn_update_freq=DQN_UPDATE_FREQ,
        show_plot=SHOW_PLOT,
    )

### Run

In [3]:
if __name__ == "__main__":
    main()

Episode 0	Avg return = -4.6810;
-4.71033 vs best: -inf, update Target DQN
-4.68705 vs best: -4.71033, update Target DQN
-4.49730 vs best: -4.68705, update Target DQN
-4.36324 vs best: -4.49730, update Target DQN
-4.13972 vs best: -4.36324, update Target DQN
-4.12780 vs best: -4.13972, update Target DQN
-4.07740 vs best: -4.12780, update Target DQN
-3.60359 vs best: -4.07740, update Target DQN
Episode 10	Avg return = -4.4049;
Episode 20	Avg return = -4.4063;
-3.59309 vs best: -3.60359, update Target DQN
Episode 30	Avg return = -4.2745;
Episode 40	Avg return = -4.4725;
Episode 50	Avg return = -4.2831;
Episode 60	Avg return = -4.2028;
Episode 70	Avg return = -4.3326;
-3.58669 vs best: -3.59309, update Target DQN
-3.29506 vs best: -3.58669, update Target DQN
Episode 80	Avg return = -4.3639;
-2.99143 vs best: -3.29506, update Target DQN
Episode 90	Avg return = -4.0940;
Episode 100	Avg return = -4.2337;
Episode 110	Avg return = -4.2227;
Episode 120	Avg return = -4.2152;
Episode 130	Avg retur