## CQL for Pursuit

### Arguments

In [1]:
NUM_AGENTS: int = 6
MAX_CYCLES: int = 150
NUM_EPISODES: int = 200
MAX_EPISODES_LEN: int = 150
DQN_UPDATE_FREQ: int = 10

RENDER: bool = False
SHOW_PLOTS: bool = False

### Main Trainer

In [2]:
from pettingzoo.sisl import pursuit_v4

from agents import CqlAgent, CqlAgentConfig
from trainer import cql_trainer
from envs.utils import EnvConfig

def main() -> None:
    pursuit_cql_env: EnvConfig = EnvConfig(
        name_abbr="pu",
        env_creator=pursuit_v4.parallel_env,
        env_kwargs={
            "n_pursuers": NUM_AGENTS,
            "max_cycles": MAX_CYCLES,
            "x_size": 12,
            "y_size": 12,
            "n_catch": 1,
        },
    )

    env = pursuit_cql_env.get_env(
        render_mode="human" if RENDER else None,
    )
    states, info = env.reset()
    obs_dims = dict([(agent_key, state.size) for agent_key, state in states.items()])
    act_dims = dict([(agent_key, int(env.action_space(agent_key).n)) for agent_key in obs_dims.keys()])

    def group_sample():
        return {agent_key: int(env.action_space(agent_key).sample()) for agent_key in act_dims.keys()}

    central_agent = CqlAgent(
        sid='central_agent',
        config=CqlAgentConfig(
            obs_dims=obs_dims,
            act_dims=act_dims,
            hidden_dims=[2000, 5000],
            batch_size=128,
            lr=1e-3,
            grad_clip_value=5,
            gamma=0.95,
            eps_start=0.9,
            eps_decay=0.95,
            eps_min=0.05,
            mem_size=10_000,
        ).infer_joint_space(),
        act_sampler=group_sample,
    )
    print(
        f"joint_obs_dim: {central_agent.config.obs_dim}\n"
        f"joint_act_dim: {central_agent.config.act_dim}"
    )
    # central_agent.target_net
    cql_trainer(
        env=env,
        env_config=pursuit_cql_env,
        central_agent=central_agent,
        num_episodes=NUM_EPISODES,
        max_episode_lengths=MAX_EPISODES_LEN,
        dqn_update_freq=DQN_UPDATE_FREQ,
        show_plot=SHOW_PLOTS,
    )


### Run

In [3]:
if __name__ == "__main__":
    main()

joint_obs_dim: 882
joint_act_dim: 15625
Episode 0: Avg return = -4.4823;
-3.6143 vs. best: -inf, update TarNet
-2.8177 vs. best: -3.6143, update TarNet
-2.5104 vs. best: -2.8177, update TarNet
Episode 10: Avg return = -4.4912;
Episode 20: Avg return = -4.1001;
-2.0688 vs. best: -2.5104, update TarNet
Episode 30: Avg return = -3.5644;
Episode 40: Avg return = -4.2704;
Episode 50: Avg return = -3.0071;
Episode 60: Avg return = -4.2786;
Episode 70: Avg return = -3.7967;
Episode 80: Avg return = -3.2081;
-1.6056 vs. best: -2.0688, update TarNet
Episode 90: Avg return = -3.8204;
Episode 100: Avg return = -4.1419;
Episode 110: Avg return = -4.2278;
Episode 120: Avg return = -3.5583;
Episode 130: Avg return = -2.7707;
Episode 140: Avg return = -4.3173;
Episode 150: Avg return = -2.7839;
Episode 160: Avg return = -3.0593;
Episode 170: Avg return = -3.7625;
Episode 180: Avg return = -3.4412;
Episode 190: Avg return = -3.8730;
Episode 199: Avg return = -3.5462;
