In [1]:
import argparse
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

import gym
from gym import spaces


from onpolicy.algorithms.r_mappo.algorithm.intention_sharing import IntentionSharingModel
from onpolicy.envs.env_wrappers import SubprocVecEnv, DummyVecEnv
from onpolicy.config import get_config

from pettingzoo.sisl import pursuit_v4


In [2]:
def make_args(args_dict):  
    args = argparse.Namespace()
    for key, value in args_dict.items():
        setattr(args, key, value)
    return args

In [3]:
args_dict = {
        'env_name': "SISL",
        'use_obs_instead_of_state': False,
        'scenario_name': 'pursuit',
        'num_agents': 3,
        'num_landmarks': 3, 
        'n_training_threads': 1,
        'n_rollout_threads': 1,
        'episode_length': 25,
        'use_local_obs': False,
        'seed': 1,
        'hidden_size': 64,
        'recurrent_N': 2,
        'gain': 0.01,
        'use_orthogonal': True,
        'use_policy_active_masks': False,
        'use_naive_recurrent_policy': False,
        'use_recurrent_policy': False,
        'intention_aggregation': 'mean',
        'imagined_traj_len': 4,
        'communication_interval': 4,
        'use_feature_normalization': True,
        'use_ReLU': True,
        'stacked_frames': 1,
        'layer_N': 1,
        
        }


In [8]:
class PettingZooToOnPolicyWrapper(gym.Env):

    def __init__(self, env, seed):
        self.env = env
        self.seed = seed
        self.n = self.env.max_num_agents
        obs_shape = self.env.observation_space(self.env.possible_agents[0]).shape
        self.observation_space = [self.env.observation_space(agent) for agent in self.env.possible_agents]
        self.action_space = [self.env.action_space(agent) for agent in self.env.possible_agents]
        self.share_observation_space = spaces.Box(-np.inf, np.inf, (self.n, *obs_shape))

    def reset(self):
        obs, _ = self.env.reset(self.seed)
        obs = self._dict_to_array(obs)
        return obs
    
    def step(self, actions):
        actions = self._convert_actions(actions)
        actions = self._array_to_dict(actions)
        obs, rewards, terminated, truncated, info = self.env.step(actions)

        obs = self._dict_to_array(obs)
        rewards = self._dict_to_array(rewards)
        rewards = [[r] for r in rewards]
        terminated = self._dict_to_array(terminated)
        truncated = self._dict_to_array(truncated)
        terminated = [a or b for a, b in zip(terminated, truncated)]
        info = self._dict_to_array(info)
        return obs, rewards, terminated, info
    
    def _convert_actions(self, actions):
        converted_actions = []
        for a in actions:
            converted_actions.append(np.argmax(a))
        return converted_actions

    def _dict_to_array(self, d):
        a = []
        for agent in self.env.possible_agents:
            a.append(d[agent])
        return a
    
    def _array_to_dict(self, a):
        d = {}
        for agent, value in zip(self.env.possible_agents, a):
            d[agent] = value
        return d

    
    def __getattr__(self, attr):
        return getattr(self.env, attr)
    
def SISLEnv(args, seed):
    if args.scenario_name == 'pursuit':
        return PettingZooToOnPolicyWrapper(pursuit_v4.parallel_env(max_cycles=args.episode_length, n_pursuers=args.num_agents), seed)
    else:
        NotImplementedError(f'{args.scenario_name} is not implemented in SISL.')

In [5]:
def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "SISL":
                env = SISLEnv(all_args, seed=all_args.seed + rank * 1000)
            else:
                print("Can not support the " + all_args.env_name + "environment.")
                raise NotImplementedError
            return env

        return init_env

    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])

In [9]:
args = make_args(args_dict)
envs = make_train_env(args)

In [10]:
envs.reset()
one_hot_actions = [np.array([0,0,0,0,1]) for _ in range(3)]
step = 0
while True:
    step += 1
    if step == 25:
        print("stop")
    obs, rews, dones, infos = envs.step(one_hot_actions)
    if np.any(dones):
        print('done')
        break


stop
done


In [None]:
infos.shape

In [None]:
def get_action_samples(action_spaces):
    sample = []
    for action_space in action_spaces:
        sample.append(action_space.sample())
    return sample


In [None]:
penv = SISLEnv(args, 4)
penv.reset()
one_hot_actions = [np.array([0,0,0,0,1]) for _ in range(penv.max_num_agents)]
penv.step(one_hot_actions)
# penv.step(get_action_samples(penv))