In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import GATConv

## Observation Encoder

\begin{equation}
h_{i}=\mathrm{MLP}(o_{i})
\end{equation}

In [7]:
class ObsEncoder(nn.Module):
    def __init__(self, in_dim, o_dim=128, h_dim=128):
        super(ObsEncoder, self).__init__()
        self.fc1 = nn.Linear(in_dim. h_dim)
        self.fc2 = nn.Linear(h_dim, o_dim)
    
    def forward(self, o):
        o = F.relu(self.fc1(o))
        o = F.relu(self.fc2(o))
        return o

## Relational Kernel

\begin{equation}
\alpha_{i,j}^{m}=\frac{\exp(\tau\cdot \mathbf{W}_{Q}^{m}h_{i}\cdot(\mathbf{W}_{K}^{m}h_{j})^\top)}{\sum_{k\in\mathbb{B}_{+i}}\exp(\tau\cdot\mathbf{W}_{Q}^{m}h_{i}\cdot(\mathbf{W}_{K}^{m}h_{k})^{\top})}
\end{equation}

\begin{equation}
h_{i}^{'}=\sigma\left( \mathrm{concat}_{m\in M}\left[ \sum_{j\in\mathbb{B}_{+i}}\alpha_{i,j}^{m}\mathbf{W}_{v}^{m}h_{j} \right] \right)
\end{equation}

In [9]:
class DotGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(DotGATLayer, self).__init__()
        self.fc_q = nn.Linear(in_dim, out_dim)
        self.fc_k = nn.Linear(in_dim, out_dim)
        self.fc_v = nn.Linear(in_dim, out_dim)
        self.tau = 1/math.sqrt(h_dim)

    def edge_attention(self, edges):
        k = self.fc_k(edges.src['z'])
        q = self.fc_q(edges.dst['z'])
        a = torch.matmul(q, k.transpose(-2, -1))*self.tau
        return {'e': a}

    def message_func(self, edges):
        return {'z': edges.src['z'], 'e': edges.data['e']}
    
    def reduce_func(self, nodes):
        s = nodes.mailbox['e']
        alpha = F.softmax(s, dim=1)
        v = self.fc_v(nodes.mailbox['z'])
        h = torch.sum(alpha * v, dim=1)
        return {'h': h, 'alpha': alpha}

    def forward(self, g, h):
        g.ndata['z'] = z
        g.apply_edges(self.edge_attention)
        g.update_all(self.message_func, self.reduce_func)
        h = g.ndata.pop('h')
        alpha = g.ndata.pop('alpha')
        return h, alpha

class MultiHeadDotGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim,
                 num_heads, merge='cat'):
        super(MultiHeadDotGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        h_dim = out_dim // num_heads
        assert (h_dim*num_heads) == out_dim
        for _ in range(num_heads):
            self.heads.append(DotGATLayer(in_dim, out_dim))
        self.merge = merge

    def forward(self, g, h):
        
        hs, alphas = map(list, zip(*[head(g, h)
                                     for head in self.heads]))
        if self.merge == 'cat':
            h = torch.cat(head_outs, dim=1)
            alpha = torch.cat(alphas, dim=1)
            return h, alpha
        else:
            return torch.mean(torch.stack(head_outs))

## DGN Agent

\begin{equation}
Q(o_{i}, \cdot)=\mathrm{Linear}\left(\mathrm{concat}\left[ h_{i}, h_{i}^{'}, h_{i}^{''} \right]\right)
\end{equation}

In [None]:
class DGNAgent(nn.Module):
    def __init__(self, in_dim, act_dim,
                 h_dim=128, num_heads=8):
        super(DGNAgent, self).__init__()
        self.encoder = ObsEncoder(in_dim, h_dim)
        self.conv1 = MultiHeadDotGATLayer(h_dim, h_dim, num_heads)
        self.conv2 = MultiHeadDotGATLayer(h_dim, h_dim, num_heads)
        self.fc_out = nn.Linear(3*h_dim, act_dim)
    
    def forward(self, x):
        

In [18]:
def dummy():
    return 1, 2

x, y = map(list, zip(*[dummy() for i in range(10)]))
y

[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]