## CQL for Pursuit

### Arguments

In [None]:
NUM_AGENTS: int = 2
MAX_CYCLES: int = 50
NUM_EPISODES: int = 100
MAX_EPISODES_LEN: int = 25
DQN_UPDATE_FREQ: int = 5

### Main Trainer

In [None]:
from agent import CqlAgent, CqlAgentConfig
from trainer import cql_trainer
from pettingzoo.sisl import pursuit_v4


def main() -> None:
    env = pursuit_v4.parallel_env(
        n_pursuers=NUM_AGENTS,
        max_cycles=MAX_CYCLES
    )
    states, info = env.reset()
    obs_dims = dict([(agent_key, len(state)) for agent_key, state in states.items()])
    act_dims = dict([(agent_key, int(env.action_space(agent_key).n)) for agent_key in obs_dims.keys()])

    def group_sample():
        return {agent_key: int(env.action_space(agent_key).sample()) for agent_key in act_dims.keys()}

    central_agent = CqlAgent(sid='central_agent', config=CqlAgentConfig(
        obs_dims=obs_dims,
        act_dims=act_dims,
        hidden_dim=128,
        batch_size=64,
        lr=1e-4,
        grad_clip_value=0.5,
        gamma=0.95,
        eps_start=0.9,
        eps_decay=0.95,
        eps_min=0.05,
        mem_size=1_000,
    ).infer_joint_space(), act_sampler=group_sample)
    print(f'joint_obs_dim: {central_agent.config.obs_dims}')
    print(f'joint_obs_dim: {central_agent.config.obs_dim}')
    print(f'joint_act_dim: {central_agent.config.act_dim}')
    # central_agent.target_net
    cql_trainer(env=env,
                central_agent=central_agent,
                num_episodes=NUM_EPISODES,
                max_episode_lengths=MAX_EPISODES_LEN,
                dqn_update_freq=DQN_UPDATE_FREQ,
                )

### Run

In [None]:
if __name__ == "__main__":
    main()