In [6]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.graphproppred import DglGraphPropPredDataset

from typing import Tuple

In [2]:
dataset = DglGraphPropPredDataset(root='/home/ksadowski/datasets', name='ogbg-molhiv')

In [4]:
g = dataset[0][0]

g.ndata['feat'] = g.ndata['feat'].to(torch.float32)
g.edata['feat'] = g.edata['feat'].to(torch.float32)

g.edata['weight'] = F.softmax(nn.Linear(3, 1)(g.edata['feat']), dim=0)

print(g.edata['weight'])

print(g.ndata['feat'])

g.update_all(
    message_func=fn.u_mul_e('feat', 'weight', 'message'),
    reduce_func=fn.sum('message', 'projection'),
)

print(g.ndata['projection'])

tensor([[0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0278],
        [0.0278],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0278],
        [0.0278],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0278],
        [0.0278],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0278],
        [0.0278],
        [0.0243],
        [0.0243],
        [0.0243],
        [0.0243]], grad_fn=<SoftmaxBackward>)
tensor([[ 5.,  0.,  4.,  5.,  3.,  0.,  2.,  0.,  0.],
        [ 5.,  0.,  4.,  5.,  2.,  0.,  2.,  0.,  0.],
        [ 5.,  0.,  3.,  5.,  0.,  0.,  1.,  0.,  1.],
        [ 7.,  0.,  2.,  6.,  0.,  0.,  1.,  0.,  1.],
        [28.,  0.,  4.,  2.,  0.

In [56]:
import math
from typing import Tuple

import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearBlock(nn.Module):
    def __init__(
        self,
        in_feats,
        out_feats,
        normalization: str = None,
        activation: str = None,
    ):
        super().__init__()
        self._linear = nn.Linear(in_feats, out_feats)

        if normalization is not None:
            if normalization == 'batch':
                self._normalization = nn.BatchNorm1d(out_feats)
            elif normalization == 'layer':
                self._normalization = nn.LayerNorm(out_feats)
        else:
            self._normalization = None

        if activation is not None:
            if activation == 'relu':
                self._activation = nn.ReLU()
            elif activation == 'leaky_relu':
                self._activation = nn.LeakyReLU()
            elif activation == 'sigmoid':
                self._activation = nn.Sigmoid()
        else:
            self._activation = None

    def forward(self, inputs: torch.Tensor):
        x = self._linear(inputs)

        if self._normalization is not None:
            x = self._normalization(x)

        if self._activation is not None:
            x = self._activation(x)

        return x


class MessageProjection(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        edge_in_feats: int,
        num_heads: int,
        message_func: str,
        reduce_func: str,
        node_activation: str = None,
        edge_activation: str = None,
    ):
        super().__init__()
        self._node_in_feats = node_in_feats
        self._num_heads = num_heads
        self._node_linear = LinearBlock(
            node_in_feats,
            node_in_feats * num_heads,
            activation=node_activation,
        )
        self._edge_linear = LinearBlock(
            edge_in_feats,
            num_heads,
            activation=edge_activation,
        )

        if message_func == 'add':
            self._message_func = fn.u_add_e('projection', 'weight', 'message')
        elif message_func == 'sub':
            self._message_func = fn.u_sub_e('projection', 'weight', 'message')
        elif message_func == 'mul':
            self._message_func = fn.u_mul_e('projection', 'weight', 'message')
        elif message_func == 'div':
            self._message_func = fn.u_div_e('projection', 'weight', 'message')

        if reduce_func == 'sum':
            self._reduce_func = fn.sum('message', 'projection')
        elif reduce_func == 'mean':
            self._reduce_func = fn.mean('message', 'projection')

    def forward(self, g: dgl.DGLGraph) -> Tuple[torch.Tensor, torch.Tensor]:
        g.ndata['projection'] = self._node_linear(g.ndata['feat'])
        g.ndata['projection'] = g.ndata['projection'].view(
            -1, self._num_heads, self._node_in_feats)

        g.edata['weight'] = self._edge_linear(g.edata['feat'])
        g.edata['weight'] = g.edata['weight'].view(-1, self._num_heads, 1)

        g.update_all(
            message_func=self._message_func,
            reduce_func=self._reduce_func,
        )

        node_projection = g.ndata.pop('projection')
        edge_projection = g.edata.pop('weight')

        return node_projection, edge_projection


class LinearProjection(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        edge_in_feats: int,
        num_heads: int,
        activation: str = None,
    ):
        super().__init__()
        self._node_in_feats = node_in_feats
        self._edge_in_feats = edge_in_feats
        self._num_heads = num_heads
        self._node_linear = LinearBlock(
            node_in_feats,
            node_in_feats * num_heads,
            activation=activation,
        )
        self._edge_linear = LinearBlock(
            edge_in_feats,
            edge_in_feats * num_heads,
            activation=activation,
        )

    def forward(self, g: dgl.DGLGraph) -> Tuple[torch.Tensor, torch.Tensor]:
        node_projection = self._node_linear(g.ndata['feat'])
        node_projection = node_projection.view(
            -1, self._num_heads, self._node_in_feats)

        edge_projection = self._edge_linear(g.edata['feat'])
        edge_projection = edge_projection.view(
            -1, self._num_heads, self._edge_in_feats)

        return node_projection, edge_projection


class MultiMutualAttentionHead(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        edge_in_feats: int,
        num_heads: int,
        head_pooling_func: str,
        message_func: str,
        reduce_func: str,
        message_projection_node_activation: str = None,
        message_projection_edge_activation: str = None,
        linear_projection_activation: str = None,
    ):
        super().__init__()
        self._node_scale_const = math.sqrt(node_in_feats)
        self._edge_scale_const = math.sqrt(edge_in_feats)
        self._head_pooling_func = head_pooling_func
        self._query_linear = MessageProjection(
            node_in_feats,
            edge_in_feats,
            num_heads,
            message_func,
            reduce_func,
            message_projection_node_activation,
            message_projection_edge_activation,
        )
        self._key_linear = MessageProjection(
            node_in_feats,
            edge_in_feats,
            num_heads,
            message_func,
            reduce_func,
            message_projection_node_activation,
            message_projection_edge_activation,
        )
        self._value_linear = LinearProjection(
            node_in_feats,
            edge_in_feats,
            num_heads,
            linear_projection_activation,
        )

    def _calculate_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        scale_const: float,
    ):
        attention_score = query @ torch.transpose(key, -1, -2)
        attention_score = torch.exp(attention_score / scale_const).clamp(-5, 5)
        attention_score = F.softmax(attention_score, dim=-1)

        return attention_score

    def forward(self, g: dgl.DGLGraph) -> Tuple[torch.Tensor, torch.Tensor]:
        node_query, edge_query = self._query_linear(g)
        node_key, edge_key = self._key_linear(g)
        node_value, edge_value = self._value_linear(g)

        node_attention = self._calculate_attention_score(
            node_query, node_key, self._node_scale_const)
        edge_attention = self._calculate_attention_score(
            edge_query, edge_key, self._edge_scale_const)

        node_embedding = node_attention @ node_value
        edge_embedding = edge_attention @ edge_value

        if self._head_pooling_func == 'sum':
            node_embedding = node_embedding.sum(-2)
            edge_embedding = edge_embedding.sum(-2)
        elif self._head_pooling_func == 'mean':
            node_embedding = node_embedding.mean(-2)
            edge_embedding = edge_embedding.mean(-2)

        return node_embedding, edge_embedding



node_emb, edge_emb = MultiMutualAttentionHead(9, 3, 4, 'sum', 'mul', 'sum', 'relu', 'sigmoid', 'relu')(g)

edge_emb.shape



torch.Size([19, 9])