In [1]:
# OPTIONAL: Load the "autoreload" extension so that code can change
%load_ext autoreload

# OPTIONAL: always reload modules so that as you change code in src, it gets loaded
%autoreload 2

# Imports

In [2]:
import ray

import os

from ray.tune.registry import register_env
from ray.tune.logger import pretty_print

# from ray.rllib.algorithms.apex_ddpg import ApexDDPGConfig
from ray.rllib.algorithms.dqn import DQNConfig, DQNTFPolicy, DQNTorchPolicy
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv

  from .autonotebook import tqdm as notebook_tqdm


# Env

In [3]:
def env_creator(render_mode="rgb_array", cycles=200):
    from src.world import world_utils
    env = world_utils.env(render_mode=render_mode, max_cycles=cycles)
    return env

register_env("tagworld", lambda config: ParallelPettingZooEnv(env_creator()))

### Fixing non-identical observation spaces

In [4]:
# from supersuit import pad_action_space_v0

In [5]:
env = env_creator()
# env = pad_action_space_v0(env)

In [6]:
env.observation_spaces

{'adversary_0': Box(-inf, inf, (34,), float32),
 'adversary_1': Box(-inf, inf, (34,), float32),
 'adversary_2': Box(-inf, inf, (34,), float32),
 'agent_0': Box(-inf, inf, (34,), float32)}

In [7]:
obs_space = env.observation_space
act_space = env.action_space

In [8]:
env.observation_space

<bound method SimpleEnv.observation_space of <src.world.world_utils.raw_env object at 0x7ff98efc7fa0>>

# Parameters

In [9]:
stop_iters = 20
stop_timesteps = 100000
stop_reward = 50.0

# Ray config

In [10]:
ray.init()

2023-03-10 23:24:10,841	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Python version:,3.9.16
Ray version:,2.3.0


In [11]:
policies = {
    "dqn_policy": (
    DQNTorchPolicy,
    obs_space,
    act_space,
    {},
    )
}

In [12]:
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    return "dqn_policy"
        # if agent_id % 2 == 0:
            # return "ppo_policy"
        # else:

In [13]:
dqn_config = (
        DQNConfig()
        .environment("tagworld")
        .framework("torch")
        # disable filters, otherwise we would need to synchronize those
        # as well to the DQN agent
        .rollouts(observation_filter="MeanStdFilter")
        .training(
            model={"vf_share_layers": True},
            n_step=3,
            gamma=0.95,
        )
        .multi_agent(
            policies=policies,
            policy_mapping_fn=policy_mapping_fn,
            policies_to_train=["dqn_policy"],
        )
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
    )

dqn = dqn_config.build()

2023-03-10 23:24:32,850	INFO algorithm.py:506 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


AttributeError: 'function' object has no attribute 'shape'

# Train

In [None]:
for i in range(stop_iters):
    print("== Iteration", i, "==")

    # improve the DQN policy
    print("-- DQN --")
    result_dqn = dqn.train()
    print(pretty_print(result_dqn))