<a href="https://colab.research.google.com/github/dasys-lab/comaze-python/blob/gym-env/CoMazeGym_Agent_Template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Please, run the following cell alone, and then go directly to the subsection that you are interested in:

In [None]:
import os
import requests
import time
import gym


class CoMazeGym(gym.Env):
  if os.path.isfile(".local"):
    API_URL = "http://localhost:16216"
    WEBAPP_URL = "http://localhost"
  else:
    API_URL = "http://teamwork.vs.uni-kassel.de:16216"
    WEBAPP_URL = "http://teamwork.vs.uni-kassel.de"
  LIB_VERSION = "1.3.0"
  
  def __init__(self):
    self.game = None
    self.game_id = None
    self.player_id = None
    self.action_space = None

  def reset(self, options={}):
    level = options.get("level", "1")
    num_of_player_slots = options.get("num_of_player_slots", "2")
    
    self.game_id = options.get("game_id", None)
    if self.game_id is None:
      self.game_id = requests.post(self.API_URL + "/game/create?level=" + level + "&numOfPlayerSlots=" + num_of_player_slots).json()["uuid"]
      options["game_id"] = self.game_id
    
    return self.play_existing_game(options)

  def play_existing_game(self, options={}):
    if "look_for_player_name" in options:
      options["game_id"] = requests.get(self.API_URL + "/game/byPlayerName?playerName=" + options["look_for_player_name"]).json()["uuid"]

    if "game_id" not in options or len(options["game_id"]) != 36:
      raise Exception("You must provide a game id when attending an existing game. Use play_new_game() instead of play_existing_game() if you want to create a new game.")

    player_name = options.get("player_name", "Python")
    self.game_id = options["game_id"]
    print("Joined gameId: " + self.game_id)
    player = requests.post(self.API_URL + "/game/" + self.game_id + "/attend?playerName=" + player_name).json()
    self.player_id = player["uuid"]
    self.action_space = player['directions'] + ['SKIP']
    print("Playing as playerId: " + self.player_id)
    self.game = requests.get(self.API_URL + "/game/" + self.game_id).json()
    print(f'Action Space is {self.action_space}')

    while self.game['currentPlayer']['uuid'] != self.player_id or len(self.game["players"]) < 2:
      if self.game['currentPlayer']['uuid'] != self.player_id:
        print(f'Waiting for other player to make first move')
      print("(Invite someone: " + self.WEBAPP_URL + "/?gameId=" + self.game_id + " )")
      time.sleep(1)
      self.game = requests.get(self.API_URL + "/game/" + self.game_id).json()

    return self.game

  def step(self, action, message=None):
    moved = False
    while not moved:
      self.game = requests.get(self.API_URL + "/game/" + self.game_id).json()

      if not self.game["state"]["started"]:
        print("Waiting for players. (Invite someone: " + self.WEBAPP_URL + "/?gameId=" + self.game_id + " )")
        time.sleep(3)
        continue
      available_actions = self.game["currentPlayer"]["directions"]+["SKIP"]
      if action not in available_actions:
        print(f"WARNING: Action {action} is not available to the current player.")
        action = "SKIP"
      print("Moving " + action)
      if action == "SKIP":
        print(f'Wanted to send message {message}, but skipped.')
        message = None
      else:
        print(f'Sending message {message}.')
      print('---')
      request_url = self.API_URL + "/game/" + self.game_id + "/move"
      request_url += "?playerId=" + self.player_id
      request_url += "&action=" + action
      if message is not None and action != 'SKIP':
        request_url += "&symbolMessage=" + message
      print(request_url)
      self.game = requests.post(request_url).json()
      moved = True
    
    if self.game["state"]["won"]:
      print("Game won!")
      reward = 1
    elif self.game["state"]["lost"]:
      print("Game lost (" + self.game["state"]["lostMessage"] + ").")
      reward = -1
    else:
      reward = 0

    if not self.game["state"]["over"]:
      # wait for other player to make a move before sending back obs
      while self.game['currentPlayer']['uuid'] != self.player_id:
        print(f'Waiting for other player to make a move')
        time.sleep(1)
        self.game = requests.get(self.API_URL + "/game/" + self.game_id).json()

    return self.game, reward, self.game["state"]["over"], None
    

## Simple Agent and Gym-like env:

In [None]:
env = CoMazeGym()

In [None]:
# Random Agent
import random 

obs = env.reset()
game_over = False
while not game_over:
  obs, reward, game_over, info = env.step(random.choice(env.action_space))

In [None]:
# Nearest Goal Agent
# Choose a nearest goal, see if one of your actions can get you there, if so take that action
obs = env.reset()
game_over = False
action_space = env.action_space
goals_pos = [goal['position']
             for goal in obs['config']['goals']]

while not game_over:
  goals_pos = [goal['position'] for goal in obs['unreachedGoals']]
  agent_pos = obs['agentPosition']
  
  goal_diffs = [(goal['x'] - agent_pos['x'], goal['y'] - agent_pos['y'])
                for goal in goals_pos]
  goal_dists = [abs(diff[0])+abs(diff[1]) for diff in goal_diffs]
  nearest_goal = goal_dists.index(min(goal_dists)) 

  print(f'Nearest goal is {obs["unreachedGoals"][nearest_goal]}')
  print(f'Nearest goal diff {goal_diffs[nearest_goal]}')

  move_x, move_y = goal_diffs[nearest_goal]

  if 'LEFT' in action_space and move_x < 0:
    action = 'LEFT'
  elif 'RIGHT' in action_space and move_x > 0:
    action = 'RIGHT'
  elif 'UP' in action_space and move_y < 0:
    action = 'UP'
  elif 'DOWN' in action_space and move_y > 0:
    action = 'DOWN'
  else:
    action = 'SKIP'

  obs, reward, game_over, info = env.step(action)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions

In [None]:
env = CoMazeGym()

In [None]:
# Basic RL agent
# single-layer NN that takes in partial observation of the current state (no walls)
# and learns actions WITHOUT communication.

ACTION_SPACE = ['LEFT', 'RIGHT', 'UP', 'DOWN', 'SKIP']

class RLAgent(nn.Module):
  def __init__(self, arena_size, num_actions=5):
    super().__init__()
    arena_size_flat = arena_size[0] * arena_size[1]
    self.embed_state = nn.Linear(arena_size_flat,16)
    self.embed_action_space = nn.Linear(5,5)
    self.policy = nn.Linear(21,num_actions)

  def forward(self, state, action_space):
    state_emb = self.embed_state(state)
    action_emb = self.embed_action_space(action_space)
    state_action_emb = torch.cat((state_emb, action_emb), dim=1)
    return self.policy(state_action_emb)


def get_state_tensor(obs):
  arena_size = (obs['config']['arenaSize']['x'], obs['config']['arenaSize']['y'])
  state_tensor = torch.zeros(arena_size).float()
  state_tensor[obs['agentPosition']['x']][obs['agentPosition']['y']] = 1    # agent

  for goal in obs['unreachedGoals']:
    state_tensor[goal['position']['x']][goal['position']['y']] = 2
  
  return state_tensor


def calculate_returns(rewards, discount_factor, normalize = True):
    returns = []
    R = 0
    
    for r in reversed(rewards):
        R = r + R * discount_factor
        returns.insert(0, R)
        
    returns = torch.tensor(returns)
    
    if normalize:
        returns = (returns - returns.mean()) / returns.std()

    return returns


action_space_list = [1 if x in env.action_space else 0 for x in ACTION_SPACE]
action_space_tensor = torch.FloatTensor(action_space_list)
action_space_tensor_batch = action_space_tensor.unsqueeze(0)

discount_factor = 0.9
learning_rate = 1e-2
num_episodes = 1

# arena_size = (obs['arenaSize']['x'], obs['arenaSize']['y'])
arena_size = (7,7)
agent = RLAgent(arena_size)
optimizer = torch.optim.SGD(agent.parameters(), lr=learning_rate)

for ep in range(num_episodes):
  obs = env.reset()

  done = False
  log_prob_actions = []
  rewards = []
  episode_reward = 0

  while not done:
    state_tensor = get_state_tensor(obs)
    state_tensor_batch = torch.flatten(state_tensor).unsqueeze(0)
    action_pred = agent(state_tensor_batch, action_space_tensor_batch)
    
    action_prob = F.softmax(action_pred, dim = -1)  
    avail_action_prob = action_prob * action_space_tensor
    dist = distributions.Categorical(avail_action_prob)
    action = dist.sample()
    log_prob_action = dist.log_prob(action)

    obs, reward, done, _ = env.step(ACTION_SPACE[action.item()])

    log_prob_actions.append(log_prob_action)
    rewards.append(reward)

    episode_reward += reward


  log_prob_actions = torch.cat(log_prob_actions)
  returns = calculate_returns(rewards, discount_factor).detach()
  loss = - (returns * log_prob_actions).sum()

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  print(f'Loss {loss} EP reward {episode_reward}')

## Discrete Action Space that combines direction/skip and message:

In [None]:
from gym.spaces import Box, Discrete, Dict, MultiBinary
import numpy as np

class CoMazeGymActionWrapper(gym.Wrapper):
  def __init__(self, env, vocab_size=10, maximum_sentence_length=1, options={}):
    super(CoMazeGymActionWrapper, self).__init__(env)
    self.nb_directions = 4
    self.actionId2action =  ["LEFT", "RIGHT", "UP", "DOWN"]
    self.action2actionId =  {"LEFT":0, "RIGHT":1, "UP":2, "DOWN":3}
    
    self.vocab_size = vocab_size
    self.id2token = {
      0:"empty", 
      1:"Q", 
      2:"W", 
      3:"E", 
      4:"R", 
      5:"T", 
      6:"Y", 
      7:"U", 
      8:"I", 
      9:"O", 
      10:"P"
    }
    self.maximum_sentence_length = maximum_sentence_length
    self._build_sentenceId2sentence()
    
    self.nb_possible_actions = self.nb_directions*self.nb_possible_sentences+1 
    # +1 accounts for the SKIP action...
    self.action_space = Discrete(self.nb_possible_actions)

  def _build_sentenceId2sentence(self):
    self.nb_possible_sentences = 1 # account for the empty string:
    for pos in range(self.maximum_sentence_length):
      self.nb_possible_sentences += (self.vocab_size)**(pos+1)
    sentenceId2sentence = np.zeros( (self.nb_possible_sentences, self.maximum_sentence_length))
    idx = 1
    local_token_pointer = 0
    global_token_pointer = 0
    while idx != self.nb_possible_sentences:
      sentenceId2sentence[idx] = sentenceId2sentence[idx-1]
      sentenceId2sentence[idx][local_token_pointer] = (sentenceId2sentence[idx][local_token_pointer]+1)%(self.vocab_size+1)
      
      while sentenceId2sentence[idx][local_token_pointer] == 0:
        # remove the possibility of an empty symbol on the left of actual tokens:
        sentenceId2sentence[idx][local_token_pointer] += 1
        local_token_pointer += 1
        sentenceId2sentence[idx][local_token_pointer] = (sentenceId2sentence[idx][local_token_pointer]+1)%(self.vocab_size+1)
      idx += 1
      local_token_pointer = 0
    
    self.sentenceId2sentence = sentenceId2sentence
  
  def _get_message_from_sentence(self, sentence):
    message = ''
    for pos, sidx in enumerate(sentence):
        # if empty symbol, then there is nothing on the right of it:
        if sidx == 0: 
          # if empty sentence:
          if pos == 0:
            message = None
          break
        token = self.id2token[sidx]
        message += token
    
    return message

  def step(self, action):
    if not self.action_space.contains(action):
      raise ValueError('action {} is invalid for {}'.format(action, self.action_space))
    
    if action != (self.nb_possible_actions-1):
      original_action_direction_id = action // self.nb_possible_sentences
      original_action_direction = self.actionId2action[original_action_direction_id]
    
      original_action_sentence_id = (action % self.nb_possible_sentences)
      original_action_sentence = self.sentenceId2sentence[original_action_sentence_id]
      original_action_message = self._get_message_from_sentence(original_action_sentence)
    else:
      original_action_direction = "SKIP"
      original_action_message = None #self.sentenceId2sentence[0] #empty message.
    
    print(f'discrete action {action} -> original action: direction={original_action_direction} / message={original_action_message}')
    
    return self.env.step(action=original_action_direction, message=original_action_message)

  def is_action_available(self, action):
    available = False
    if not self.action_space.contains(action):
      raise ValueError('action {} is invalid for {}'.format(action, self.action_space))
    
    if action != (self.nb_possible_actions-1):
      original_action_direction_id = action // self.nb_possible_sentences
      original_action_direction = self.actionId2action[original_action_direction_id]
    else:
      original_action_direction = "SKIP"
     
    available = original_action_direction in self.env.action_space
    return available

In [None]:
env = CoMazeGym()
wrapped_env = CoMazeGymActionWrapper(env=env)

In [None]:
print(wrapped_env.action_space)

In [None]:
# Random Agent with Discrete action wrapper
obs = wrapped_env.reset()
game_over = False
while not game_over:
  obs, reward, game_over, info = wrapped_env.step(wrapped_env.action_space.sample())

## RL Agent with Discrete Action Space (Directions+Messages):

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions

In [None]:
env = CoMazeGym()
wrapped_env = CoMazeGymActionWrapper(env=env)

In [None]:
# Basic RL agent
# single-layer NN that takes in partial observation of the current state (no walls)
# and learns actions WITH communication.

nb_possible_actions = wrapped_env.action_space.n 
ACTION_SPACE = np.arange(nb_possible_actions)

class CommRLAgent(nn.Module):
  def __init__(self, arena_size, num_actions=1+4*10):
    super().__init__()
    arena_size_flat = arena_size[0] * arena_size[1]
    embed_state_size = 128
    self.embed_state = nn.Linear(arena_size_flat,embed_state_size)
    embed_action_size = 128
    self.embed_action_space = nn.Linear(num_actions,embed_action_size)
    policy_input_size = embed_state_size+embed_action_size
    self.policy = nn.Linear(policy_input_size,num_actions)

  def forward(self, state, action_space):
    state_emb = self.embed_state(state)
    action_emb = self.embed_action_space(action_space)
    state_action_emb = torch.cat((state_emb, action_emb), dim=1)
    return self.policy(state_action_emb)


def get_state_tensor(obs):
  arena_size = (obs['config']['arenaSize']['x'], obs['config']['arenaSize']['y'])
  state_tensor = torch.zeros(arena_size).float()
  state_tensor[obs['agentPosition']['x']][obs['agentPosition']['y']] = 1    # agent

  for goal in obs['unreachedGoals']:
    state_tensor[goal['position']['x']][goal['position']['y']] = 2
  
  return state_tensor


def calculate_returns(rewards, discount_factor, normalize = True):
    returns = []
    R = 0
    
    for r in reversed(rewards):
        R = r + R * discount_factor
        returns.insert(0, R)
        
    returns = torch.tensor(returns)
    
    if normalize:
        returns = (returns - returns.mean()) / returns.std()

    return returns


discount_factor = 0.9
learning_rate = 1e-2
num_episodes = 1

# arena_size = (obs['arenaSize']['x'], obs['arenaSize']['y'])
arena_size = (7,7)
agent = CommRLAgent(arena_size, num_actions=nb_possible_actions)
optimizer = torch.optim.SGD(agent.parameters(), lr=learning_rate)

for ep in range(num_episodes):
  obs = env.reset()
  
  nb_available_actions = 1+2*(wrapped_env.vocab_size**wrapped_env.maximum_sentence_length+1)
  action_space_list = [1 if wrapped_env.is_action_available(action_id) else 0 for action_id in ACTION_SPACE]
  action_space_tensor = torch.FloatTensor(action_space_list)
  action_space_tensor_batch = action_space_tensor.unsqueeze(0)
  assert action_space_tensor_batch.sum() == nb_available_actions
  
  done = False
  log_prob_actions = []
  rewards = []
  episode_reward = 0

  while not done:
    state_tensor = get_state_tensor(obs)
    state_tensor_batch = torch.flatten(state_tensor).unsqueeze(0)
    action_pred = agent(state_tensor_batch, action_space_tensor_batch)
    
    action_prob = F.softmax(action_pred, dim = -1)  
    avail_action_prob = action_prob * action_space_tensor
    dist = distributions.Categorical(avail_action_prob)
    action = dist.sample()
    log_prob_action = dist.log_prob(action)

    obs, reward, done, _ = wrapped_env.step(action.item())

    log_prob_actions.append(log_prob_action)
    rewards.append(reward)

    episode_reward += reward


  log_prob_actions = torch.cat(log_prob_actions)
  returns = calculate_returns(rewards, discount_factor).detach()
  loss = - (returns * log_prob_actions).sum()

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  print(f'Loss {loss} EP reward {episode_reward}')

# Dictionnary Observation Space and Discrete Action Space:

The dictionnary contains a 'encoded_pov' that encodes the state of the game similarly to the [BabyAI environment](https://github.com/mila-iqia/babyai).


In [None]:
from gym.spaces import Dict, MultiDiscrete, Box, Discrete, MultiBinary
import numpy as np

class CoMazeGymDictObsActionWrapper(gym.Wrapper):
    """
    
    """
    def __init__(self, env, vocab_size=10, maximum_sentence_length=1, options={}):
        super(CoMazeGymDictObsActionWrapper, self).__init__(env)
        self.game = self.env.reset(options=options)
        
        self.nb_directions = 4
        self.actionId2action =  ["LEFT", "RIGHT", "UP", "DOWN"]
        self.action2actionId =  {"LEFT":0, "RIGHT":1, "UP":2, "DOWN":3}
    
        self.vocab_size = vocab_size
        assert self.vocab_size == 10
        self.token2id = {
            "empty":0, 
            "Q":1, 
            "W":2, 
            "E":3, 
            "R":4, 
            "T":5, 
            "Y":6, 
            "U":7, 
            "I":8, 
            "O":9, 
            "P":10
        }
        self.id2token = {
            0:"empty", 
            1:"Q", 
            2:"W", 
            3:"E", 
            4:"R", 
            5:"T", 
            6:"Y", 
            7:"U", 
            8:"I", 
            9:"O", 
            10:"P"
        }
        self.maximum_sentence_length = maximum_sentence_length
        assert self.maximum_sentence_length == 1
        self._build_sentenceId2sentence()
        
        # Action Space:
        self.nb_possible_actions = self.nb_directions*self.nb_possible_sentences+1 
        # +1 accounts for the SKIP action...
        self.action_space = Discrete(self.nb_possible_actions)
        
        # Observation Space:
        ## previous_message_space
        previous_message_space = MultiDiscrete(
            [self.vocab_size+1 for _ in range(self.maximum_sentence_length)]
        )
        ## encoded_pov_space: the depth channel is a one-hot encoding of the tile nature:
        self.tile2id = {}
        self.nb_different_tile = 3       #background, time bonus, and agent
        self.tile2id["background"]= 0
        self.tile2id["agent"]= 1
        self.tile2id["time_bonus"]= 2
        self.nb_different_tile += 4+1    #4 unreached goals + 1 reached goal.
        self.goalEnum2id = {"RED":3, "BLUE":4, "GREEN":5, "YELLOW":6, 'reached_goal':7}
        self.tile2id["goal_1"]= 3
        self.tile2id["goal_2"]= 4
        self.tile2id["goal_3"]= 5
        self.tile2id["goal_4"]= 6
        self.tile2id["reached_goal"]= 7
        self.nb_different_tile += 4      # wall in any of the 4 directions.
        self.wallDirectionEnum2id = {"LEFT":8, "RIGHT":9, "UP":10, "DOWN":11}
        self.tile2id["wall_left"]= 8
        self.tile2id["wall_right"]= 9
        self.tile2id["wall_up"]= 10
        self.tile2id["wall_down"]= 11
        
        encoded_pov_space = Box(
            low=0, 
            high=1, 
            shape=(
                self.game["config"]["arenaSize"]["x"],
                self.game["config"]["arenaSize"]["y"],
                self.nb_different_tile
            ),
            dtype=np.int64, 
        )
        
        ## available_action_space:
        available_actions_space = MultiBinary(n=self.nb_possible_actions)
        
        ##
        
        self.observation_space = Dict({
          'encoded_pov': encoded_pov_space,
          'available_actions': available_actions_space,
          'previous_message': previous_message_space,
        })

    def _build_sentenceId2sentence(self):
        self.nb_possible_sentences = 1 # account for the empty string:
        for pos in range(self.maximum_sentence_length):
            self.nb_possible_sentences += (self.vocab_size)**(pos+1)
        sentenceId2sentence = np.zeros( (self.nb_possible_sentences, self.maximum_sentence_length))
        idx = 1
        local_token_pointer = 0
        global_token_pointer = 0
        while idx != self.nb_possible_sentences:
            sentenceId2sentence[idx] = sentenceId2sentence[idx-1]
            sentenceId2sentence[idx][local_token_pointer] = (sentenceId2sentence[idx][local_token_pointer]+1)%(self.vocab_size+1)
            
            while sentenceId2sentence[idx][local_token_pointer] == 0:
                # remove the possibility of an empty symbol on the left of actual tokens:
                sentenceId2sentence[idx][local_token_pointer] += 1
                local_token_pointer += 1
                sentenceId2sentence[idx][local_token_pointer] = (sentenceId2sentence[idx][local_token_pointer]+1)%(self.vocab_size+1)
            idx += 1
            local_token_pointer = 0    
        
        self.sentenceId2sentence = sentenceId2sentence
  
    def _get_message_from_sentence(self, sentence):
        message = ''
        for pos, sidx in enumerate(sentence):
            # if empty symbol, then there is nothing on the right of it:
            if sidx == 0: 
                # if empty sentence:
                if pos == 0:
                    message = None
                break
            token = self.id2token[sidx]
            message += token
        return message
    
    def reset(self, options={}):
        level = options.get("level", "1")
        num_of_player_slots = options.get("num_of_player_slots", "2")
        self.game_id = requests.post(self.API_URL + "/game/create?level=" + level + "&numOfPlayerSlots=" + num_of_player_slots).json()["uuid"]
        options["game_id"] = self.game_id
        
        self.game = self.play_existing_game(options)
        
        self.obs = {}
        self.obs["encoded_pov"] = self._encode_game(game=self.game)
        
        self.obs["available_actions"] = self._get_available_actions(game=self.game)
        
        self.obs["previous_message"] = np.zeros(self.maximum_sentence_length, dtype=np.int64) #self._get_previous_message(game=self.game)
        
        return self.obs
    
    def step(self, action):
        if not self.action_space.contains(action):
            raise ValueError('action {} is invalid for {}'.format(action, self.action_space))

        if action != (self.nb_possible_actions-1):
            original_action_direction_id = action // self.nb_possible_sentences
            original_action_direction = self.actionId2action[original_action_direction_id]
            
            original_action_sentence_id = (action % self.nb_possible_sentences)
            original_action_sentence = self.sentenceId2sentence[original_action_sentence_id]
            original_action_message = self._get_message_from_sentence(original_action_sentence)
        else:
            original_action_direction = "SKIP"
            original_action_message = None #self.sentenceId2sentence[0] #empty message.
            
        print(f'discrete action {action} -> original action: direction={original_action_direction} / message={original_action_message}')

        self.game, self.reward, self.done, self.infos = self.env.step(action=original_action_direction, message=original_action_message)
        
        self.obs = {}
        self.obs["encoded_pov"] = self._encode_game(game=self.game)
        
        self.obs["available_actions"] = self._get_available_actions(game=self.game)
        
        self.obs["previous_message"] = self._get_previous_message(game=self.game)
        
        return self.obs, self.reward, self.done, self.infos
    
    def is_action_available(self, action):
        available = False
        if not self.action_space.contains(action):
          raise ValueError('action {} is invalid for {}'.format(action, self.action_space))

        if action != (self.nb_possible_actions-1):
          original_action_direction_id = action // self.nb_possible_sentences
          original_action_direction = self.actionId2action[original_action_direction_id]
        else:
          original_action_direction = "SKIP"

        available = original_action_direction in self.env.action_space
        return available
    
    def _encode_game(self, game):
        grid = np.zeros(
            (game["config"]["arenaSize"]["x"], game["config"]["arenaSize"]["y"], self.nb_different_tile),
            dtype=np.int64,
        )
        for x in range(grid.shape[0]):
            for y in range(grid.shape[1]):
                grid[x][y][0] = 1
        
        # Agent:
        agent_x = game["agentPosition"]["x"]
        agent_y = game["agentPosition"]["y"]
        grid[agent_x, agent_y, 0] = 0
        grid[agent_x, agent_y, self.tile2id["agent"]] = 1
        
        # Goals:
        goals = game["config"]["goals"]
        unreached_goals = game["unreachedGoals"]
        for goal in goals:
            gx, gy = goal["position"]["x"], goal["position"]["y"]
            goal_id = self.goalEnum2id[goal["color"]]
            if goal not in unreached_goals:
                goal_id = self.goalEnum2id["reached_goal"]
            grid[gx, gy, 0] = 0
            grid[gx, gy, goal_id] = 1
        
        # Walls?
        walls = game["config"]["walls"]
        for wall in walls:
            wx, wy = wall["position"]["x"], goal["position"]["y"]
            wall_id = self.wallDirectionEnum2id[wall["direction"]]
            grid[wx, wy, 0] = 0
            grid[wx, wy, wall_id] = 1
        
        return grid
    
    def _get_available_actions(self, game):
        current_player_available_actions = game["currentPlayer"]["actions"]
        a_actions = np.zeros(self.nb_possible_actions)
        # SKIP action:
        a_actions[-1] = 1
        for action in current_player_available_actions:
            if action == "SKIP":    continue
            action_id = self.action2actionId[action]
            for midx in range(self.nb_possible_sentences):
                a_actions[action_id*self.nb_possible_sentences+midx] = 1
        return a_actions
    
    def _get_previous_message(self, game):
        players = game["players"]
        currentPlayer = game["currentPlayer"]
        otherPlayers = [player for player in players if player != currentPlayer]
        assert len(otherPlayers) == 1
        otherPlayer_message = otherPlayers[0]["lastSymbolMessage"]
        otherPlayer_message_discrete = np.zeros(self.maximum_sentence_length)
        if otherPlayer_message is not None:
            for widx, token in zip(np.arange(self.maximum_sentence_length), otherPlayer_message):
                otherPlayer_message_discrete[widx] = self.token2id[token]
        return otherPlayer_message_discrete
    

In [None]:
env = CoMazeGym()

In [None]:
dictwrapped_env = CoMazeGymDictObsActionWrapper(env=env)

In [None]:
print(dictwrapped_env.action_space)

In [None]:
print(dictwrapped_env.observation_space)

In [None]:
# Random Agent with Dict Observation and Discrete action wrapper
%debug
obs = dictwrapped_env.reset()
print(obs)
game_over = False
while not game_over:
  obs, reward, game_over, info = dictwrapped_env.step(dictwrapped_env.action_space.sample())

## RL Agent with Dictionary Observation Space and Discrete Action Space (Directions+Messages):

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions

In [None]:
options = {}
#options["game_id"] = '6539c0bb-f659-4cee-8887-c78915b5e285'
env = CoMazeGym()
wrapped_env = CoMazeGymDictObsActionWrapper(env=env, options=options)

In [None]:
# CNN-based RL agent
# takes as input the current state as a dictionnary of elements.
# and learns actions WITH communication.

nb_possible_actions = wrapped_env.action_space.n 
ACTION_SPACE = np.arange(nb_possible_actions)

class DictObsCommRLAgent(nn.Module):
  def __init__(self, num_actions=1+4*10, pov_shape=[7,7,12], previous_message_length=1):
    super().__init__()
    self.num_actions = num_actions
    self.pov_shape = pov_shape
    self.previous_message_length = previous_message_length
    self.embed_pov_size = 256
    self.embed_pov = nn.Sequential(
      nn.Conv2d(in_channels=pov_shape[-1], out_channels=32, kernel_size=3, stride=1, padding=1),
      nn.ReLU(),
      nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1),
      nn.ReLU(),
      nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
      nn.ReLU(),
      nn.Flatten(),
      nn.Linear(512, self.embed_pov_size),
      nn.ReLU(),
    )
    
    self.embed_message_size = 64
    self.embed_previous_message = nn.Embedding(
      num_embeddings=10+1,
      embedding_dim=self.embed_message_size,
    )
    
    self.embed_action_size = 128
    self.embed_action_space = nn.Linear(num_actions, self.embed_action_size)
    
    policy_input_size = self.embed_pov_size+self.embed_message_size+self.embed_action_size
    self.policy = nn.Linear(policy_input_size,num_actions)
  
  def get_formatted_inputs(self, obs):
    nobs = {}
    for k,v in obs.items():
      if 'pov' in k:
        # move channels around:
        assert len(v.shape)==3
        v = np.transpose(v, (2,0,1))
      nv = torch.from_numpy(v).unsqueeze(0).float()
      nobs[k] = nv
    return nobs
    
  def forward(self, obs):
    pov_input = obs["encoded_pov"]
    message_input = obs["previous_message"].long()
    action_space = obs["available_actions"]
    
    pov_emb = self.embed_pov(pov_input)
    message_emb = self.embed_previous_message(message_input).reshape(-1, self.embed_message_size)
    action_emb = self.embed_action_space(action_space)
    
    pov_message_action_emb = torch.cat((pov_emb, message_emb, action_emb), dim=1)
    return self.policy(pov_message_action_emb)


def calculate_returns(rewards, discount_factor, normalize = True):
    returns = []
    R = 0
    
    for r in reversed(rewards):
        R = r + R * discount_factor
        returns.insert(0, R)
        
    returns = torch.tensor(returns)
    
    if normalize:
        returns = (returns - returns.mean()) / returns.std()

    return returns


discount_factor = 0.9
learning_rate = 1e-2
num_episodes = 1

agent = DictObsCommRLAgent(num_actions=nb_possible_actions, pov_shape=[7,7,12], previous_message_length=1)
optimizer = torch.optim.SGD(agent.parameters(), lr=learning_rate)
print(agent)

for ep in range(num_episodes):
  obs = wrapped_env.reset(options=options)
  
  done = False
  log_prob_actions = []
  rewards = []
  episode_reward = 0

  while not done:
    obs = agent.get_formatted_inputs(obs)
    action_pred = agent(obs)
    action_prob = F.softmax(action_pred, dim = -1)  
    avail_action_prob = action_prob * obs["available_actions"]
    dist = distributions.Categorical(avail_action_prob)
    action = dist.sample()
    log_prob_action = dist.log_prob(action)

    obs, reward, done, infos = wrapped_env.step(action.item())

    log_prob_actions.append(log_prob_action)
    rewards.append(reward)

    episode_reward += reward


  log_prob_actions = torch.cat(log_prob_actions)
  returns = calculate_returns(rewards, discount_factor).detach()
  loss = - (returns * log_prob_actions).sum()

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  print(f'Loss {loss} EP reward {episode_reward}')