This notebook is the implementation of Jungle scenario of Neural Message Passing Reinforcement Learning (NMP-RL).
The paper is submitted to Nature Machine Intelligence and under reviewed.

Written by Kha Vo, Chin-Teng Lin, University of Technology Sydney

For more magical stuffs on RL, please visit
https://voanhkha.github.io/2019/11/29/magic_rl_p1/

Email: kha.vo@uts.edu.au
www.kaggle.com/khahuras
www.github.com/voanhkha

Gentle Request: I am happy if you re-use the materials here or the article
Please just kindly cite it by simply copying the source link, or my name.
So much thanks!
Kind regards,
Kha Vo.

In [None]:
import numpy as np, pandas as pd, os, cv2
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
plt.style.use('seaborn-pastel')
import time
from IPython import display
%matplotlib notebook
%matplotlib notebook

from itertools import count
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import schnetpack
from schnetpack.nn import Dense, shifted_softplus
from torch_scatter import scatter_add, scatter_mean
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'

In [None]:
class GridWorld:
    
    def plot(self, path='frames/sample.png', title='title', show_plot=True, save_fig=False):
        colors = {'red':[230, 32, 32], 'green':[40,255,40]}
        M = np.zeros((self.width, self.height, 3), dtype=np.uint8)
        global_maps = [self.get_view(None, 0, t) for t in self.agent_types] 
        for m, c in zip(global_maps, colors):
            for i in range(self.width):
                for j in range(self.height):
                    if m[i,j] != 0: M[i,j,:] = colors[c]
                    #else:  M[i,j,:] = [255,255,255]
        if show_plot:
            plt.imshow(M) 
            plt.axis('off')
            plt.suptitle(title, fontsize=16)
            #plt.tight_layout()
            #plt.subplots_adjust(top=0.88)
        if save_fig:
            plt.savefig(path, bbox_inches='tight')
            #plt.show()
        return M

    def pad_width(self, vector, pad_width, iaxis, kwargs):
        pad_value = kwargs.get('padder', 10)
        vector[:pad_width[0]] = pad_value
        vector[-pad_width[1]:] = pad_value    
    
    def __init__(self, width=20, height=20, agent_types=['predator', 'prey'], 
                 n_agents_each_type=[4,4], terminal_frame=20):
        self.width, self.height = width, height
        self.agent_types = agent_types
        self.n_agent_types = len(self.agent_types)
        self.n_agents_each_type = n_agents_each_type
        self.food_cells = []
        self.agents = {}
        self.map =  { t: [ [ [ ] for j in range(width) ]  for i in range(height) ] for t in self.agent_types}
        self.init_agents(init_pos='random')
        self.move_offset = [[0,1],[0,-1],[1,0],[-1,0], [0,0]]
        self.done = False
        self.terminal_frame = terminal_frame
        self.count_frame = 0
        self.global_reward = 0
        
    def sample_random_actions(self):
        acts = [random.choices( range(len(self.move_offset)), k=N ) 
                for N in self.n_agents_each_type]
        return acts

    def init_agents(self, init_pos):
        if init_pos == 'random':
            init_pos = random.sample([[x,y] for x in range(self.width) for y in range(self.height)], 
                                     np.sum(self.n_agents_each_type))
        cnt = 0
        for j, t in enumerate(self.agent_types):
            self.agents[t] = []
            for i in range(self.n_agents_each_type[j]):  
                _p = np.array(init_pos[cnt])
                new_agent = Agent(agent_type=t, pos=_p, identity=i)
                self.agents[t].append(new_agent)
                self.map[t][_p[0]][_p[1]].append(i)
                if t=='prey': self.food_cells.append([_p[0], _p[1]])
                cnt += 1
                
                
    def get_view(self, pos, vr, agent_type): # if vr==0, get the whole map without padding (use for render)
        stride = vr*2+1
        if self.map[agent_type] == [[[] for i in range(self.width)] for j in range(self.height)]: 
            empty = True
            view = np.zeros((self.width, self.height))
        else:
            empty = False
            view = np.array(self.map[agent_type])
            
        # pad the borders with -1, with thickness = viewrange
        if vr>0: 
            view_pad = np.pad(view, vr, self.pad_width, padder=-1)
            view = view_pad[ pos[0]:pos[0]+stride , pos[1]:pos[1]+stride ]
        if empty: local_view = view
        else: local_view = np.array([[len(subitem) if subitem!=-1 else -1 for subitem in item] for item in view])
        return local_view
    
    
    def get_neighbours(self, pos):
        offsets_to_check = [[0,1], [0,-1], [1,0], [-1,0], [0,0]]
        neighbours = []
        for offset in offsets_to_check:
            check_pos = pos + offset
            if check_pos[0] < 0 or check_pos[0] >= self.width: continue
            if check_pos[1] < 0 or check_pos[1] >= self.height: continue
            #print(check_pos[0], check_pos[1])
            occupied_agents = self.map['predator'][check_pos[0]][check_pos[1]]
            neighbours.extend(occupied_agents)
        return neighbours
        

    def get_connectivity_pairs(self, mat):
        clusters = []
        processed = []
        for cur_idx, m in enumerate(mat):
            if cur_idx in processed: continue
            neis = np.array(m)
            neis = np.where(neis==1)[0].tolist()
            if cur_idx == 0: 
                processed.append(cur_idx)
                processed.extend(neis)
                clusters.append([cur_idx] + neis)
            else: 
                flag = False
                for n in neis:
                    for c, clus in enumerate(clusters):
                        if n in clus: 
                            flag = True
                            clusters[c].extend([cur_idx]+neis)
                            processed.append(cur_idx)
                            processed.extend(neis)
                            break

                    if flag: break

                if not flag: 
                    processed.append(cur_idx)
                    processed.extend(neis)
                    clusters.append([cur_idx] + neis)

        clusters = [np.unique(clus) for clus in clusters]
        pairs = []
        for clus in clusters:
            pairs.extend([ [i,j] for i in clus for j in clus if i!=j ])
        return pairs
      
    def dist(self, p0, p1):
        return np.sqrt((p0[0]-p1[0])**2 + (p0[1]-p1[1])**2)
    
    def get_distances(self, conns):
        dists = []
        for conn in conns:
            dists.append(self.dist( self.agents['predator'][conn[0]].pos, self.agents['predator'][conn[1]].pos ))
        return dists
        
    
    def build_connectivity(self):
        ## Sub-graph
#         connect_matrix = np.zeros((self.n_agents_each_type[0], self.n_agents_each_type[0]))
#         for j, t in enumerate(self.agent_types):
#             if t!='predator': continue
#             for k, agent in enumerate(self.agents[t]):
#                 neighbours = self.get_neighbours(agent.pos)
#                 neighbours = [n for n in neighbours if n!=k]
#                 for n in neighbours: connect_matrix[k, n] = 1     
        
#         connectivities = get_connectivity_pairs(connect_matrix)
        
        ## Full graph
        nb_agents = self.n_agents_each_type[0]
        
        connectivities = [[i,j] for i in range(nb_agents) for j in range(nb_agents) if i!=j]
        
        distances = self.get_distances(connectivities)
        
#         print(connectivities, distances)
        
        return connectivities, distances
        
        
    def transition(self, actions_input='random'):
        if actions_input == 'random': actions_input = self.sample_random_actions()
        self.count_frame += 1
        # Move each agent
        for j, t in enumerate(self.agent_types):
            if t!='predator': continue
                
            for k, agent in enumerate(self.agents[t]):
                
                if agent.properties['movable'] is False: continue
                if agent.properties['active'] is False: continue
                current_pos = agent.pos
                next_pos = current_pos + self.move_offset[actions_input[j][k]]

                if [next_pos[0], next_pos[1]] in self.food_cells: # if agent overlap with food: don't move
                    next_pos[0], next_pos[1] = current_pos[0], current_pos[1]
                
                elif next_pos[0] >= self.width or next_pos[0] < 0:
                    next_pos[0] = current_pos[0]

                elif next_pos[1] >= self.height or next_pos[1] < 0:
                    next_pos[1] = current_pos[1]

                agent.pos = next_pos
                      
                self.map[t][current_pos[0]][current_pos[1]].remove(agent.id)
                self.map[t][next_pos[0]][next_pos[1]].append(agent.id)
                
                agent.current_reward = agent.properties['default_reward'] # reset reward before interact with environment


        # Update predators
        for j, t in enumerate(self.agent_types):
            if t!='predator': continue
                
            for agent in self.agents[t]:
                if agent.properties['active'] is False: continue # if agent is dead, skip
                    
                # Check if this predator is next to a food or not
                view = self.get_view(agent.pos, agent.properties['viewrange'], 'prey')
                next_to_food = True if view[0,1]>=1 or view[1,0]>=1 or view[1,2]>=1 or view[2,1]>=1 else False
                if next_to_food: 
                    agent.properties['kills'] += 1
                    agent.current_reward += 1
                    
                # Check if this predator is next to other 2 agents
                view = self.get_view(agent.pos, agent.properties['viewrange'], 'predator')
                next_1 = view[0,1] if view[0,1]!=-1 else 0
                next_2 = view[1,0] if view[1,0]!=-1 else 0
                next_3 = view[1,2] if view[1,2]!=-1 else 0
                next_4 = view[2,1] if view[2,1]!=-1 else 0
                next_5 = view[1,1]
                surround_count = next_1+next_2+next_3+next_4+next_5-1 #-1 because view[1,1] is itself
                #print(surround_count)
                next_to_other = True if surround_count>=1 else False
                if next_to_other: agent.properties['hp'] -= 1

     
        # Remove agents if hp exhausted by attacked
        for j, t in enumerate(self.agent_types):
            for agent in self.agents[t]:   
                if agent.properties['active'] is False: continue
                if agent.properties['hp'] <= 0:
                    agent.properties['active'] = False
                    agent.current_reward = -1
                    self.map[t][agent.pos[0]][agent.pos[1]].remove(agent.id) # Remove agent from map
                    
            
        # If exceed maximum frame, make the env "done"
        if self.count_frame == self.terminal_frame+1: 
            self.done = True
            #self.global_reward = -1 # count global reward??
            #for agent in self.agents['predator']: agent.current_reward = -1
                
                
class Agent:
    def __init__(self, agent_type='predator', reward_rules={}, pos=[0,0], identity=-1,
                properties={'movable':True, 'active':True}):
        self.agent_type = agent_type
        self.id = identity
        if agent_type=='predator': properties ={'movable':True, 'active':True, 
                                                'viewrange':1, 'kills':0, 
                                                'hp':3, 'default_reward':0}
        elif agent_type=='prey': properties ={'movable':False, 'active':True, 
                                              'viewrange':1, 'hp':1, 'default_reward':0}
        self.properties = properties
        self.pos = pos
        self.action_buffer = []
        self.view = None
        self.current_reward = 0
        self.reward_buffer = []
        self.replay_buffer = []
    
    
class Policy_Gradient(nn.Module):
    def __init__(self, input_dim=25, output_dim=5):
        super(Policy_Gradient, self).__init__()
        self.affine = nn.Linear(input_dim, 128)
        self.action_head = nn.Linear(128, output_dim)

    def forward(self, x):
        x = F.relu(self.affine(x))
        action_prob = F.softmax(self.action_head(x), dim=-1)
        return action_prob


In [None]:
# TRAIN GRADIENT POLICY (REINFORCE)

WIDTH, HEIGHT = 5, 5
AGENT_TYPES = ['predator', 'prey']
N_AGENTS_EACHTYPE = [4, 1]
FRAMES_PER_EPISODE = 30
EPISODES_TO_TRAIN = 12000
VERBOSE = 100

model = Policy_Gradient(input_dim = 18)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
eps = np.finfo(np.float32).eps.item()

# Main

reward_log  = []

for i_episode in range(EPISODES_TO_TRAIN):
    # reset environment and episode reward
    W = GridWorld(width=WIDTH, height=HEIGHT, agent_types=AGENT_TYPES, 
                  n_agents_each_type=N_AGENTS_EACHTYPE,
                  terminal_frame=FRAMES_PER_EPISODE)
    
    ep_reward = 0
    reward_buffer = [[] for _ in range(W.n_agents_each_type[0])] # each predator has 1 separate buffer
    action_buffer = [[] for _ in range(W.n_agents_each_type[0])] # throughout the whole episode

    while True: # transition the environment until done
        actions = [[0 for i in range(N)] for N in W.n_agents_each_type] # init actions in right shape
        
        # Get action for each predator using the model
        for agent_id, agent in enumerate(W.agents['predator']):
            # TODO: only forward action of active predator
            
            view1 = W.get_view(agent.pos, agent.properties['viewrange'], 'prey').ravel()
            view2 = W.get_view(agent.pos, agent.properties['viewrange'], 'predator').ravel()
            state = torch.tensor(np.hstack([view1, view2])).float().to(device)
            #print(state.shape)

            probs = model(state) # forward pass
            m = Categorical(probs)
            action_sample = m.sample()
            action_buffer[agent_id].append(m.log_prob(action_sample))
            action = action_sample.item() # this "action" can be 0,1,2,3,4 (label encoded)

            # embed action to the current agent
            actions[0][agent_id] = action
        
        # transition
        W.transition(actions)
        
        # Get reward for each predator from the new transitioned environment
        for agent_id in range(W.n_agents_each_type[0]):
            reward = W.agents['predator'][agent_id].current_reward
            reward_buffer[agent_id].append(reward)
        
        if W.done: break

            
    # perform backprop
    returns = []
    for agent_id in list(range(W.n_agents_each_type[0]))[::-1]: # Reverse order
        R = 0
        for r in reward_buffer[agent_id][::-1]: # Reverse order
            R = r + 0.99 * R
            returns.insert(0, R)
            
    ep_reward = np.mean(returns[::len(reward_buffer[0])]   )
    
    returns = torch.tensor(returns)
    #returns = (returns - returns.mean()) / (returns.std() + eps) # this one is suspicious

    action_buffer = [item for sublist in action_buffer for item in sublist] # ravel
    policy_losses = []
    for log_prob, R in zip(action_buffer, returns):
        policy_losses.append(-log_prob * R)


    # reset gradients
    optimizer.zero_grad()

    # sum up all the values of policy_losses and value_losses
    loss = torch.stack(policy_losses).sum() #+ torch.stack(value_losses).sum() # ?????

    # perform backprop
    loss.backward()
    optimizer.step()

    # log results
            
    
    reward_log.append(ep_reward)
    if (i_episode+1) % VERBOSE == 0: 
        print('Episode {}\tReward: {:.2f}'.format(i_episode+1, np.mean(reward_log[-VERBOSE:])))
    

In [None]:
# Save plot
train_curve_pg = np.reshape(reward_log, (-1,100)).mean(axis=1)
plt.plot(train_curve_pg)
# plt.savefig('revision_0_materials/PG_Jungle.png')

# # Save numerical result
# np.save('revision_0_materials/PG_Jungle.npy', np.array(reward_log))

# # Save model
# save_path = "revision_0_materials/PG_Jungle_model.pth"
# torch.save(model.state_dict(), save_path)

In [None]:
def _compute_stacked_offsets(sizes, repeats):
    return np.repeat(np.cumsum(np.hstack([0, sizes[:-1]])), repeats)

def _concat(to_stack):
    """ function to stack (or concatentate) depending on dimensions """
    if np.asarray(to_stack[0]).ndim >= 2:
        return np.concatenate(to_stack)
    else:
        return np.hstack(to_stack)

def rbf_expansion(distances, mu=0, delta=0.1, kmax=150):
    k = np.arange(0, kmax)
    logits = -(np.atleast_2d(distances).T - (-mu + delta * k)) ** 2 / delta
    return np.exp(logits)

class SchnetWithEdgeUpdate(nn.Module):
    def __init__(self, n_atom_basis=128, max_z=100, kmax=150, n_interactions=1, activation=shifted_softplus):
        super(SchnetWithEdgeUpdate, self).__init__()
        
        self.n_interactions = n_interactions
        
        self.edge_update_net = nn.Sequential(
            Dense(164, 32, activation=activation), # 164 = 128 (rbf expand) + 18 (state src) + 18 (state dst)
            Dense(32, 18))
        
        self.msg_edge_net = nn.Sequential(
            Dense(18, 18, activation=activation),
            Dense(18, 18, activation=activation))
        
        self.msg_atom_fc = Dense(18, 18)
        
        self.state_trans_net = nn.Sequential(
            Dense(18, 18, activation=activation),
            Dense(18, 18))
        
        self.init_edge_fc = Dense(kmax, n_atom_basis, activation=activation)

        
    def forward(self, inputs):
        x_atom = inputs['states']
#         print('\n\n')
#         print('x_atom', x_atom.shape)
        x_bond =  rbf_expansion(inputs['distance']).astype(np.float32)
#         print('x_bond1', x_bond.shape)
        x_bond = torch.from_numpy(x_bond)
        x_bond = self.init_edge_fc(x_bond)
#         print('x_bond2', x_bond.shape)
#         print('conns', inputs['connectivity'])

        src_idx = torch.LongTensor(inputs['connectivity'][:, 0])
        dst_idx = torch.LongTensor(inputs['connectivity'][:, 1])
#         print('src_idx', src_idx, 'dst_idx', dst_idx)
# 
        for n in range(self.n_interactions):
            # Update edge
            x_src_atom = x_atom[src_idx]
            x_dst_atom = x_atom[dst_idx]
            x_bond = torch.cat((x_src_atom, x_dst_atom, x_bond), dim=1)
#             print('x_bond3', x_bond.shape)
            x_bond = self.edge_update_net(x_bond)
#             print('x_bond4', x_bond.shape)

            # message function
            bond_msg = self.msg_edge_net(x_bond)
#             print('bond_msg', bond_msg.shape)
            src_atom_msg = self.msg_atom_fc(x_src_atom)
#             print('src_atom_msg', src_atom_msg.shape)
            messages = torch.mul(bond_msg, src_atom_msg)
#             print('messages1', messages.shape)
            messages = scatter_add(messages, dst_idx, dim=0)
#             print('messages2', messages.shape)

            # state transition function
            messages = self.state_trans_net(messages)
#             print('messages3', messages.shape)

#             print(x_atom.shape, messages.shape)
            x_atom = x_atom + messages
            
        return x_atom, x_bond

    
class Net(nn.Module):
    def __init__(self, schnet):
        super(Net, self).__init__()
        self.schnet = schnet
        self.agentwise = nn.Sequential(schnetpack.nn.blocks.MLP(
            n_in=18, n_out=5, n_layers=2, activation=shifted_softplus))
        self.softmax = torch.nn.Softmax()

    def forward(self, inputs):
        if inputs['connectivity']!=[]:
            x_atom, x_bond = self.schnet(inputs)
        else: x_atom = inputs['states']
        
        out_agentwise = self.agentwise(x_atom)
        out_agentwise = F.softmax(out_agentwise, dim=-1)
        #out_agentwise = self.softmax(out_agentwise)
        return out_agentwise

In [None]:
# TRAIN NMP-PG

WIDTH, HEIGHT = 5, 5
AGENT_TYPES = ['predator', 'prey']
N_AGENTS_EACHTYPE = [4, 1]
FRAMES_PER_EPISODE = 30
EPISODES_TO_TRAIN = 12000
VERBOSE = 100

model = model = Net(SchnetWithEdgeUpdate())
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
eps = np.finfo(np.float32).eps.item()

# Main

reward_log  = []

for i_episode in range(EPISODES_TO_TRAIN):
    # reset environment and episode reward
    W = GridWorld(width=WIDTH, height=HEIGHT, agent_types=AGENT_TYPES, 
                  n_agents_each_type=N_AGENTS_EACHTYPE,
                  terminal_frame=FRAMES_PER_EPISODE)
    
    ep_reward = 0
    reward_buffer = [[] for _ in range(W.n_agents_each_type[0])] # each predator has 1 separate buffer
    action_buffer = [] # throughout the whole episode

    while True: # transition the environment until done
        actions = [[0 for i in range(N)] for N in W.n_agents_each_type] # init actions in right shape
        
        states, connectivities, distances = [], [], []
        
        for agent_id, agent in enumerate(W.agents['predator']):
            # TODO: only forward action of active predator
            
            view1 = W.get_view(agent.pos, agent.properties['viewrange'], 'prey').ravel()
            view2 = W.get_view(agent.pos, agent.properties['viewrange'], 'predator').ravel()
            state = np.hstack([view1, view2])
            
            # agents states
            states.append(state)
            
        states = torch.tensor(states).float().to(device)
        
        # connectivity pairs
        connectivity = [[i,j] for i in range(len(states)) for j in range(len(states))]
        connectivity = np.array(connectivity)

        # distance for each connectivity pair
        distance = [1 for i in range(len(states)**2)]
        distance = torch.tensor(distance).float().to(device)
        
        connectivity, distance = W.build_connectivity()
        connectivity, distance = np.array(connectivity), np.array(distance)
        
        batch_data = {'states':states, 'connectivity':connectivity, 'distance':distance}
        ###############

        probs = model(batch_data) # forward pass, # probs shape: n_agents x n_action_options 
#         print(probs)
        
        m = Categorical(probs)
        action_sample = m.sample() # sampled actions of all agents at this time step
        action_buffer.append(m.log_prob(action_sample))

        # embed all agents' actions into pre-initialised actions for transitioning the environment
        actions[0] = action_sample.detach().cpu().numpy() # 0 here is index of predator
        
        # transition
        W.transition(actions)
        
        # Get reward for each predator from the new transitioned environment
        for agent_id in range(W.n_agents_each_type[0]):
            reward = W.agents['predator'][agent_id].current_reward
            reward_buffer[agent_id].append(reward)
        
        if W.done: break

            
    # perform backprop
    returns = []
    for agent_id in list(range(W.n_agents_each_type[0]))[::-1]: # Reverse order
        R = 0
        for r in reward_buffer[agent_id][::-1]: # Reverse order
            R = r + 0.99 * R
            returns.insert(0, R)
            
    ep_reward = np.mean(returns[::len(reward_buffer[0])]   )
    
    returns = torch.tensor(returns)
    #returns = (returns - returns.mean()) / (returns.std() + eps) # this one is suspicious

    action_buffer = np.transpose(action_buffer) # transpose action buffer because time step dimension is first, now needs to switch to agent dimension
    action_buffer = [item for sublist in action_buffer for item in sublist] # ravel
    policy_losses = []
    for log_prob, R in zip(action_buffer, returns):
        policy_losses.append(-log_prob * R)


    # reset gradients
    optimizer.zero_grad()

    # sum up all the values of policy_losses and value_losses
    loss = torch.stack(policy_losses).sum() #+ torch.stack(value_losses).sum() # ?????

    # perform backprop
    loss.backward()
    optimizer.step()

    # log results
    reward_log.append(ep_reward)
    if (i_episode+1) % VERBOSE == 0: 
        print('Episode {}\tReward: {:.2f}'.format(i_episode+1, np.mean(reward_log[-VERBOSE:])))
    