# Experiment 3: two agents selecting non-overlapping image segments in a MARL environment (Fig 4)

In [None]:
# To run the notebook, change the string below
dir_path = '/path/to/my/directory'

# Uncomment to run in Google Colab
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)
import sys
sys.path.append(dir_path)

# ViT model based on https://github.com/facebookresearch/dino/blob/main/README.md
import sys
sys.path.append(dir_path)
import time
from torchvision.utils import save_image
import importlib
import torch
import torchvision
from torchvision import datasets, transforms, utils
from torch.utils.data import Dataset, TensorDataset, DataLoader
import torch.utils.data as data_utils
import numpy as np
from PIL import Image
import vision_transformer_mlpcritic
import fnmatch
import os
import matplotlib.pyplot as plt
device = torch.device("cuda")
import torch.random
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.nn.distributions import NormalParamExtractor

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.utils import check_env_specs

from torchrl.modules import ProbabilisticActor, TanhNormal
from torch.distributions import Categorical
from torchrl.objectives import ClipPPOLoss, ValueEstimators

from torch import nn

from torchrl.data import (
    BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec,
    DiscreteTensorSpec
)
from torchrl.envs import (
    CatTensors,
    EnvBase,
    Transform,
    TransformedEnv,
    UnsqueezeTransform,
    StepCounter
)
from torchrl.envs.transforms.transforms import _apply_to_composite
from torchrl.envs.utils import check_env_specs, step_mdp
import torchvision.transforms.functional as TF
from torchrl.modules import  ValueOperator, ActorCriticOperator
from tensordict.nn import (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, TensorDictModule, TensorDictParams, TensorDictSequential)

traindir = dir_path+"/data/train/marl_full"
train_transforms = transforms.Compose([transforms.Resize((16,16)),
                                       transforms.ToTensor(),
                                       ])

train_data = datasets.ImageFolder(traindir,transform=train_transforms)
def cycle(iterable):
    while True:
        for x in iterable:
            yield x
image_iter = iter(cycle(torch.utils.data.DataLoader(train_data, shuffle = True, batch_size=1)))

device = "cuda"
vmas_device = device

frames_per_batch = 6000  # number of team frames collected per training iteration
n_iters = 50  # number of sampling and training iterations
total_frames = frames_per_batch * n_iters

num_epochs = 30  # number of optimization steps per training iteration
minibatch_size = 300  # size of the mini-batches in each optimization step
lr = 2e-4
max_grad_norm = 1.0  # maximum norm for the gradients

clip_epsilon = 0.2  # clip value for PPO loss
gamma = 0.9  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation
entropy_eps = 1e-5  # coefficient of the entropy term in the PPO loss
prediction_coef = 1e-3
FIELD = next(image_iter)[0].squeeze().to("cuda")
DISCOVER_WT = 0.5
OVERLAP_WT = 1.7
seeds = [89982, 44686, 50255, 95253, 81551,
         56170, 93726, 62698, 73980, 18263,
         61224, 24087, 13041, 99425, 70874]


In [None]:
class SegmentationEnv(EnvBase):
  def __init__(self, n_agents=2, n_envs=1, name="", device="cuda"):
    self.name = name
    self.step_count = 0
    self.device=device
    super(SegmentationEnv, self).__init__()
    self.field_size = FIELD.shape[1:]
    self.n_agents = n_agents
    self.n_envs = n_envs
    self.field = FIELD.expand(self.n_envs, 3, *self.field_size)
    self.player_segs = torch.zeros(self.n_envs, self.n_agents, *self.field_size)

    # specs: the expected shapes of variables the environment must keep track of
    self.full_action_spec = CompositeSpec(
        agents=CompositeSpec(
          action=DiscreteTensorSpec(
          n=2,
          shape=torch.Size([self.n_envs, self.n_agents, *self.field_size],),
           device=self.device),
        device=self.device
    ),device=self.device)

    self.full_observation_spec = CompositeSpec(
        agents=CompositeSpec(
          observation=UnboundedContinuousTensorSpec(
              # observation (shared for all agents): n_agents player fields and 1 game field
              shape=torch.Size([self.n_envs, self.n_agents+3, *self.field_size],),
         device=self.device),
       device=self.device,
    ),device=self.device)

    self.full_reward_spec = CompositeSpec(
        agents=CompositeSpec(
          reward=UnboundedContinuousTensorSpec(shape=torch.Size([self.n_envs, self.n_agents, 1],),
          device=self.device),
          agent0=UnboundedContinuousTensorSpec(shape=torch.Size([self.n_envs, 1],),
          device=self.device),
          agent1=UnboundedContinuousTensorSpec(shape=torch.Size([self.n_envs, 1],),
          device=self.device),
          overlap=UnboundedContinuousTensorSpec(shape=torch.Size([self.n_envs, 1],),
          device=self.device),
        device=self.device
    ),device=self.device)

    self.full_state_spec = CompositeSpec(
        agents=CompositeSpec(
            episode_reward=UnboundedContinuousTensorSpec(
                    shape=torch.Size([self.n_envs, self.n_agents, 1],),
          device=self.device),
        device=self.device
    ),device=self.device)



  def _reset(self, tensordict=None, **kwargs):
          out_tensordict = TensorDict({}, batch_size=torch.Size(), device=self.device)
          FIELD = next(image_iter)[0].squeeze().to("cuda")
          self.field = FIELD.expand(self.n_envs, 3, *self.field_size)
          self.player_segs = torch.zeros(self.n_envs, self.n_agents, *self.field_size).to(self.device)
          # out: [1, n_agents+3, field size]
          out_tensordict.set(("agents","observation"), torch.cat(
              (self.field, self.player_segs), dim=1).to(self.device), batch_size=torch.Size(),
                             device=self.device)
          return out_tensordict

  def _step(self, tensordict):
          self.step_count +=1
          action = tensordict[("agents","action")]
          # print("Board:\n"+str(torch.sum(self.player_segs, dim=1)[0,...]))

          # player_overlap: number of times agent's selection intersected others'
          player_overlap = []
          # new_segs: segments any agent selected that were untouched (value 0) before this round
          new_segs = 0
          new_segs_agent = []
          for agent in range(self.n_agents):
            self_action = action[:, agent, ...].unsqueeze(1)
            # print("Agent "+str(agent)+" action:\n"+str(self_action[0,...]))
            unselected_mask = ~(torch.sum(self.player_segs, dim=1) > 0).unsqueeze(1)
            dones = torch.count_nonzero(unselected_mask,
                                          dim=tuple(range(1,len(self_action.shape)))) == 0
            # print(dones)
            new_segs += torch.count_nonzero(self_action*unselected_mask)
            new_segs_agent.append(torch.count_nonzero(self_action*unselected_mask))
            others_action = torch.cat((action[:, :agent, ...],
                                       action[:, agent+1:, ...]), dim=1)
            self_action = self_action.expand(*others_action.shape)
            overlap = self_action == others_action
            nonzero_overlap = self_action * overlap
            nonzero_overlap = torch.count_nonzero(nonzero_overlap,
                                          dim=tuple(range(1,len(nonzero_overlap.shape))))
            # print("Num new: "+str(num_new_segs[0].item()) +", Overlap: "+str(nonzero_overlap[0].item()))
            player_overlap.append(nonzero_overlap.unsqueeze(-1))

          player_overlap = torch.stack(player_overlap, dim=1)
          new_segs = torch.Tensor(new_segs).expand(*player_overlap.shape)

          self.player_segs += action

          reward = (DISCOVER_WT*new_segs)/(1+(OVERLAP_WT*player_overlap))

          out_tensordict = TensorDict(
              {
              "agents": {
              "observation": torch.cat((self.field, self.player_segs), dim=1),
              "reward": reward,
              "agent0": new_segs_agent[0],
              "agent1": new_segs_agent[1],
              "overlap": player_overlap[:,0].squeeze(-1)
              },
              "done": dones
              }, batch_size=torch.Size(), device=self.device)

          return out_tensordict

  def _set_seed(self, seed):
          pass

In [None]:
senv = SegmentationEnv(n_envs=1, n_agents=2)
print(senv.reset())

senv = TransformedEnv(
    senv,
    RewardSum(in_keys=[("agents", "reward")], out_keys=[("agents", "episode_reward")]),
)
senv = TransformedEnv(
    senv,
    RewardSum(in_keys=[("agents", "agent0")], out_keys=[("agents", "episode_agent0")]),
)
senv = TransformedEnv(
    senv,
    RewardSum(in_keys=[("agents", "agent1")], out_keys=[("agents", "episode_agent1")]),
)
senv = TransformedEnv(
    senv,
    RewardSum(in_keys=[("agents", "overlap")], out_keys=[("agents", "episode_overlap")]),
)

senv = TransformedEnv(
    senv,
    StepCounter(max_steps=256),
)

print(senv.reset())

In [None]:
schema_count = 0
control_count = 0
mixed_count = 0

for trial in range(16):
  torch.manual_seed(seeds[trial])
  for mode in ["schema", "mixed", "control"]:
    schema = (mode == "schema")
    if mode == "mixed":
      mixed_count += 1
      name = "chan5_mixed_"+str(mixed_count)
      out = [("agents","h1m"), ("agents", "pred_attn"), ("agents", "logits")]
      policy_net = torch.nn.Sequential(
        vision_transformer_mlpcritic.MultiAgentMixed(
            n_agents=senv.n_agents,
            img_size=[16],
            in_chans=5,
            n_agent_outputs= (senv.action_spec.shape[-1]**2)*2,
        ),
      )
    else:
      if schema:
          schema_count += 1
          name = "chan5_schema_"+str(schema_count)
          out = [("agents","h1m"), ("agents", "pred_attn"), ("agents", "logits")]
      else:
          control_count += 1
          name = "chan5_control_"+str(control_count)
          out = [("agents", "logits")]

      policy_net = torch.nn.Sequential(
          vision_transformer_mlpcritic.MultiAgentVit(
              n_agents=senv.n_agents,
              img_size=[16],
              in_chans=5,
              # each square of the field is a discrete binary action
              n_agent_outputs= (senv.action_spec.shape[-1]**2)*2,
              schema=schema,
          ), # reshaped to [1, 2, 16, 16, 2]
      )

    policy_module = TensorDictModule(
        policy_net,
        in_keys=[("agents", "observation")],
        out_keys=out,
    )

    policy = ProbabilisticActor(
        module=policy_module,
        spec=senv.action_spec,
        in_keys=[("agents", "logits")],
        out_keys=[senv.action_key],
        distribution_class= Categorical,
        return_log_prob=True,
        log_prob_key=("agents", "sample_log_prob"),
    ).to("cuda")

    centralised = False
    _, channels, field_dim, _ = senv.observation_spec["agents", "observation"].shape

    critic_net = vision_transformer_mlpcritic.MultiAgentMlp(
        n_agents=senv.n_agents,
        n_chans=5,
        n_agent_outputs=1,
        centralised=centralised,
        field_dim=field_dim
    )

    critic = TensorDictModule(
        module=critic_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "state_value")],
    ).to(device)

    collector = SyncDataCollector(
        senv,
        policy,
        device="cuda",
        storing_device="cuda",
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
    )

    replay_buffer = ReplayBuffer(
        storage=LazyTensorStorage(
            frames_per_batch, device="cuda"
        ),  # store the frames_per_batch collected at each iteration
        sampler=SamplerWithoutReplacement(),
        batch_size=minibatch_size,
    )

    pred_loss_module = torch.nn.MSELoss()

    loss_module = ClipPPOLoss(
        actor=policy,
        critic=critic,
        clip_epsilon=clip_epsilon,
        entropy_coef=entropy_eps,
        normalize_advantage=False,  #  avoid normalizing across the agent dimension
    )
    loss_module.set_keys(  # tell the loss where to find the keys
        reward=("agents", "reward"),
        action=senv.action_key,
        sample_log_prob=("agents", "sample_log_prob"),
        value=("agents", "state_value"),
        # last 2 keys will be expanded to match the reward shape
        done=("done"),
        terminated=("agents", "terminated"),
    )

    loss_module.make_value_estimator(
        ValueEstimators.GAE, gamma=gamma, lmbda=lmbda
    )
    GAE = loss_module.value_estimator

    optim = torch.optim.Adam(loss_module.parameters(), lr)

    episode_reward_mean_list = []
    episode_agent0_mean_list = []
    episode_agent1_mean_list = []
    episode_overlap_mean_list = []
    episode_steps_list = []

    td_counter = 0
    for tensordict_data in collector:
    # expand the reward, done, and terminated vectors to match ndims of action space
    # (expected by the value estimator)
        td_counter +=1
        tensordict_data.set(
            ("next", "agents", "reward"),
            tensordict_data.get(("next", "agents", "reward"))
            .view(*tensordict_data.get_item_shape(("next", "agents", "reward")), 1, 1)
            .squeeze(1)
            )

        tensordict_data.set(
            ("next", "done"),
            tensordict_data.get(("next", "done"))
            .view(*tensordict_data.get_item_shape(("next", "done")), 1, 1, 1)
            .expand(tensordict_data.get_item_shape(("next", "agents", "reward")))
        )

        tensordict_data.set(
            ("next", "terminated"),
            tensordict_data.get(("next", "terminated"))
            .view(*tensordict_data.get_item_shape(("next", "terminated")), 1, 1, 1)
            .expand(tensordict_data.get_item_shape(("next", "agents", "reward")))
        )

        tensordict_data.set(
            ("done"),
            tensordict_data.get(("done"))
            .view(*tensordict_data.get_item_shape(("done")), 1, 1, 1)
            .expand(tensordict_data.get_item_shape(("next", "agents", "reward")))
        )

        tensordict_data.set(
            ("terminated"),
            tensordict_data.get(("terminated"))
            .view(*tensordict_data.get_item_shape(("terminated")), 1, 1, 1)
            .expand(tensordict_data.get_item_shape(("next", "agents", "reward")))
        )

        with torch.no_grad():

            GAE(
                tensordict_data,
                params=loss_module.critic_network_params,
                target_params=loss_module.target_critic_network_params,
            )  # get advantages

        data_view = tensordict_data.reshape(-1)  # flatten the batch size to shuffle data
        replay_buffer.extend(data_view) # refill the buffer

        for ep in range(num_epochs):
            for _ in range(frames_per_batch // minibatch_size):
                subdata = replay_buffer.sample()
                # combine subdata batch dim with n_envs
                for key in subdata.keys(include_nested=True):
                  if (len(subdata.get(key).shape) > 4):
                    subdata.set(key, subdata.get(key).squeeze(1))
                loss_vals = loss_module(subdata)
                if schema:
                      pred_loss = prediction_coef*pred_loss_module(
                        subdata.get(("agents", "pred_attn")),
                        subdata.get(("agents", "h1m")))
                else:
                      pred_loss = 0

                loss_value = (
                    loss_vals["loss_objective"]
                    + loss_vals["loss_critic"]
                    + loss_vals["loss_entropy"]
                    + pred_loss
                )

                loss_value.backward()
                optim.step()
                optim.zero_grad()

        collector.update_policy_weights_()
        done = tensordict_data.get(("next", "done"))
        episode_reward_mean = (
            tensordict_data.get(("next", "agents", "episode_reward")).mean().item()
        )
        episode_agent0_mean = (
            tensordict_data.get(("next", "agents", "episode_agent0")).mean().item()
        )
        episode_agent1_mean = (
            tensordict_data.get(("next", "agents", "episode_agent1")).mean().item()
        )
        episode_overlap_mean = (
            tensordict_data.get(("next", "agents", "overlap")).mean().item()
        )
        n_eps = torch.count_nonzero(tensordict_data.get(("next","done")) == True).item() + 1
        episode_steps_list.append(frames_per_batch/n_eps)
        print(str(td_counter)+": "+str(episode_reward_mean))
        episode_reward_mean_list.append(episode_reward_mean)
        episode_agent0_mean_list.append(episode_agent0_mean)
        episode_agent1_mean_list.append(episode_agent1_mean)
        episode_overlap_mean_list.append(episode_overlap_mean)

    file = open(dir_path+"reward_"+name+".txt","w")
    for item in episode_reward_mean_list:
        file.write(str(item)+"\n")
    file.close()


plt.plot(episode_reward_mean_list)