## IQL for Foraging

### Arguments

In [4]:
FORAGER_LEVELS: list[int] = [1, 2, 2, 3, 3]
CROP_LEVELS: list[int] = [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4]
MAX_CYCLES: int = 100
NUM_EPISODES: int = 100
MAX_EPISODES_LEN: int = 100
# target DQN update/freeze frequency
DQN_UPDATE_FREQ: int = 10

RENDER: bool = False
SHOW_PLOTS: bool = False

### Main Trainer

In [5]:
from envs import foraging
from envs.utils import EnvConfig
from agents import IqlAgent, IqlAgentConfig
from trainer import iql_trainer


def main() -> None:
    foraging_config: EnvConfig = EnvConfig(
        name_abbr="fo",
        env_creator=foraging.parallel_env,
        env_kwargs={
            "n_foragers": len(FORAGER_LEVELS),
            "forager_levels": FORAGER_LEVELS,
            "n_crops": len(CROP_LEVELS),
            "crop_levels": CROP_LEVELS,
            "max_cycles": MAX_CYCLES,
            "reward_idx": 0,
        },
    )
    env = foraging_config.get_env(
        render_mode = "human" if RENDER else None,
    )

    states, info = env.reset()
    obs_dims = dict([(agent_key, state.size) for agent_key, state in states.items()])
    act_dims = dict([(agent_key, int(env.action_space(agent_key).n)) for agent_key in obs_dims.keys()])
    print(
        f"Observation dimensions: {next(iter(obs_dims.values()))}\n"
        f"Action dimensions: {next(iter(act_dims.values()))}"
    )

    cur_agents = {
        agent_key: IqlAgent(
            sid=agent_key,
            config=IqlAgentConfig(
                obs_dim=obs_dims[agent_key],
                act_dim=act_dims[agent_key],
                hidden_dims=[100, 50, 25],
                batch_size=128,
                lr=1e-3,
                grad_clip_value=5,
                gamma=0.95,
                eps_start=0.9,
                eps_decay=0.95,
                eps_min=0.05,
                mem_size=10_000,
            ), act_sampler=env.action_space(agent_key).sample)
        for agent_key in states.keys()
    }
    iql_trainer(
        env=env,
        env_config=foraging_config,
        cur_agents=cur_agents,
        num_episodes=NUM_EPISODES,
        max_episode_lengths=MAX_EPISODES_LEN,
        dqn_update_freq=DQN_UPDATE_FREQ,
    )

### Run

In [6]:
if __name__ == "__main__":
    main()

Observation dimensions: 98
Action dimensions: 5
Episode 0: Avg return = 1.2300;
4.1720 vs best: -inf, update TarNet
4.9940 vs best: 4.1720, update TarNet
10.6560 vs best: 4.9940, update TarNet
12.8147 vs best: 10.6560, update TarNet
13.0100 vs best: 12.8147, update TarNet
15.9973 vs best: 13.0100, update TarNet
17.0473 vs best: 15.9973, update TarNet
22.5620 vs best: 17.0473, update TarNet
23.1734 vs best: 22.5620, update TarNet
Episode 10: Avg return = 13.3689;
23.6627 vs best: 23.1734, update TarNet
25.2613 vs best: 23.6627, update TarNet
Episode 20: Avg return = 14.1400;
25.2867 vs best: 25.2613, update TarNet
28.9620 vs best: 25.2867, update TarNet
29.8347 vs best: 28.9620, update TarNet
Episode 30: Avg return = 27.2747;
Episode 40: Avg return = 10.5520;
29.8373 vs best: 29.8347, update TarNet
Episode 50: Avg return = 29.1347;
34.5040 vs best: 29.8373, update TarNet
Episode 60: Avg return = 16.4893;
Episode 70: Avg return = 9.9740;
Episode 80: Avg return = 17.0100;
Episode 90: Avg 