Notes on tutorial: <https://pytorch.org/rl/stable/tutorials/multiagent_ppo.html>

In [1]:
# Torch
import torch

# Tensordict modules
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import multiprocessing

import torchrl

# Data collection
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

# Env
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import check_env_specs

# Multi-agent network
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal

# Loss
from torchrl.objectives import ClipPPOLoss, ValueEstimators

# Utils
torch.manual_seed(0)
from matplotlib import pyplot as plt
from tqdm import tqdm

In [2]:
# Devices
device = torch.device(0)
vmas_device = device  # The device where the simulator is run (VMAS can run on GPU)

# Sampling
frames_per_batch = 6_000  # Number of team frames collected per training iteration
n_iters = 10  # Number of sampling and training iterations
total_frames = frames_per_batch * n_iters

# Training
num_epochs = 30  # Number of optimization steps per training iteration
minibatch_size = 400  # Size of the mini-batches in each optimization step
lr = 3e-4  # Learning rate
max_grad_norm = 1.0  # Maximum norm for the gradients

# PPO
clip_epsilon = 0.2  # clip value for PPO loss
gamma = 0.99  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation
entropy_eps = 1e-4  # coefficient of the entropy term in the PPO loss

In [3]:
max_steps = 100  # Episode steps before done
num_vmas_envs = (
    frames_per_batch // max_steps
)  # Number of vectorized envs. frames_per_batch should be divisible by this number
scenario_name = "navigation"
n_agents = 3

#env = VmasEnv(
#    scenario=scenario_name,
#    num_envs=num_vmas_envs,
#    continuous_actions=True,  # VMAS supports both continuous and discrete actions
#    max_steps=max_steps,
#    device=vmas_device,
#    # Scenario kwargs
#    n_agents=n_agents,  # These are custom kwargs that change for each VMAS scenario, see the VMAS repo to know more.
#)

base_env = torchrl.envs.UnityMLAgentsEnv(registered_name="3DBall")
t = torchrl.envs.Stack(
    in_keys=[('group_0', f'agent_{idx}') for idx in range(12)],
    out_key='agents',
)
env = TransformedEnv(
    base_env,
    t,
)


  Downloading 3DBall : |████████████████████| 100% 
  Extracting  3DBall : |████████████████████| 100% 
  Cleaning up 3DBall : |████████████████████| 100% 
[UnityMemory] Configuration Parameters - Can be set up in boot.config
    "memorysetup-bucket-allocator-granularity=16"
    "memorysetup-bucket-allocator-bucket-count=8"
    "memorysetup-bucket-allocator-block-size=4194304"
    "memorysetup-bucket-allocator-block-count=1"
    "memorysetup-main-allocator-block-size=16777216"
    "memorysetup-thread-allocator-block-size=16777216"
    "memorysetup-gfx-main-allocator-block-size=16777216"
    "memorysetup-gfx-thread-allocator-block-size=16777216"
    "memorysetup-cache-allocator-block-size=4194304"
    "memorysetup-typetree-allocator-block-size=2097152"
    "memorysetup-profiler-bucket-allocator-granularity=16"
    "memorysetup-profiler-bucket-allocator-bucket-count=8"
    "memorysetup-profiler-bucket-allocator-block-size=4194304"
    "memorysetup-profiler-bucket-allocator-block-count=1"

  unity_communicator_version = StrictVersion(unity_com_ver)


In [4]:
print("action_spec:", base_env.full_action_spec)
print("reward_spec:", base_env.full_reward_spec)
print("done_spec:", base_env.full_done_spec)
print("observation_spec:", base_env.observation_spec)

action_spec: Composite(
    group_0: Composite(
        agent_0: Composite(
            continuous_action: BoundedContinuous(
                shape=torch.Size([2]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            device=None,
            shape=torch.Size([])),
        agent_1: Composite(
            continuous_action: BoundedContinuous(
                shape=torch.Size([2]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.flo

In [5]:
print("action_keys:", base_env.action_keys)
print("reward_keys:", base_env.reward_keys)
print("done_keys:", base_env.done_keys)

action_keys: [('group_0', 'agent_0', 'continuous_action'), ('group_0', 'agent_1', 'continuous_action'), ('group_0', 'agent_10', 'continuous_action'), ('group_0', 'agent_11', 'continuous_action'), ('group_0', 'agent_2', 'continuous_action'), ('group_0', 'agent_3', 'continuous_action'), ('group_0', 'agent_4', 'continuous_action'), ('group_0', 'agent_5', 'continuous_action'), ('group_0', 'agent_6', 'continuous_action'), ('group_0', 'agent_7', 'continuous_action'), ('group_0', 'agent_8', 'continuous_action'), ('group_0', 'agent_9', 'continuous_action')]
reward_keys: [('group_0', 'agent_0', 'group_reward'), ('group_0', 'agent_0', 'reward'), ('group_0', 'agent_1', 'group_reward'), ('group_0', 'agent_1', 'reward'), ('group_0', 'agent_10', 'group_reward'), ('group_0', 'agent_10', 'reward'), ('group_0', 'agent_11', 'group_reward'), ('group_0', 'agent_11', 'reward'), ('group_0', 'agent_2', 'group_reward'), ('group_0', 'agent_2', 'reward'), ('group_0', 'agent_3', 'group_reward'), ('group_0', 'age

In [6]:
agent_keys = [keys[:-1] for keys in base_env.action_keys]
print(agent_keys)

[('group_0', 'agent_0'), ('group_0', 'agent_1'), ('group_0', 'agent_10'), ('group_0', 'agent_11'), ('group_0', 'agent_2'), ('group_0', 'agent_3'), ('group_0', 'agent_4'), ('group_0', 'agent_5'), ('group_0', 'agent_6'), ('group_0', 'agent_7'), ('group_0', 'agent_8'), ('group_0', 'agent_9')]


In [8]:
t = torchrl.envs.Stack(
    in_keys=[('group_0', f'agent_{idx}') for idx in range(12)],
    out_key='agents',
)

env = TransformedEnv(
    base_env,
    t,
)

In [9]:
env.reset()

TensorDict(
    fields={
        agents: TensorDict(
            fields={
                VectorSensor_size8: Tensor(shape=torch.Size([12, 8]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([12, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                terminated: Tensor(shape=torch.Size([12, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([12, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([12]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)