In [None]:
pip install torch

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [1]:
import torch
import torchvision.models as models

#config.py

In [2]:
############################################################
####################    environment     ####################
############################################################

obs_radius = 4
reward_fn = dict(move=-0.075,
                stay_on_goal=0,
                stay_off_goal=-0.075,
                collision=-0.5,
                finish=3)

obs_shape = (6, 2*obs_radius+1, 2*obs_radius+1)
action_dim = 5

############################################################
####################         DQN        ####################
############################################################

# basic training setting
num_actors = 16
log_interval = 10
training_steps = 150000
save_interval = 1000
gamma = 0.99
batch_size = 128
learning_starts = 50000
target_network_update_freq = 1750
save_path='./saved_models'
max_episode_length = 256
buffer_capacity = 262144
chunk_capacity = 64
burn_in_steps = 20

actor_update_steps = 200

# gradient norm clipping
grad_norm_dqn=40

# n-step forward
forward_steps = 2

# prioritized replay
prioritized_replay_alpha=0.6
prioritized_replay_beta=0.4

# curriculum learning
init_env_settings = (1, 10)
max_num_agents = 16
max_map_lenght = 40
pass_rate = 0.9

# dqn network setting
cnn_channel = 128
hidden_dim = 256

# same as DHC if set to false
selective_comm = True
# only works if selective_comm set to false
max_comm_agents = 3

# curriculum learning
cl_history_size = 100

test_seed = 0
num_test_cases = 200
test_env_settings = (
                    (40, 4, 0.3), (40, 8, 0.3),
                    (40, 16, 0.3),
                    (40, 32, 0.3),
                    (40, 64, 0.3),
                    (80, 4, 0.3), (80, 8, 0.3), (80, 16, 0.3), (80, 32, 0.3), (80, 64, 0.3),
                    ) # map length, number of agents, density

#environment.py

In [3]:
import random
from typing import List, Union
import numpy as np
#import config

ACTION_LIST = np.array([[-1, 0], [1, 0], [0, -1], [0, 1], [0, 0]], dtype=np.int32)

class Environment:
    def __init__(self, num_agents: int = init_env_settings[0], map_length: int = init_env_settings[1],
                obs_radius: int = obs_radius, reward_fn: dict = reward_fn, fix_density=None,
                curriculum=False, init_env_settings_set=init_env_settings):

        self.curriculum = curriculum
        if curriculum:
            self.env_set = [init_env_settings_set]
            self.num_agents = init_env_settings_set[0]
            self.map_size = (init_env_settings_set[1], init_env_settings_set[1])
        else:
            self.num_agents = num_agents
            self.map_size = (map_length, map_length)

        # set as same as in PRIMAL
        if fix_density is None:
            self.fix_density = False
            self.obstacle_density = np.random.triangular(0, 0.33, 0.5)
        else:
            self.fix_density = True
            self.obstacle_density = fix_density

        self.map = np.random.choice(2, self.map_size, p=[1-self.obstacle_density, self.obstacle_density]).astype(int)

        partition_list = self._map_partition()

        while len(partition_list) == 0:
            self.map = np.random.choice(2, self.map_size, p=[1-self.obstacle_density, self.obstacle_density]).astype(int)
            partition_list = self._map_partition()

        self.agents_pos = np.empty((self.num_agents, 2), dtype=np.int32)
        self.goals_pos = np.empty((self.num_agents, 2), dtype=np.int32)

        pos_num = sum([len(partition) for partition in partition_list])

        # loop to assign agent original position and goal position for each agent
        for i in range(self.num_agents):

            pos_idx = random.randint(0, pos_num-1)
            partition_idx = 0
            for partition in partition_list:
                if pos_idx >= len(partition):
                    pos_idx -= len(partition)
                    partition_idx += 1
                else:
                    break

            pos = random.choice(partition_list[partition_idx])
            partition_list[partition_idx].remove(pos)
            self.agents_pos[i] = np.asarray(pos, dtype=np.int32)

            pos = random.choice(partition_list[partition_idx])
            partition_list[partition_idx].remove(pos)
            self.goals_pos[i] = np.asarray(pos, dtype=np.int32)

            partition_list = [partition for partition in partition_list if len(partition) >= 2]
            pos_num = sum([len(partition) for partition in partition_list])

        self.obs_radius = obs_radius

        self.reward_fn = reward_fn
        self._get_heuri_map()
        self.steps = 0

        self.last_actions = np.zeros((self.num_agents, 5), dtype=bool)


    def update_env_settings_set(self, new_env_settings_set):
        self.env_set = new_env_settings_set

    def reset(self, num_agents=None, map_length=None):

        if self.curriculum:
            rand = random.choice(self.env_set)
            self.num_agents = rand[0]
            self.map_size = (rand[1], rand[1])

        elif num_agents is not None and map_length is not None:
            self.num_agents = num_agents
            self.map_size = (map_length, map_length)

        if not self.fix_density:
            self.obstacle_density = np.random.triangular(0, 0.33, 0.5)

        self.map = np.random.choice(2, self.map_size, p=[1-self.obstacle_density, self.obstacle_density]).astype(np.float32)

        partition_list = self._map_partition()

        while len(partition_list) == 0:
            self.map = np.random.choice(2, self.map_size, p=[1-self.obstacle_density, self.obstacle_density]).astype(np.float32)
            partition_list = self._map_partition()

        self.agents_pos = np.empty((self.num_agents, 2), dtype=np.int)
        self.goals_pos = np.empty((self.num_agents, 2), dtype=np.int)

        pos_num = sum([len(partition) for partition in partition_list])

        for i in range(self.num_agents):

            pos_idx = random.randint(0, pos_num-1)
            partition_idx = 0
            for partition in partition_list:
                if pos_idx >= len(partition):
                    pos_idx -= len(partition)
                    partition_idx += 1
                else:
                    break

            pos = random.choice(partition_list[partition_idx])
            partition_list[partition_idx].remove(pos)
            self.agents_pos[i] = np.asarray(pos, dtype=np.int)

            pos = random.choice(partition_list[partition_idx])
            partition_list[partition_idx].remove(pos)
            self.goals_pos[i] = np.asarray(pos, dtype=np.int)

            partition_list = [partition for partition in partition_list if len(partition) >= 2]
            pos_num = sum([len(partition) for partition in partition_list])

        self.steps = 0
        self._get_heuri_map()

        self.last_actions = np.zeros((self.num_agents, 5), dtype=np.bool)

        return self.observe()

    def load(self, map:np.ndarray, agents_pos:np.ndarray, goals_pos:np.ndarray):

        self.map = np.copy(map)
        self.agents_pos = np.copy(agents_pos)
        self.goals_pos = np.copy(goals_pos)

        self.num_agents = agents_pos.shape[0]
        self.map_size = (self.map.shape[0], self.map.shape[1])

        self.steps = 0

        self._get_heuri_map()

        self.last_actions = np.zeros((self.num_agents, 5), dtype=bool)

    def _get_heuri_map(self):
        dist_map = np.ones((self.num_agents, *self.map_size), dtype=np.int32) * np.iinfo(np.int32).max

        empty_pos = np.argwhere(self.map==0).tolist()
        empty_pos = set([tuple(pos) for pos in empty_pos])

        for i in range(self.num_agents):
            open_list = set()
            x, y = tuple(self.goals_pos[i])
            open_list.add((x, y))
            dist_map[i, x, y] = 0

            while open_list:
                x, y = open_list.pop()
                dist = dist_map[i, x, y]

                up = x-1, y
                if up in empty_pos and dist_map[i, x-1, y] > dist+1:
                    dist_map[i, x-1, y] = dist+1
                    open_list.add(up)

                down = x+1, y
                if down in empty_pos and dist_map[i, x+1, y] > dist+1:
                    dist_map[i, x+1, y] = dist+1
                    open_list.add(down)

                left = x, y-1
                if left in empty_pos and dist_map[i, x, y-1] > dist+1:
                    dist_map[i, x, y-1] = dist+1
                    open_list.add(left)

                right = x, y+1
                if right in empty_pos and dist_map[i, x, y+1] > dist+1:
                    dist_map[i, x, y+1] = dist+1
                    open_list.add(right)

        self.heuri_map = np.zeros((self.num_agents, 4, *self.map_size), dtype=bool)
        #AAAself.heuri_map = np.zeros((self.num_agents, 4, *self.map_size), dtype=np.bool)

        for x, y in empty_pos:
            for i in range(self.num_agents):

                if x > 0 and dist_map[i, x-1, y] < dist_map[i, x, y]:
                    self.heuri_map[i, 0, x, y] = 1

                if x < self.map_size[0]-1 and dist_map[i, x+1, y] < dist_map[i, x, y]:
                    self.heuri_map[i, 1, x, y] = 1

                if y > 0 and dist_map[i, x, y-1] < dist_map[i, x, y]:
                    self.heuri_map[i, 2, x, y] = 1

                if y < self.map_size[1]-1 and dist_map[i, x, y+1] < dist_map[i, x, y]:
                    self.heuri_map[i, 3, x, y] = 1

        self.heuri_map = np.pad(self.heuri_map, ((0, 0), (0, 0), (self.obs_radius, self.obs_radius), (self.obs_radius, self.obs_radius)))

    def _map_partition(self):
        '''
        partitioning map into independent partitions
        '''
        empty_list = np.argwhere(self.map==0).tolist()

        empty_pos = set([tuple(pos) for pos in empty_list])

        if not empty_pos:
            raise RuntimeError('no empty position')

        partition_list = list()
        while empty_pos:

            start_pos = empty_pos.pop()

            open_list = list()
            open_list.append(start_pos)
            close_list = list()

            while open_list:
                x, y = open_list.pop(0)

                up = x-1, y
                if up in empty_pos:
                    empty_pos.remove(up)
                    open_list.append(up)

                down = x+1, y
                if down in empty_pos:
                    empty_pos.remove(down)
                    open_list.append(down)

                left = x, y-1
                if left in empty_pos:
                    empty_pos.remove(left)
                    open_list.append(left)

                right = x, y+1
                if right in empty_pos:
                    empty_pos.remove(right)
                    open_list.append(right)

                close_list.append((x, y))

            if len(close_list) >= 2:
                partition_list.append(close_list)

        return partition_list

    def step(self, actions: List[int]):
        '''
        actions:
            list of indices
                0 up
                1 down
                2 left
                3 right
                4 stay
        '''

        assert len(actions) == self.num_agents, 'only {} actions as input while {} agents in environment'.format(len(actions), self.num_agents)
        assert all([action_idx<5 and action_idx>=0 for action_idx in actions]), 'action index out of range'

        checking_list = [i for i in range(self.num_agents)]

        rewards = []
        next_pos = np.copy(self.agents_pos)

        # remove unmoving agent id
        for agent_id in checking_list.copy():
            if actions[agent_id] == 4:
                # unmoving
                if np.array_equal(self.agents_pos[agent_id], self.goals_pos[agent_id]):
                    rewards.append(self.reward_fn['stay_on_goal'])
                else:
                    rewards.append(self.reward_fn['stay_off_goal'])
                checking_list.remove(agent_id)
            else:
                # move
                next_pos[agent_id] += ACTION_LIST[actions[agent_id]]
                rewards.append(self.reward_fn['move'])

        # first round check, these two conflicts have the highest priority
        for agent_id in checking_list.copy():

            if np.any(next_pos[agent_id]<0) or np.any(next_pos[agent_id]>=self.map_size[0]):
                # agent out of map range
                rewards[agent_id] = self.reward_fn['collision']
                next_pos[agent_id] = self.agents_pos[agent_id]
                checking_list.remove(agent_id)

            elif self.map[tuple(next_pos[agent_id])] == 1:
                # collide obstacle
                rewards[agent_id] = self.reward_fn['collision']
                next_pos[agent_id] = self.agents_pos[agent_id]
                checking_list.remove(agent_id)

        # second round check, agent swapping conflict
        no_conflict = False
        while not no_conflict:

            no_conflict = True
            for agent_id in checking_list:

                target_agent_id = np.where(np.all(next_pos[agent_id]==self.agents_pos, axis=1))[0]

                #if target_agent_id:
                if target_agent_id is not None and len(target_agent_id) > 0:

                    target_agent_id = target_agent_id.item()

                    if np.array_equal(next_pos[target_agent_id], self.agents_pos[agent_id]):
                        assert target_agent_id in checking_list, 'target_agent_id should be in checking list'

                        next_pos[agent_id] = self.agents_pos[agent_id]
                        rewards[agent_id] = self.reward_fn['collision']

                        next_pos[target_agent_id] = self.agents_pos[target_agent_id]
                        rewards[target_agent_id] = self.reward_fn['collision']

                        checking_list.remove(agent_id)
                        checking_list.remove(target_agent_id)

                        no_conflict = False
                        break

        # third round check, agent collision conflict
        no_conflict = False
        while not no_conflict:
            no_conflict = True
            for agent_id in checking_list:

                collide_agent_id = np.where(np.all(next_pos==next_pos[agent_id], axis=1))[0].tolist()
                if len(collide_agent_id) > 1:
                    # collide agent

                    # if all agents in collide agent are in checking list
                    all_in_checking = True
                    for id in collide_agent_id.copy():
                        if id not in checking_list:
                            all_in_checking = False
                            collide_agent_id.remove(id)

                    if all_in_checking:

                        collide_agent_pos = next_pos[collide_agent_id].tolist()
                        for pos, id in zip(collide_agent_pos, collide_agent_id):
                            pos.append(id)
                        collide_agent_pos.sort(key=lambda x: x[0]*self.map_size[0]+x[1])

                        collide_agent_id.remove(collide_agent_pos[0][2])

                    next_pos[collide_agent_id] = self.agents_pos[collide_agent_id]
                    for id in collide_agent_id:
                        rewards[id] = self.reward_fn['collision']

                    for id in collide_agent_id:
                        checking_list.remove(id)

                    no_conflict = False
                    break

        self.agents_pos = np.copy(next_pos)

        self.steps += 1

        # check done
        if np.array_equal(self.agents_pos, self.goals_pos):
            done = True
            rewards = [self.reward_fn['finish'] for _ in range(self.num_agents)]
        else:
            done = False

        info = {'step': self.steps-1}

        # make sure no overlapping agents
        assert np.unique(self.agents_pos, axis=0).shape[0] == self.num_agents

        # update last actions
        self.last_actions = np.zeros((self.num_agents, 5), dtype=bool)
        #AAAself.last_actions = np.zeros((self.num_agents, 5), dtype=np.bool)
        self.last_actions[np.arange(self.num_agents), np.array(actions)] = 1

        return self.observe(), rewards, done, info


    def observe(self):
        '''
        return observation and position for each agent

        obs: shape (num_agents, 6, 2*obs_radius+1, 2*obs_radius+1)
            layer 1: agent map
            layer 2: obstacle map
            layer 3-6: heuristic map

        last_act: agents' last step action

        pos: current position of each agent, used for caculating communication mask

        '''
        obs = np.zeros((self.num_agents, 6, 2*self.obs_radius+1, 2*self.obs_radius+1), dtype=bool)

        obstacle_map = np.pad(self.map, self.obs_radius, 'constant', constant_values=0)

        agent_map = np.zeros((self.map_size), dtype=bool)
        agent_map[self.agents_pos[:,0], self.agents_pos[:,1]] = 1
        agent_map = np.pad(agent_map, self.obs_radius, 'constant', constant_values=0)

        for i, agent_pos in enumerate(self.agents_pos):
            x, y = agent_pos

            obs[i, 0] = agent_map[x:x+2*self.obs_radius+1, y:y+2*self.obs_radius+1]
            obs[i, 0, self.obs_radius, self.obs_radius] = 0
            obs[i, 1] = obstacle_map[x:x+2*self.obs_radius+1, y:y+2*self.obs_radius+1]
            obs[i, 2:] = self.heuri_map[i, :, x:x+2*self.obs_radius+1, y:y+2*self.obs_radius+1]

        return obs, np.copy(self.last_actions), np.copy(self.agents_pos)

#buffer.py

In [4]:
import math
from dataclasses import dataclass
import numpy as np

@dataclass
class EpisodeData:
    __slots__ = ('actor_id', 'num_agents', 'map_len', 'obs', 'last_act', 'actions', 'rewards',
                'hiddens', 'relative_pos', 'comm_mask', 'gammas', 'td_errors', 'sizes', 'done')
    actor_id: int
    num_agents: int
    map_len: int
    obs: np.ndarray
    last_act: np.ndarray
    actions: np.ndarray
    rewards: np.ndarray
    hiddens: np.ndarray
    relative_pos: np.ndarray
    comm_mask: np.ndarray
    gammas: np.ndarray
    td_errors: np.ndarray
    sizes: np.ndarray
    done: bool


class SumTree:
    '''used for prioritized experience replay'''
    def __init__(self, capacity: int):
        layer = 1
        while 2**(layer-1) < capacity:
            layer += 1
        assert 2**(layer-1) == capacity, 'capacity only allow n**2 size'
        self.layer = layer
        self.tree = np.zeros(2**layer-1, dtype=np.float64)
        self.capacity = capacity
        self.size = 0

    def sum(self):
        assert np.sum(self.tree[-self.capacity:])-self.tree[0] < 0.1, 'sum is {} but root is {}'.format(np.sum(self.tree[-self.capacity:]), self.tree[0])
        return self.tree[0]

    def __getitem__(self, idx: int):
        assert 0 <= idx < self.capacity

        return self.tree[self.capacity-1+idx]

    def batch_sample(self, batch_size: int):
        p_sum = self.tree[0]
        interval = p_sum/batch_size

        prefixsums = np.arange(0, p_sum, interval, dtype=np.float64) + np.random.uniform(0, interval, batch_size)

        idxes = np.zeros(batch_size, dtype=np.int)
        for _ in range(self.layer-1):
            nodes = self.tree[idxes*2+1]
            idxes = np.where(prefixsums<nodes, idxes*2+1, idxes*2+2)
            prefixsums = np.where(idxes%2==0, prefixsums-self.tree[idxes-1], prefixsums)

        priorities = self.tree[idxes]
        idxes -= self.capacity-1

        assert np.all(priorities>0), 'idx: {}, priority: {}'.format(idxes, priorities)
        assert np.all(idxes>=0) and np.all(idxes<self.capacity)

        return idxes, priorities

    def batch_update(self, idxes: np.ndarray, priorities: np.ndarray):
        assert idxes.shape[0] == priorities.shape[0]
        idxes += self.capacity-1
        self.tree[idxes] = priorities

        for _ in range(self.layer-1):
            idxes = (idxes-1) // 2
            idxes = np.unique(idxes)
            self.tree[idxes] = self.tree[2*idxes+1] + self.tree[2*idxes+2]

        # check
        assert np.sum(self.tree[-self.capacity:])-self.tree[0] < 0.1, 'sum is {} but root is {}'.format(np.sum(self.tree[-self.capacity:]), self.tree[0])


class LocalBuffer:
    __slots__ = ('actor_id', 'map_len', 'num_agents', 'obs_buf', 'act_buf', 'rew_buf', 'hidden_buf', 'forward_steps',
                'relative_pos_buf', 'q_buf', 'capacity', 'size', 'done', 'burn_in_steps', 'chunk_capacity', 'last_act_buf', 'comm_mask_buf')
    def __init__(self, actor_id: int, num_agents: int, map_len: int, init_obs: np.ndarray, forward_steps=forward_steps,
                capacity: int = max_episode_length, burn_in_steps=burn_in_steps,
                obs_shape=obs_shape, hidden_dim=hidden_dim, action_dim=action_dim):
        """
        buffer for each episode
        """
        self.actor_id = actor_id
        self.num_agents = num_agents
        self.map_len = map_len

        self.burn_in_steps = burn_in_steps
        self.forward_steps = forward_steps

        self.chunk_capacity = chunk_capacity

        self.obs_buf = np.zeros((burn_in_steps+capacity+1, num_agents, *obs_shape), dtype=np.bool)
        self.last_act_buf = np.zeros((burn_in_steps+capacity+1, num_agents, 5), dtype=np.bool)
        self.act_buf = np.zeros((capacity), dtype=np.uint8)
        self.rew_buf = np.zeros((capacity+forward_steps-1), dtype=np.float16)
        self.hidden_buf = np.zeros((burn_in_steps+capacity+1, num_agents, hidden_dim), dtype=np.float16)
        self.relative_pos_buf = np.zeros((burn_in_steps+capacity+1, num_agents, num_agents, 2), dtype=np.int8)
        self.comm_mask_buf = np.zeros((burn_in_steps+capacity+1, num_agents, num_agents), dtype=np.bool)
        self.q_buf = np.zeros((capacity+1, action_dim), dtype=np.float32)

        self.capacity = capacity
        self.size = 0

        self.obs_buf[:burn_in_steps+1] = init_obs

    def add(self, q_val, action: int, last_act, reward: float, next_obs, hidden, relative_pos, comm_mask):
        assert self.size < self.capacity

        self.act_buf[self.size] = action
        self.rew_buf[self.size] = reward
        self.obs_buf[self.burn_in_steps+self.size+1] = next_obs
        self.last_act_buf[self.burn_in_steps+self.size+1] = last_act
        self.q_buf[self.size] = q_val
        self.hidden_buf[self.burn_in_steps+self.size+1] = hidden
        self.relative_pos_buf[self.burn_in_steps+self.size] = relative_pos
        self.comm_mask_buf[self.burn_in_steps+self.size] = comm_mask

        self.size += 1

    def finish(self, last_q_val=None, last_relative_pos=None, last_comm_mask=None):
        forward_steps = min(self.size, self.forward_steps)
        cumulated_gamma = [gamma**forward_steps for _ in range(self.size-forward_steps)]

        # last q value is None if done
        if last_q_val is None:
            done = True
            cumulated_gamma.extend([0 for _ in range(forward_steps)])

        else:
            done = False
            self.q_buf[self.size] = last_q_val
            self.relative_pos_buf[self.burn_in_steps+self.size] = last_relative_pos
            self.comm_mask_buf[self.burn_in_steps+self.size] = last_comm_mask
            cumulated_gamma.extend([gamma**i for i in reversed(range(1, forward_steps+1))])


        num_chunks = math.ceil(self.size/chunk_capacity)

        cumulated_gamma = np.array(cumulated_gamma, dtype=np.float16)
        self.obs_buf = self.obs_buf[:self.burn_in_steps+self.size+1]
        self.last_act_buf = self.last_act_buf[:self.burn_in_steps+self.size+1]
        self.act_buf = self.act_buf[:self.size]
        self.rew_buf = self.rew_buf[:self.size+self.forward_steps-1]
        self.hidden_buf = self.hidden_buf[:self.size]
        self.relative_pos_buf = self.relative_pos_buf[:self.burn_in_steps+self.size+1]
        self.comm_mask_buf = self.comm_mask_buf[:self.burn_in_steps+self.size+1]

        self.rew_buf = np.convolve(self.rew_buf,
                                [gamma**(self.forward_steps-1-i) for i in range(self.forward_steps)],
                                'valid').astype(np.float16)

        # caculate td errors for prioritized experience replay

        max_q = np.max(self.q_buf[forward_steps:self.size+1], axis=1)
        max_q = np.concatenate((max_q, np.array([max_q[-1] for _ in range(forward_steps-1)])))

        target_q = self.q_buf[np.arange(self.size), self.act_buf]
        td_errors = np.zeros(num_chunks*self.chunk_capacity, dtype=np.float32)
        td_errors[:self.size] = np.abs(self.rew_buf+max_q*cumulated_gamma-target_q).clip(1e-6)
        sizes = np.array([min(self.chunk_capacity, self.size-i*self.chunk_capacity) for i in range(num_chunks)], dtype=np.uint8)

        data = EpisodeData(self.actor_id, self.num_agents, self.map_len, self.obs_buf, self.last_act_buf, self.act_buf,
                    self.rew_buf, self.hidden_buf, self.relative_pos_buf, self.comm_mask_buf, cumulated_gamma, td_errors, sizes, done)

        return data

#model.py k means

In [5]:
from sklearn.cluster import KMeans

In [6]:
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast

class CommLayer(nn.Module):
    def __init__(self, input_dim=hidden_dim, message_dim=32, pos_embed_dim=16, num_heads=4):
        super().__init__()
        self.input_dim = input_dim
        self.message_dim = message_dim
        self.pos_embed_dim = pos_embed_dim
        self.num_heads = num_heads

        self.norm = nn.LayerNorm(input_dim)

        self.position_embeddings = nn.Linear((2*obs_radius+1)**2, pos_embed_dim)

        self.message_key = nn.Linear(input_dim+pos_embed_dim, message_dim * num_heads)
        self.message_value = nn.Linear(input_dim+pos_embed_dim, message_dim * num_heads)
        self.hidden_query = nn.Linear(input_dim, message_dim * num_heads)

        self.head_agg = nn.Linear(message_dim * num_heads, message_dim * num_heads)

        self.update = nn.GRUCell(num_heads*message_dim, input_dim)

    def position_embed(self, relative_pos, dtype, device):

        batch_size, num_agents, _, _ = relative_pos.size()
        # mask out out of FOV agent
        relative_pos[(relative_pos.abs() > obs_radius).any(3)] = 0

        one_hot_position = torch.zeros((batch_size*num_agents*num_agents, 9*9), dtype=dtype, device=device)
        relative_pos += obs_radius
        relative_pos = relative_pos.reshape(batch_size*num_agents*num_agents, 2)
        relative_pos_idx = relative_pos[:, 0] + relative_pos[:, 1]*9

        one_hot_position[torch.arange(batch_size*num_agents*num_agents), relative_pos_idx.long()] = 1
        position_embedding = self.position_embeddings(one_hot_position)

        return position_embedding

    def forward(self, hidden, relative_pos, comm_mask):
        batch_size, num_agents, hidden_dim = hidden.size()
        attn_mask = (comm_mask==False).unsqueeze(3).unsqueeze(4)
        relative_pos = relative_pos.clone()

        position_embedding = self.position_embed(relative_pos, hidden.dtype, hidden.device)

        input = hidden

        hidden = self.norm(hidden)

        hidden_q = self.hidden_query(hidden).view(batch_size, 1, num_agents, self.num_heads, self.message_dim) # batch_size x num_agents x message_dim*num_heads

        message_input = hidden.repeat_interleave(num_agents, dim=1).view(batch_size*num_agents*num_agents, hidden_dim)
        message_input = torch.cat((message_input, position_embedding), dim=1)
        message_input = message_input.view(batch_size, num_agents, num_agents, self.input_dim+self.pos_embed_dim)
        message_k = self.message_key(message_input).view(batch_size, num_agents, num_agents, self.num_heads, self.message_dim)
        message_v = self.message_value(message_input).view(batch_size, num_agents, num_agents, self.num_heads, self.message_dim)

        # attention
        attn_score = (hidden_q * message_k).sum(4, keepdim=True) / self.message_dim**0.5 # batch_size x num_agents x num_agents x self.num_heads x 1
        attn_score.masked_fill_(attn_mask, torch.finfo(attn_score.dtype).min)
        attn_weights = F.softmax(attn_score, dim=1)

        # agg
        agg_message = (message_v * attn_weights).sum(1).view(batch_size, num_agents, self.num_heads*self.message_dim)
        agg_message = self.head_agg(agg_message)

        # update hidden with request message
        input = input.view(-1, hidden_dim)
        agg_message = agg_message.view(batch_size*num_agents, self.num_heads*self.message_dim)
        updated_hidden = self.update(agg_message, input)

        # some agents may not receive message, keep it as original
        update_mask = comm_mask.any(1).view(-1, 1)
        hidden = torch.where(update_mask, updated_hidden, input)
        hidden = hidden.view(batch_size, num_agents, hidden_dim)

        return hidden



class CommBlock(nn.Module):
    def __init__(self, hidden_dim=hidden_dim, message_dim=128, pos_embed_dim=16):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.message_dim = message_dim
        self.pos_embed_dim = pos_embed_dim

        self.request_comm = CommLayer()
        self.reply_comm = CommLayer()


    def forward(self, latent, relative_pos, comm_mask):
        '''
        latent shape: batch_size x num_agents x latent_dim
        relative_pos shape: batch_size x num_agents x num_agents x 2
        comm_mask shape: batch_size x num_agents x num_agents
        '''

        batch_size, num_agents, latent_dim = latent.size()

        assert relative_pos.size() == (batch_size, num_agents, num_agents, 2), relative_pos.size()
        assert comm_mask.size() == (batch_size, num_agents, num_agents), comm_mask.size()

        if torch.sum(comm_mask).item() == 0:
            return latent

        hidden = self.request_comm(latent, relative_pos, comm_mask)

        comm_mask = torch.transpose(comm_mask, 1, 2)

        hidden = self.reply_comm(hidden, relative_pos, comm_mask)

        return hidden

class Network(nn.Module):
    def __init__(self, input_shape=obs_shape, selective_comm=selective_comm):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.latent_dim = self.hidden_dim + 5
        self.obs_shape = input_shape
        self.selective_comm = selective_comm

        self.obs_encoder = nn.Sequential(
            nn.Conv2d(input_shape[0], 64, 3, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 3, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 192, 3, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(192, 256, 3, 1),
            nn.LeakyReLU(0.2, True),
            nn.Flatten(),
        )

        self.recurrent = nn.GRUCell(self.latent_dim, self.hidden_dim)
        self.comm = CommBlock(self.hidden_dim)

        self.hidden = None

        # dueling q structure
        self.adv = nn.Linear(self.hidden_dim, 5)
        self.state = nn.Linear(self.hidden_dim, 1)

    @torch.no_grad()

    def step(self, obs, last_act, pos):
      num_agents = obs.size(0)
      agent_indexing = torch.arange(num_agents)
      relative_pos = pos.unsqueeze(0) - pos.unsqueeze(1)

      in_obs_mask = (relative_pos.abs() <= obs_radius).all(2)
      in_obs_mask[agent_indexing, agent_indexing] = 0

      group_size = 4
      num_groups = max(1, num_agents // group_size)

      pos_np = pos.cpu().numpy()
      kmeans = KMeans(n_clusters=num_groups, n_init='auto').fit(pos_np)
      labels = kmeans.labels_

      comm_mask = torch.zeros((num_agents, num_agents), dtype=bool)
      for i in range(num_agents):
          for j in range(num_agents):
              if i != j and labels[i] == labels[j]:
                  comm_mask[i, j] = True

      comm_mask = comm_mask & in_obs_mask

      latent = self.obs_encoder(obs)
      latent = torch.cat((latent, last_act), dim=1)

      if self.hidden is None:
          self.hidden = self.recurrent(latent)
      else:
          self.hidden = self.recurrent(latent, self.hidden)

      self.hidden = self.comm(self.hidden.unsqueeze(0), relative_pos.unsqueeze(0), comm_mask.unsqueeze(0))
      self.hidden = self.hidden.squeeze(0)

      adv_val = self.adv(self.hidden)
      state_val = self.state(self.hidden)
      q_val = state_val + adv_val - adv_val.mean(1, keepdim=True)

      actions = torch.argmax(q_val, 1).tolist()

      return actions, q_val.numpy(), self.hidden.numpy(), relative_pos.numpy(), comm_mask.numpy()


    def reset(self):
        self.hidden = None

    @autocast()
    def forward(self, obs, last_act, steps, hidden, relative_pos, comm_mask):
        '''
        used for training
        '''
        # obs shape: seq_len, batch_size, num_agents, obs_shape
        # relative_pos shape: batch_size, seq_len, num_agents, num_agents, 2
        seq_len, batch_size, num_agents, *_ = obs.size()

        obs = obs.view(seq_len*batch_size*num_agents, *self.obs_shape)
        last_act = last_act.view(seq_len*batch_size*num_agents, action_dim)

        latent = self.obs_encoder(obs)
        latent = torch.cat((latent, last_act), dim=1)
        latent = latent.view(seq_len, batch_size*num_agents, self.latent_dim)

        hidden_buffer = []
        for i in range(seq_len):
            # hidden size: batch_size*num_agents x self.hidden_dim
            hidden = self.recurrent(latent[i], hidden)
            hidden = hidden.view(batch_size, num_agents, self.hidden_dim)
            hidden = self.comm(hidden, relative_pos[:, i], comm_mask[:, i])
            # only hidden from agent 0
            hidden_buffer.append(hidden[:, 0])
            hidden = hidden.view(batch_size*num_agents, self.hidden_dim)

        # hidden buffer size: batch_size x seq_len x self.hidden_dim
        hidden_buffer = torch.stack(hidden_buffer).transpose(0, 1)

        # hidden size: batch_size x self.hidden_dim
        hidden = hidden_buffer[torch.arange(batch_size), steps-1]

        adv_val = self.adv(hidden)
        state_val = self.state(hidden)

        q_val = state_val + adv_val - adv_val.mean(1, keepdim=True)

        return q_val

  @autocast()


#test.py

In [9]:
'''create test set and test model'''
import os
import random
import pickle
from typing import Tuple, Union
import warnings
warnings.simplefilter("ignore", UserWarning)
from tqdm import tqdm
import numpy as np
import torch
import torch.multiprocessing as mp

torch.manual_seed(test_seed)
np.random.seed(test_seed)
random.seed(test_seed)
DEVICE = torch.device('cpu')
torch.set_num_threads(1)

def create_test(test_env_settings: Tuple = test_env_settings, num_test_cases: int = num_test_cases):
    '''
    create test set
    '''

    for map_length, num_agents, density in test_env_settings:

        name = f'./test_set/{map_length}length_{num_agents}agents_{density}density.pth'
        print(f'-----{map_length}length {num_agents}agents {density}density-----')

        tests = []

        env = Environment(fix_density=density, num_agents=num_agents, map_length=map_length)

        for _ in tqdm(range(num_test_cases)):
            tests.append((np.copy(env.map), np.copy(env.agents_pos), np.copy(env.goals_pos)))
            env.reset(num_agents=num_agents, map_length=map_length)
        print()

        with open(name, 'wb') as f:
            pickle.dump(tests, f)



def test_model(model_range: Union[int, tuple], test_set: Tuple = test_env_settings):
    '''
    test model in 'saved_models' folder
    '''
    network = Network()
    network.eval()
    network.to(DEVICE)
    pool = mp.Pool(mp.cpu_count()//2)
    print(model_range)

    if isinstance(model_range, int):
        state_dict = torch.load(os.path.join(save_path, f'{model_range}.pth'), map_location=DEVICE)
        network.load_state_dict(state_dict)
        network.eval()
        network.share_memory()


        print(f'----------test model {model_range}----------')

        for case in test_set:
            print(f"test set: {case[0]} length {case[1]} agents {case[2]} density")
            with open('./test_set/{}length_{}agents_{}density.pth'.format(case[0], case[1], case[2]), 'rb') as f:
                tests = pickle.load(f)

            tests = [(test, network) for test in tests]
            ret = pool.map(test_one_case, tests)

            success, steps, num_comm = zip(*ret)


            print("success rate: {:.2f}%".format(sum(success)/len(success)*100))
            print("average step: {}".format(sum(steps)/len(steps)))
            print("communication times: {}".format(sum(num_comm)/len(num_comm)))
            print()

    elif isinstance(model_range, tuple):

        for model_name in range(model_range[0], model_range[1]+1, save_interval):
            state_dict = torch.load(os.path.join(save_path, f'{model_name}.pth'), map_location=DEVICE)
            network.load_state_dict(state_dict)
            network.eval()
            network.share_memory()


            print(f'----------test model {model_name}----------')

            for case in test_set:
                print(f"test set: {case[0]} length {case[1]} agents {case[2]} density")
                with open(f'./test_set/{case[0]}length_{case[1]}agents_{case[2]}density.pth', 'rb') as f:
                    tests = pickle.load(f)

                tests = [(test, network) for test in tests]
                ret = pool.map(test_one_case, tests)


                success, steps, num_comm = zip(*ret)

                print("success rate: {:.2f}%".format(sum(success)/len(success)*100))
                print("average step: {}".format(sum(steps)/len(steps)))
                print("communication times: {}".format(sum(num_comm)/len(num_comm)))
                print()

            print('\n')

def test_one_case(args):

    env_set, network = args

    env = Environment()
    env.load(env_set[0], env_set[1], env_set[2])
    obs, last_act, pos = env.observe()

    done = False
    network.reset()

    step = 0
    num_comm = 0
    while not done and env.steps < max_episode_length:
        actions, _, _, _, comm_mask = network.step(torch.as_tensor(obs.astype(np.float32)).to(DEVICE),
                                                    torch.as_tensor(last_act.astype(np.float32)).to(DEVICE),
                                                    torch.as_tensor(pos.astype(int)))
        (obs, last_act, pos), _, done, _ = env.step(actions)
        step += 1
        num_comm += np.sum(comm_mask)

    return np.array_equal(env.agents_pos, env.goals_pos), step, num_comm




def code_test():
    env = Environment()
    network = Network()
    network.eval()
    obs, last_act, pos = env.observe()
    network.step(torch.as_tensor(obs.astype(np.float32)).to(DEVICE),
                                                    torch.as_tensor(last_act.astype(np.float32)).to(DEVICE),
                                                    torch.as_tensor(pos.astype(int)))

if __name__ == '__main__':

    # load trained model and reproduce results in paper
    test_model(128000)



128000
----------test model 128000----------
test set: 40 length 4 agents 0.3 density
success rate: 99.50%
average step: 49.54
communication times: 35.5

test set: 40 length 8 agents 0.3 density
success rate: 100.00%
average step: 58.35
communication times: 203.55

test set: 40 length 16 agents 0.3 density
success rate: 99.00%
average step: 72.4
communication times: 1017.06

test set: 40 length 32 agents 0.3 density
success rate: 80.00%
average step: 133.235
communication times: 6575.38

test set: 40 length 64 agents 0.3 density
success rate: 19.50%
average step: 242.655
communication times: 38113.65

test set: 80 length 4 agents 0.3 density
success rate: 100.00%
average step: 92.895
communication times: 18.92

test set: 80 length 8 agents 0.3 density
success rate: 99.00%
average step: 108.83
communication times: 111.66

test set: 80 length 16 agents 0.3 density
success rate: 97.50%
average step: 119.81
communication times: 463.6

test set: 80 length 32 agents 0.3 density
success rate: