In [1]:
import argparse
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from gym import spaces


from onpolicy.algorithms.r_mappo.algorithm.intention_sharing import IntentionSharingModel
from onpolicy.envs.env_wrappers import SubprocVecEnv, DummyVecEnv
from onpolicy.envs.mpe.MPE_env import MPEEnv
from onpolicy.config import get_config

In [2]:
def preprocess_actions(actions):
        if actions.dim() == 5:
            # We assume actions having the dims (Rollout, Sequence, Batch, Agent, Features)
            action_shape = actions.shape
            n_agents = actions.size(3)
            own_action = []
            other_actions = []
            for idx in range(n_agents):
                mask = torch.ones_like(actions)
                mask[..., idx, :] = 0
                mask = mask == 1
                other_actions.append(actions[mask].reshape(*action_shape[:3], 1, -1))
                # We want to keep the agent dim
                own_action.append(actions[..., idx : idx + 1, :])
            other_actions = torch.cat(other_actions, 3)
            own_action = torch.cat(own_action, 3)
            all_actions = torch.cat((own_action, other_actions), -1)
        elif actions.dim() == 4:
            # We assume actions having the dims (Seq, Batch, Agent, Features)
            action_shape = actions.shape
            n_agents = actions.shape[2]
            own_action = []
            other_actions = []
            for idx in range(n_agents):
                mask = torch.ones_like(actions)
                mask[:, :, idx, :] = 0
                mask = mask == 1
                other_actions.append(actions[mask].reshape(*action_shape[:2], 1, -1))
                # We want to keep the agent dim
                own_action.append(actions[..., idx : idx + 1, :])
            other_actions = torch.cat(other_actions, 2)
            own_action = torch.cat(own_action, 2)
            all_actions = torch.cat((own_action, other_actions), -1)
        else:
            ValueError(
                f"Input expected to have 4 or 5 dimensions, got {actions.dim()} instead."
            )
        return all_actions


In [3]:
def pretrain_world_model(
        policy, env, n_agents, n_samples, batch_size, n_episodes
    ):
        # Since each env step gives the number of agents in samples
        samples_per_env_step = n_agents * env.num_envs
        obs_batch = []
        action_batch = []
        # env reset only gives back the first observation
        obs = env.reset()
        obs_rollout = [obs]
        action_rollout = []
        n_actions = env.action_space[0].n
        actions_per_step = np.product(obs.shape[:-1])
        obs_predictor = policy.actor.observation_predictor
        optim = policy.obs_predictor_optimizer
        collected_samples = 0
        # Collect observations from random trajectories
        while True:
            random_actions = torch.randint(n_actions, (actions_per_step,)).reshape(
                (*obs.shape[:-1], 1)
            )
            random_actions = F.one_hot(random_actions, n_actions).squeeze(-2)
            obs, _, dones, _ = env.step(random_actions)
            obs_rollout.append(obs)
            action_rollout.append(random_actions.numpy())
            if np.any(dones):
                obs_batch.append(obs_rollout)
                action_batch.append(action_rollout)
                collected_samples += len(action_rollout) * samples_per_env_step
                if collected_samples >= n_samples:
                    break
                obs_rollout = [env.reset()]
                action_rollout = []

        obs_batch = torch.Tensor(np.array(obs_batch))
        action_batch = torch.Tensor(np.array(action_batch))
        action_batch = preprocess_actions(action_batch)
        # To ensure that the timesteps in the train batches are in correct order, we have to transpose the obs tensor
        perm_seq = (0, 2, 3, 1, 4)
        obs_batch = obs_batch.permute(perm_seq)
        action_batch = action_batch.permute(perm_seq)
        # Take all obs except the last of each episode as the input data
        initial_obs = obs_batch[..., :-1, :]
        x_train = torch.cat((initial_obs, action_batch), -1)
        x_train = x_train.reshape(-1, x_train.size(-1))[:n_samples].to(
            **policy.actor.tpdv
        )
        # Take all consecutive obs of x as our labels
        y_train = (
            obs_batch[..., 1:, :]
            .reshape(-1, obs_batch.size(-1))[:n_samples]
            .to(**policy.actor.tpdv)
        )
        initial_obs_train = initial_obs.reshape(-1, obs_batch.size(-1))[:n_samples].to(
            **policy.actor.tpdv
        )

        n_batches = n_samples // batch_size

        loss_fn = nn.MSELoss()
        steps = 0
        for _ in range(n_episodes):
            samples = torch.randperm(n_samples)
            for n_batch in range(n_batches):
                optim.zero_grad()
                idx = samples[batch_size * n_batch : batch_size * (n_batch + 1)]
                x = x_train[idx]
                x_obs = initial_obs_train[idx]
                y = y_train[idx]
                # The observation predictor infers the the difference between current and next observation. Therefore we add the current obs to the output.
                y_pred = obs_predictor(x) + x_obs
                loss = loss_fn(y_pred, y)
                loss.backward()
                optim.step()
                print(f"WM obs. predictor loss at step {steps}: {loss:.4f}")
                steps += batch_size

In [4]:
def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "MPE":
                env = MPEEnv(all_args)
            else:
                print("Can not support the " + all_args.env_name + "environment.")
                raise NotImplementedError
            env.seed(all_args.seed + rank * 1000)
            return env

        return init_env

    if all_args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])

In [5]:
def make_args(args_dict):  
    args = argparse.Namespace()
    for key, value in args_dict.items():
        setattr(args, key, value)
    return args

In [10]:
class Policy(nn.Module):

    def __init__(self, args, envs):
        super().__init__()

        self.actor = IntentionSharingModel(
            args,
            envs.observation_space[0],
            envs.action_space[0],
            spaces.Box(-np.inf, np.inf, (32,)),
            args.num_agents,
            torch.device('cpu'),
    )
        self.obs_predictor_optimizer = torch.optim.Adam(self.actor.observation_predictor.parameters(), 1e-3)
        

In [11]:
args_dict = {
        'env_name': "MPE",
        'use_obs_instead_of_state': False,
        'scenario_name': 'simple_spread',
        'num_agents': 3,
        'num_landmarks': 3, 
        'n_training_threads': 1,
        'n_rollout_threads': 1,
        'episode_length': 25,
        'use_local_obs': False,
        'seed': 1,
        'hidden_size': 64,
        'recurrent_N': 2,
        'gain': 0.01,
        'use_orthogonal': True,
        'use_policy_active_masks': False,
        'use_naive_recurrent_policy': False,
        'use_recurrent_policy': False,
        'intention_aggregation': 'mean',
        'imagined_traj_len': 4,
        'communication_interval': 4,
        'use_feature_normalization': True,
        'use_ReLU': True,
        'stacked_frames': 1,
        'layer_N': 1,
        
        }

In [15]:
args = make_args(args_dict)
envs = make_train_env(args)
policy = Policy(args, envs)

In [16]:
pretrain_world_model(policy, envs, args.num_agents, 1000, 100, 10)

WM obs. predictor loss at step 0: 0.4095
WM obs. predictor loss at step 100: 0.3127
WM obs. predictor loss at step 200: 0.3328
WM obs. predictor loss at step 300: 0.2353
WM obs. predictor loss at step 400: 0.2380
WM obs. predictor loss at step 500: 0.2468
WM obs. predictor loss at step 600: 0.1781
WM obs. predictor loss at step 700: 0.1631
WM obs. predictor loss at step 800: 0.1965
WM obs. predictor loss at step 900: 0.1878
WM obs. predictor loss at step 1000: 0.1896
WM obs. predictor loss at step 1100: 0.1534
WM obs. predictor loss at step 1200: 0.0988
WM obs. predictor loss at step 1300: 0.1247
WM obs. predictor loss at step 1400: 0.0802
WM obs. predictor loss at step 1500: 0.0786
WM obs. predictor loss at step 1600: 0.1176
WM obs. predictor loss at step 1700: 0.1213
WM obs. predictor loss at step 1800: 0.1578
WM obs. predictor loss at step 1900: 0.0858
WM obs. predictor loss at step 2000: 0.1187
WM obs. predictor loss at step 2100: 0.0855
WM obs. predictor loss at step 2200: 0.0537
