In [1]:
import argparse
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from gym import spaces


from onpolicy.algorithms.r_mappo.algorithm.intention_sharing import IntentionSharingModel
from onpolicy.envs.env_wrappers import SubprocVecEnv, DummyVecEnv
from onpolicy.envs.mpe.MPE_env import MPEEnv
from onpolicy.config import get_config

In [2]:
def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "MPE":
                env = MPEEnv(all_args)
            else:
                print("Can not support the " + all_args.env_name + "environment.")
                raise NotImplementedError
            env.seed(all_args.seed + rank * 1000)
            return env

        return init_env

    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])

In [3]:
def make_args(args_dict):  
    args = argparse.Namespace()
    for key, value in args_dict.items():
        setattr(args, key, value)
    return args

In [4]:
args_dict = {
        'env_name': "MPE",
        'use_obs_instead_of_state': False,
        'scenario_name': 'simple_spread',
        'num_agents': 3,
        'num_landmarks': 3, 
        'n_training_threads': 1,
        'n_rollout_threads': 1,
        'episode_length': 25,
        'use_local_obs': False,
        'seed': 1,
        'hidden_size': 64,
        'recurrent_N': 2,
        'gain': 0.01,
        'use_orthogonal': True,
        'use_policy_active_masks': False,
        'use_naive_recurrent_policy': False,
        'use_recurrent_policy': False,
        'intention_aggregation': 'mean',
        'imagined_traj_len': 4,
        'communication_interval': 4,
        'use_feature_normalization': True,
        'use_ReLU': True,
        'stacked_frames': 1,
        'layer_N': 1,
        
        }

In [7]:
args = make_args(args_dict)
envs = make_train_env(args)
env = MPEEnv(args)


In [13]:
env.action_space

[Discrete(5), Discrete(5), Discrete(5)]

In [27]:
one_hot_actions = [np.array([0,0,0,0,1]),np.array([0,0,0,0,1]),np.array([0,0,0,0,1])]
actions = np.array([[1], [1], [1]])

env.reset()
# env.step(one_hot_actions)
obs, rews, dones, infos = envs.step(np.array([one_hot_actions]))

In [31]:
infos.shape

(1, 3)