In [1]:
"""Cast a multi-agent env as vec env to use SB3's PPO."""
import logging
from datetime import datetime
from typing import List, Union
import typer
import torch
import wandb
import numpy as np
import matplotlib.pyplot as plt
# Multi-agent as vectorized environment
from nocturne.envs.vec_env_ma import MultiAgentAsVecEnv
from utils.config import load_config
from utils.render import make_video
from box import Box
# Custom callback
from utils.sb3.callbacks import CustomMultiAgentCallback
import yaml
# Custom PPO class that supports multi-agent control
from utils.sb3.custom_ppo import MultiAgentPPO
from utils.string_utils import datetime_to_str

In [2]:
with open(f"/home/emerge/daphne/nocturne_lab/configs/env_config.yaml", "r") as stream:
    env_config = Box(yaml.safe_load(stream))
env_config.data_path = "/home/emerge/daphne/nocturne_lab/data_tmp/"

### Create environment

In [3]:
env = MultiAgentAsVecEnv(
    config=env_config, 
    num_envs=env_config.max_num_vehicles,
)

obs = env.reset()

In [4]:
obs.shape

(20, 6730)

In [5]:
obs[0, :].shape

(6730,)

In [6]:
# The ego state
obs[:10, :].shape

print(obs[:10, :])

[[0.28085133 0.49425942 0.00275741 ... 0.         0.         0.        ]
 [0.33037499 0.583      0.08777953 ... 0.         0.         0.        ]
 [       nan        nan        nan ...        nan        nan        nan]
 ...
 [       nan        nan        nan ...        nan        nan        nan]
 [       nan        nan        nan ...        nan        nan        nan]
 [       nan        nan        nan ...        nan        nan        nan]]


### Model

This is the base MLP in stable baselines

In [10]:
#CustomActorCriticPolicy, 
model = MultiAgentPPO(env=env, policy="MlpPolicy", verbose=1)
model.learn(50_000)

Using cuda device
-----------------------------
| time/              |      |
|    fps             | 1068 |
|    iterations      | 1    |
|    time_elapsed    | 3    |
|    total_timesteps | 3747 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 956         |
|    iterations           | 2           |
|    time_elapsed         | 8           |
|    total_timesteps      | 8049        |
| train/                  |             |
|    approx_kl            | 0.011197504 |
|    clip_fraction        | 0.133       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.21       |
|    explained_variance   | 0.00433     |
|    learning_rate        | 0.0003      |
|    loss                 | 0.133       |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0108     |
|    value_loss           | 0.552       |
-----------------------------------------
----------------

<utils.sb3.custom_ppo.MultiAgentPPO at 0x7f96ec07a650>

In [8]:
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

from gymnasium import spaces
import torch as th
from torch import nn

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy


class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the features extractor.

    :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
    :param last_layer_dim_pi: (int) number of units for the last layer of the policy network
    :param last_layer_dim_vf: (int) number of units for the last layer of the value network
    """

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 64,
        last_layer_dim_vf: int = 64,
    ):
        super().__init__()

        # IMPORTANT:
        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # Policy network
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
        )
        # Value network
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
        )

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)


class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        # Disable orthogonal initialization
        kwargs["ortho_init"] = False
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )


    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)

In [9]:
#CustomActorCriticPolicy, 
model = MultiAgentPPO(env=env, policy="MlpPolicy", verbose=1)
model.learn(5000)

Using cuda device
-----------------------------
| time/              |      |
|    fps             | 1183 |
|    iterations      | 1    |
|    time_elapsed    | 3    |
|    total_timesteps | 4048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 976         |
|    iterations           | 2           |
|    time_elapsed         | 8           |
|    total_timesteps      | 8147        |
| train/                  |             |
|    approx_kl            | 0.012126818 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.21       |
|    explained_variance   | 0.0163      |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0875      |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00909    |
|    value_loss           | 0.512       |
-----------------------------------------


<utils.sb3.custom_ppo.MultiAgentPPO at 0x7f96ec07a6e0>

### Making the network permutation equivariant


In [11]:
class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the features extractor.

    :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
    :param last_layer_dim_pi: (int) number of units for the last layer of the policy network
    :param last_layer_dim_vf: (int) number of units for the last layer of the value network
    """

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 64,
        last_layer_dim_vf: int = 64,
    ):
        super().__init__()

        # IMPORTANT:
        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # Policy network
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
        )
        # Value network
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
        )

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)


class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        # Disable orthogonal initialization
        kwargs["ortho_init"] = False
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )


    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)

In [21]:
import torch.nn as nn
import torch as th
from typing import Tuple

class PermEqNetwork(nn.Module):
    def __init__(self, latent_dim_pi: int = 64, latent_dim_vf: int = 64):
        super().__init__()
        self.state_dim = 10
        self.observation_dim = 6720

        self.latent_dim_pi = latent_dim_pi
        self.latent_dim_vf = latent_dim_vf

        # Policy network for state
        self.state_net = nn.Sequential(
            nn.Linear(self.state_dim, latent_dim_pi),
            nn.ReLU()
        )

        # Policy network for observation
        self.observation_net = nn.Sequential(
            nn.Linear(self.observation_dim, latent_dim_pi),
            nn.Linear(latent_dim_pi, latent_dim_pi),
            nn.Linear(latent_dim_pi, latent_dim_pi),
            nn.Linear(latent_dim_pi, latent_dim_pi),
            nn.Linear(latent_dim_pi, latent_dim_pi)
        )

        # Value network for state
        self.value_state_net = nn.Sequential(
            nn.Linear(self.state_dim, latent_dim_vf),
            nn.ReLU()
        )

        # Value network for observation
        self.value_observation_net = nn.Sequential(
            nn.Linear(self.observation_dim, latent_dim_vf),
            nn.ReLU()
        )

        # Attention layer for observation
        self.attention = nn.MultiheadAttention(latent_dim_pi, num_heads=1)

        # Final layer combining state and processed observation
        self.final_layer = nn.Linear(latent_dim_pi + latent_dim_vf, latent_dim_pi)

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        state, observation = features[:, :10], features[:, 10:]

        # Process state and observation separately through policy networks
        latent_policy_state = self.state_net(state)
        latent_policy_observation = self.observation_net(observation)

        # Process state and observation through value networks
        latent_value_state = self.value_state_net(state)
        latent_value_observation = self.value_observation_net(observation)

        # Attention mechanism on the observation
        observation = observation.permute(1, 0, 2)  # Change shape for MultiheadAttention
        output, _ = self.attention(latent_policy_observation.unsqueeze(0), observation, observation)
        latent_policy_observation = th.max(output.squeeze(0), dim=0).values

        # Concatenate policy and value outputs
        latent_policy = th.cat((latent_policy_state, latent_policy_observation), dim=1)
        latent_value = th.cat((latent_value_state, latent_value_observation), dim=1)

        # Combine policy and value outputs
        final_output = th.cat((latent_policy, latent_value), dim=1)
        final_output = self.final_layer(final_output)

        return final_output[:, :self.latent_dim_pi], final_output[:, self.latent_dim_pi:]


class PermEqActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        # Disable orthogonal initialization
        kwargs["ortho_init"] = False
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )


    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = PermEqNetwork(self.features_dim)

In [22]:
#CustomActorCriticPolicy, 
model = MultiAgentPPO(env=env, policy=PermEqActorCriticPolicy, verbose=1)
model.learn(5000)

Using cuda device


RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3