In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import GATConv

Using backend: pytorch


## Relational Kernel

\begin{equation}
\alpha_{i,j}^{m}=\frac{\exp(\tau\cdot \mathbf{W}_{Q}^{m}h_{i}\cdot(\mathbf{W}_{K}^{m}h_{j})^\top)}{\sum_{k\in\mathbb{B}_{+i}}\exp(\tau\cdot\mathbf{W}_{Q}^{m}h_{i}\cdot(\mathbf{W}_{K}^{m}h_{k})^{\top})}
\end{equation}

In [None]:
class DotGATLayer(nn.Module):
    def __init__(self, in_dim, h_dim, out_dim):
        super(DotGATLayer, self).__init__()
        self.q_fc = nn.Linear(in_dim, h_dim)
        self.k_fc = nn.Linear(in_dim, h_dim)
        self.v_fc = nn.Linear(in_dim, h_dim)
        
        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        
        self.sigma = torch.tensor(sigma).type(torch.float32)
        self.se_fc1 = nn.Linear(in_dim, se_dim)
        self.se_fc2 = nn.Linear(se_dim, 1)
        self.se_act = nn.ReLU()
        self.sigma_0 = torch.tensor(sigma_0).type(torch.float32)
        self.KL_backward = 0.

    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.leaky_relu(self.attn_fc(z2))
        k = self.se_fc1(edges.src['k'])
        k = self.se_fc2(self.se_act(k))
        return {'e': a, 'p': k}
    
    def message_func(self, edges):
        return {'z': edges.src['z'], 'e': edges.data['e'], 'p': edges.data['p']}
    
    def reduce_func(self, nodes):
        s = nodes.mailbox['e']
        p = F.softmax(nodes.mailbox['p'], dim=1)
        mean_prior = torch.log(p+1e-20)
        alpha = F.softmax(s, dim=1)
        logprobs = torch.log(alpha+1e-20)
        if self.training:
            mean_posterior = logprobs - self.sigma**2 / 2
            out_weight = F.softmax(mean_posterior + self.sigma*torch.randn_like(logprobs), dim=-1)
            KL = torch.log(self.sigma_0 / self.sigma + 1e-20) + (
                    self.sigma**2 + (mean_posterior - mean_prior)**2) / (2 * self.sigma_0**2) - 0.5
        else:
            out_weight = alpha
            KL = torch.zeros_like(out_weight)
        out_weight = self.attn_drop(out_weight)
        h = torch.sum(out_weight * nodes.mailbox['z'], dim=1)
        return {'h': h, 'kl': KL.mean(dim=1)}

    def forward(self, g, h):
        g.ndata['k'] = h
        z = self.fc(self.feat_drop(h))
        g.ndata['z'] = z
        g.apply_edges(self.edge_attention)
        g.update_all(self.message_func, self.reduce_func)
        self.KL_backward = g.ndata['kl'].mean()
        return g.ndata.pop('h')