In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import namedtuple
from dgl.nn.pytorch import GATConv
from torch.optim import Adam

Using backend: pytorch


## Replay Buffer

In [5]:
Transition = namedtuple('Transition', ('graph', 'action', 'reward', 'next_graph', 'done'))


class ReplayBuffer(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Save Transitions"""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        samples = random.sample(self.memory, batch_size)
        
        graphs = [sample[0] for sample in samples]
        actions = [sample[1] for sample in samples]
        rewards = [sample[2] for sample in samples]
        next_graphs = [sample[3] for sample in samples]
        dones = [sample[4] for sample in samples]

        ret_graph = dgl.batch(graphs)
        ret_action = torch.stack(actions).reshape(-1, 1)
        ret_reward = torch.Tensor(rewards).reshape(-1)
        ret_next_graph = dgl.batch(next_graphs)
        ret_dones = torch.Tensor(dones).reshape(-1)

        
        return ret_graph, ret_action, ret_reward, ret_next_graph, ret_dones

    def __len__(self):
        return len(self.memory)

## Observation Encoder

\begin{equation}
h_{i}=\mathrm{MLP}(o_{i})
\end{equation}

In [3]:
class ObsEncoder(nn.Module):
    def __init__(self, in_dim, o_dim=128, h_dim=512):
        super(ObsEncoder, self).__init__()
        self.fc1 = nn.Linear(in_dim. h_dim)
        self.fc2 = nn.Linear(h_dim, o_dim)
    
    def forward(self, o):
        o = F.relu(self.fc1(o))
        o = F.relu(self.fc2(o))
        return o

## Relational Kernel

\begin{equation}
\alpha_{i,j}^{m}=\frac{\exp(\tau\cdot \mathbf{W}_{Q}^{m}h_{i}\cdot(\mathbf{W}_{K}^{m}h_{j})^\top)}{\sum_{k\in\mathbb{B}_{+i}}\exp(\tau\cdot\mathbf{W}_{Q}^{m}h_{i}\cdot(\mathbf{W}_{K}^{m}h_{k})^{\top})}
\end{equation}

\begin{equation}
h_{i}^{'}=\sigma\left( \mathrm{concat}_{m\in M}\left[ \sum_{j\in\mathbb{B}_{+i}}\alpha_{i,j}^{m}\mathbf{W}_{v}^{m}h_{j} \right] \right)
\end{equation}



In [25]:
class DotGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(DotGATLayer, self).__init__()
        self.fc_q = nn.Linear(in_dim, out_dim)
        self.fc_k = nn.Linear(in_dim, out_dim)
        self.fc_v = nn.Linear(in_dim, out_dim)
        self.tau = 1/math.sqrt(out_dim)

    def edge_attention(self, edges):
        k = self.fc_k(edges.src['z'])
        q = self.fc_q(edges.dst['z'])
        a = (k*q).sum(-1, keepdims=True)*self.tau
        return {'e': a}

    def message_func(self, edges):
        return {'z': edges.src['z'], 'e': edges.data['e']}
    
    def reduce_func(self, nodes):
        s = nodes.mailbox['e']
        alpha = F.softmax(s, dim=1)
        v = self.fc_v(nodes.mailbox['z'])
        h = torch.sum(alpha * v, dim=1)
        return {'h': h, 'alpha': alpha.squeeze()}

    def forward(self, g, z):
        g.ndata['z'] = z
        g.apply_edges(self.edge_attention)
        g.update_all(self.message_func, self.reduce_func)
        h = g.ndata.pop('h')
        alpha = g.ndata.pop('alpha')
        dummy = g.ndata.pop('z')
        return h, alpha

class MultiHeadDotGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads):
        super(MultiHeadDotGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        h_dim = out_dim // num_heads
        assert (h_dim*num_heads) == out_dim
        for _ in range(num_heads):
            self.heads.append(DotGATLayer(in_dim, h_dim))
        self.merge = merge

    def forward(self, g, h):
        hs, alphas = map(list, zip(*[head(g, h)
                                     for head in self.heads]))
        alpha = torch.stack(alphas).mean(0)
        h = F.relu(torch.cat(hs, dim=1))
        return h, alpha

## Relational Kernel with Bayesian Attention

In [26]:
class BayesGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim,
                 se_dim=1, sigma=1e-15, sigma_0=1e15):
        super(BayesGATLayer, self).__init__()
        self.fc_q = nn.Linear(in_dim, out_dim)
        self.fc_k = nn.Linear(in_dim, out_dim)
        self.fc_v = nn.Linear(in_dim, out_dim)
        self.tau = 1/math.sqrt(out_dim)

        self.sigma = torch.tensor(sigma).type(torch.float32)
        self.se_fc1 = nn.Linear(in_dim, se_dim)
        self.se_fc2 = nn.Linear(se_dim, 1)
        self.se_act = nn.ReLU()
        self.sigma_0 = torch.tensor(sigma_0).type(torch.float32)
        self.KL_backward = 0.

    def edge_attention(self, edges):
        k = edges.src['z']
        k2 = self.se_fc1(k)
        k2 = self.se_fc2(self.se_act(k2))
        k = self.fc_k(k)
        q = self.fc_q(edges.dst['z'])
        a = (k*q).sum(-1, keepdims=True)*self.tau
        
        return {'e': a, 'p': k2}
    
    def message_func(self, edges):
        return {'z': edges.src['z'], 'e': edges.data['e'], 'p': edges.data['p']}
    
    def reduce_func(self, nodes):
        s = nodes.mailbox['e']
        p = F.softmax(nodes.mailbox['p'], dim=1)
        mean_prior = torch.log(p+1e-20)
        alpha = F.softmax(s, dim=1)
        logprobs = torch.log(alpha+1e-20)
        if self.training:
            mean_posterior = logprobs - self.sigma**2 / 2
            out_weight = F.softmax(mean_posterior + self.sigma*torch.randn_like(logprobs), dim=1)
            KL = torch.log(self.sigma_0 / self.sigma + 1e-20) + (
                    self.sigma**2 + (mean_posterior - mean_prior)**2) / (2 * self.sigma_0**2) - 0.5
        else:
            out_weight = alpha
            KL = torch.zeros_like(out_weight)
        v = self.fc_v(nodes.mailbox['z'])
        h = torch.sum(out_weight * v, dim=1)
        return {'h': h, 'alpha': alpha.squeeze(), 'kl': KL.mean(dim=1)}

    def forward(self, g, z):
        g.ndata['z'] = z
        g.apply_edges(self.edge_attention)
        g.update_all(self.message_func, self.reduce_func)
        self.KL_backward = g.ndata.pop('kl').mean()
        h = g.ndata.pop('h')
        alpha = g.ndata.pop('alpha')
        dummy = g.ndata.pop('z')
        return h, alpha

class BayesMultiHeadGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads,
                 se_dim=1, sigma=1e-15, sigma_0=1e15):
        super(BayesMultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        assert (h_dim*num_heads) == out_dim
        for _ in range(num_heads):
            self.heads.append(BayesGATLayer(in_dim, h_dim,
                                            se_dim, sigma, sigma_0))
        self.merge = merge
        self.KL_backward = 0.

    def forward(self, g, h):
        hs, alphas = map(list, zip(*[head(g, h)
                                     for head in self.heads]))
        alpha = torch.stack(alphas).mean(0)
        KL = [head.KL_backward for head in self.heads]
        self.KL_backward = torch.mean(torch.stack(KL))
        h = F.relu(torch.cat(hs, dim=1))
        return h, alpha

## DGN-R Agent

\begin{equation}
Q(o_{i}, \cdot)=\mathrm{Linear}\left(\mathrm{concat}\left[ h_{i}, h_{i}^{'}, h_{i}^{''} \right]\right)
\end{equation}

\begin{equation}
\mathcal{L}_{\mathrm{reg}}(\theta)=\frac{1}{M}\sum_{m=1}^{M}D_{\mathrm{KL}}\left( \mathcal{G}_{m}(O_{i,\mathcal{C}};\theta) || \mathcal{G}_{m}(O_{i,\mathcal{C}}';\theta) \right)
\end{equation}

In [23]:
class DGN_Conv(nn.Module):
    def __init__(self, obs_dim, h_dim=128, num_heads=8,
                 target=False):
        super(DGN_Conv, self).__init__()
        self.encoder = ObsEncoder(in_dim=obs_dim, o_dim=h_dim)
        self.conv1 = MultiHeadDotGATLayer(h_dim, h_dim, num_heads)
        self.conv2 = MultiHeadDotGATLayer(h_dim, h_dim, num_heads)
        self.target = target
    
    def forward(self, graph):
        obs = graph.ndata['obs']
        z1 = self.encoder(obs)
        z2, _ = self.conv1(graph, z1)
        z3, alpha = self.conv2(graph, z2)
        out = torch.cat([z1, z2, z3], dim=1)
        if self.target:
            return out
        return out, alpha

class DGNAgent(nn.Module):
    def __init__(self, n_agents, obs_dim, act_dim, h_dim=128,
                 num_heads=8, gamma=0.96, batch_size=10,
                 buffer_size=2*1e5, epsilon=0.6, episilon_min=0.01,
                 decay_rate=0.996, lr=1e-4, neighbors=3,
                 lamb=0.03, beta=0.01, *args, **kwargs):
        super(DGNAgent, self).__init__()
        self.conv_net = DGN_Conv(obs_dim, h_dim, num_heads)
        self.target_conv = DGN_Conv(obs_dim, h_dim, num_heads, target=True)
        self.q_net = nn.Linear(3*h_dim, act_dim)
        self.target_q = nn.Linear(3*h_dim, act_dim)
        self.target_conv.load_state_dict(self.conv_net.state_dict())
        self.target_q.load_state_dict(self.q_net.state_dict())
        self.optimizer = Adam(self.conv_net.parameters()+self.q_net.parameters(),
                              lr=lr)
        self.beta = beta
        self.gamma = gamma
        self.buffer = ReplayBuffer(buffer_size)
        self.batch_size = batch_size

        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.decay_rate = decay_rate
        
        self.n_agents = n_agents
        self.n_act = act_dim
        self.n_neighbor = neighbors
        self.lamb = lamb  

    def get_action(self, graph):
        if random.random() < self.epsilon:
            action = torch.randint(0, n_act, size=(self.n_agents,))
        else:
            q_value = self.q_net(graph)
            action = q.argmax(dim=-1).detach()
        self.epsilon = max(self.epsilon*self.decay_rate, self.epsilon_min)

        return action

    def get_q(self, graph):
        z, weight = self.conv_net(graph)
        q = self.q_net(z)
        return q, weight
    
    def get_target(self, graph):
        z = self.target_conv(graph)
        q = self.target_q(z)
        return q
    
    def save_samples(self, g, a, r, n_g, t):
        self.buffer.push(g, a, r, n_g, t)
    
    def fit(self):
        if len(self.buffer) < self.batch_size*10:
            return False, 0
        
        state, act, reward, n_state, done = self.buffer.sample(self.batch_size)
        curr_qs, curr_weight = self.get_qs(state)
        selected_qs = curr_qs.gather(1, act).reshape(-1)
        next_qs = self.get_target(n_state).max(dim=1)[0].detach()
        target = reward + self.gamma * next_qs * (1 - done)
        
        _, next_weight = self.get_qs(n_state)
        KL = (curr_weight * torch.log(curr_weight/next_weight)).sum(-1)
        
        loss = F.mse_loss(curr_qs, target) + self.lamb*KL.mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.target_update()
        
        return True, loss.item()
    
    def target_update(self):
        for target, param in zip(self.target_conv.parameters(),
                                  self.conv_net.parameters()):
            target.data = (1-self.beta)*target.data + self.beta*param.data

        for target, param in zip(self.q_net.parameters(),
                                  self.target_q.parameters()):
            target.data = (1-self.beta)*target.data + self.beta*param.data

## DGN-R Agent with Bayesian Attention

In [27]:
class BayesDGN_Conv(nn.Module):
    def __init__(self, obs_dim, h_dim=128, num_heads=8,
                 se_dim=1, sigma=1e-15, sigma_0=1e15,
                 target=False):
        super(BayesDGN_Conv, self).__init__()
        self.encoder = ObsEncoder(in_dim=obs_dim, o_dim=h_dim)
        self.conv1 = BayesMultiHeadGATLayer(h_dim, h_dim, num_heads,
                                            se_dim, sigma, sigma_0)
        self.conv2 = BayesMultiHeadGATLayer(h_dim, h_dim, num_heads,
                                            se_dim, sigma, sigma_0)
        self.target = target
        if self.target:
            self.training = False

    def forward(self, graph):
        obs = graph.ndata['obs']
        z1 = self.encoder(obs)
        z2, _ = self.conv1(graph, z1)
        z3, alpha = self.conv2(graph, z2)
        out = torch.cat([z1, z2, z3], dim=1)
        if self.target:
            return out
        return out, alpha
    
    def kl(self):
        return self.conv1.KL_backward + self.conv2.KL_backward

class BayesDGNAgent(nn.Module):
    def __init__(self, n_agents, obs_dim, act_dim, h_dim=128,
                 num_heads=8, gamma=0.96, batch_size=10,
                 buffer_size=2*1e5, epsilon=0.6, episilon_min=0.01,
                 decay_rate=0.996, lr=1e-4, neighbors=3,
                 lamb=0.03, beta=0.01, rho=0.1,
                 se_dim=1, sigma=1e-15, sigma_0=1e15,
                 *args, **kwargs):
        super(BayesDGNAgent, self).__init__()
        self.conv_net = BayesDGN_Conv(obs_dim, h_dim, num_heads,
                                      se_dim=1, sigma=1e-15, sigma_0=1e15)
        self.target_conv = BayesDGN_Conv(obs_dim, h_dim, num_heads,
                                         se_dim=1, sigma=1e-15, sigma_0=1e15,
                                         target=True)
        self.q_net = nn.Linear(3*h_dim, act_dim)
        self.target_q = nn.Linear(3*h_dim, act_dim)
        self.target_conv.load_state_dict(self.conv_net.state_dict())
        self.target_q.load_state_dict(self.q_net.state_dict())
        self.optimizer = Adam(self.conv_net.parameters()+self.q_net.parameters(),
                              lr=lr)
        self.beta = beta
        self.gamma = gamma
        self.buffer = ReplayBuffer(buffer_size)
        self.batch_size = batch_size

        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.decay_rate = decay_rate
        
        self.n_agents = n_agents
        self.n_act = act_dim
        self.n_neighbor = neighbors
        self.lamb_temp = lamb
        self.rho = rho
        self.t = 0

    def get_action(self, graph):
        if random.random() < self.epsilon:
            action = torch.randint(0, n_act, size=(self.n_agents,))
        else:
            q_value = self.q_net(graph)
            action = q.argmax(dim=-1).detach()
        self.epsilon = max(self.epsilon*self.decay_rate, self.epsilon_min)

        return action

    def get_q(self, graph):
        z, weight = self.conv_net(graph)
        self.KL_backward = self.conv_net.kl()
        q = self.q_net(z)
        return q, weight
    
    def get_target(self, graph):
        z = self.target_conv(graph)
        q = self.target_q(z)
        return q
    
    def save_samples(self, g, a, r, n_g, t):
        self.buffer.push(g, a, r, n_g, t)
    
    def fit(self):
        if len(self.buffer) < self.batch_size*10:
            return False, 0
        lamb_elbo = F.sigmoid(self.rho*self.t)
        state, act, reward, n_state, done = self.buffer.sample(self.batch_size)
        curr_qs, curr_weight = self.get_qs(state)
        selected_qs = curr_qs.gather(1, act).reshape(-1)
        next_qs = self.get_target(n_state).max(dim=1)[0].detach()
        target = reward + self.gamma * next_qs * (1 - done)
        
        _, next_weight = self.get_qs(n_state)
        KL = (curr_weight * torch.log(curr_weight/next_weight)).sum(-1)
        KL = self.lamb_temp*KL + lamb_elbo*self.conv_net.kl()
        loss = F.mse_loss(curr_qs, target) + KL
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.target_update()
        self.t += 1

        return True, loss.item()
    
    def target_update(self):
        for target, param in zip(self.target_conv.parameters(),
                                  self.conv_net.parameters()):
            target.data = (1-self.beta)*target.data + self.beta*param.data

        for target, param in zip(self.q_net.parameters(),
                                  self.target_q.parameters()):
            target.data = (1-self.beta)*target.data + self.beta*param.data

## Get graph from observation

In [17]:
def get_edges(feature, n_agents, n_neighbor=3):
    from_idx = [] # source
    to_idx = [] # destination
    dis = []
    for src in range(n_agents):
        x, y = feature[src][-2], feature[src][-1]
        dis.append((x, y, src))
    for src in range(n_agents):
        f = []
        for dst in range(n_agents):
            distance = (dis[dst][0]-dis[src][0])**2+(dis[dst][1]-dis[src][1])**2
            f.append([distance, dst])
        f.sort(key=lambda x:x[0]) # sort w.r.t. distance
        for order in range(n_neighbor+1):
            from_idx.append(src)
            to_idx.append(f[order][1])
    return from_idx, to_idx

def observation(view, feature, n_agents):
    obs = []
    for j in range(n_agents):
        obs.append(np.hstack(((view[j][:,:,1]-view[j][:,:,5]).flatten(),
                              feature[j][-1:-3:-1])))
    return obs

def gen_graph(view, feature):
    g = dgl.DGLGraph()
    
    n_agents = len(feature)
    g.add_nodes(n_agents)
    
    from_idx, to_idx = get_edges(feature, n_agents)
    g.add_edges(from_idx, to_idx)
    
    # we save observation as the feature of the nodes
    obs = observation(view, feature)
    g.ndata['obs'] = torch.Tensor(obs) # shape = (n_agents, view_size**2 + 2)

    return g

## Test code

In [28]:
pass

## Need to do

1) Test whether it works on "$\textbf{battle}$" environment of "Mean Field Multi-Agent Reinforcement Learning"
https://github.com/mlii/mfrl/blob/master/examples/battle_model/python/magent/builtin/config/battle.py

2) Train models with different groups setting (DGN-R vs DGN, BAM-DGN vs DGN, ...)

3) Discussion on the results & Writing