# Dependencies

In [None]:
import math
from timeit import default_timer
from typing import Union, Tuple

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.lsc import DglPCQM4MDataset
from ogb.utils import smiles2graph
from tqdm import trange

torch.manual_seed(13)

# Dataset

In [None]:
class ProcessedPCQM4M(dgl.data.DGLDataset):
    def __init__(self, molecules_path: str, molecules_lg_path: str):
        self.molecules_path = molecules_path
        self.molecules_lg_path = molecules_lg_path
        self.graphs = []
        self.line_graphs = []
        self.labels = []
        super().__init__(name='processed_PCQM4M')

    def process(self):  
        start = default_timer()

        self.graphs, labels = dgl.data.utils.load_graphs(self.molecules_path)
        self.line_graphs, _ = dgl.data.utils.load_graphs(self.molecules_lg_path)

        # self.labels = [label for label in labels.values()]

        for i in range(len(self.graphs)):
            self.labels.append(labels[str(i)])

        stop = default_timer()

        print(f'Processed data loading time: {(stop - start) / 60} min.')

    def __getitem__(self, idx: Union[int, torch.Tensor]):
        if isinstance(idx, int):
            return self.graphs[idx], self.line_graphs[idx], self.labels[idx]
        elif torch.is_tensor(idx) and idx.dtype == torch.long:
            if idx.dim() == 0:
                return self.graphs[idx], self.line_graphs[idx], self.labels[idx]
            elif idx.dim() == 1:
                return dgl.data.utils.Subset(self, idx.cpu())

    def __len__(self):
        return len(self.graphs)

processed_dataset = ProcessedPCQM4M('./data/molecules_norm.bin', './data/molecules_lg.bin')

In [None]:
# start = default_timer()

# dataset = DglPCQM4MDataset(root='/home/ksadowski/datasets', smiles2graph=smiles2graph)

# split_dict = dataset.get_idx_split()

# train_idx = split_dict['train']
# val_idx = split_dict['valid']
# test_idx = split_dict['test']

# stop = default_timer()

# print(f'Data loading time: {(stop - start) / 60} min.')

In [None]:
train_idx = torch.load('./data/train_idx.pt')
val_idx = torch.load('./data/val_idx.pt')
test_idx = torch.load('./data/test_idx.pt')

In [None]:
subset_ratio = 0.01

train_subset = torch.randperm(len(train_idx))[:int(subset_ratio * len(train_idx))]
val_subset = torch.randperm(len(val_idx))[:int(subset_ratio * len(val_idx))]
test_subset = torch.randperm(len(test_idx))[:int(subset_ratio * len(test_idx))]

batch_size = 64

train_dataloader = dgl.dataloading.pytorch.GraphDataLoader(
    processed_dataset[train_idx[train_subset]],
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)

val_dataloader = dgl.dataloading.pytorch.GraphDataLoader(
    processed_dataset[val_subset],
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
)

test_dataloader = dgl.dataloading.pytorch.GraphDataLoader(
    processed_dataset[test_subset],
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
)

print(f'Train samples: {len(train_dataloader)}')
print(f'Val samples: {len(val_dataloader)}')
print(f'Test samples: {len(test_dataloader)}')

In [None]:
# class ProcessedPCQM4M(dgl.data.DGLDataset):
#     def __init__(self, ogb_dataset: dgl.data.DGLDataset):
#         self.ogb_dataset = ogb_dataset
#         self.graphs = []
#         self.line_graphs = []
#         self.labels = []
#         super().__init__(name='processed_PCQM4M')

#     def process(self):
#         for i in trange(len(self.ogb_dataset)):
#             g = self.ogb_dataset[i][0].add_self_loop()
#             lg = dgl.line_graph(g, backtracking=False).add_self_loop()

#             g.ndata['feat'] = g.ndata['feat'].float()
#             g.edata['feat'] = g.edata['feat'].float()

#             self.graphs.append(g)
#             self.line_graphs.append(lg)
#             self.labels.append(self.ogb_dataset[i][1])

#     def __getitem__(self, index: Union[int, torch.Tensor]):
#         if isinstance(index, int):
#             return self.graphs[index], self.line_graphs[index], self.labels[index]
#         elif torch.is_tensor(index) and index.dtype == torch.long:
#             if index.dim() == 0:
#                 return self.graphs[index], self.line_graphs[index], self.labels[index]
#             elif index.dim() == 1:
#                 return dgl.data.utils.Subset(self, index.cpu())

#     def __len__(self):
#         return len(self.graphs)

# processed_dataset = ProcessedPCQM4M(dataset)

In [None]:
# labels = {f'{i}': processed_dataset[i][2] for i in range(len(processed_dataset))}

# dgl.data.utils.save_graphs('./molecules.bin', processed_dataset.graphs, labels)
# dgl.data.utils.save_graphs('./molecules_lg.bin', processed_dataset.line_graphs, labels)

# Model

## Transformer

In [None]:
# import math
# from typing import Tuple

# import dgl
# import torch
# import torch.nn as nn
# import torch.nn.functional as F


# class LinearLayer(nn.Module):
#     def __init__(
#         self,
#         in_feats: int,
#         out_feats: int,
#         activation: str = None,
#     ) -> None:
#         super().__init__()
#         self._linear = nn.Linear(in_feats, out_feats)
#         self._activation = activation

#     def forward(self, inputs: torch.Tensor) -> torch.Tensor:
#         x = self._linear(inputs)

#         if self._activation == 'relu':
#             x = F.relu(x)
#         elif self._activation == 'relu6':
#             x = F.relu6(x)
#         elif self._activation == 'leaky_relu':
#             x = F.leaky_relu(x)
#         elif self._activation == 'elu':
#             x = F.elu(x)
#         elif self._activation == 'selu':
#             x = F.selu(x)
#         elif self._activation == 'celu':
#             x = F.celu(x)

#         return x


# class BilinearReadoutLayer(nn.Module):
#     def __init__(
#         self,
#         node_in_feats: int,
#         edge_in_feats: int,
#         out_feats: int,
#         activation: str = None,
#     ) -> None:
#         super().__init__()
#         self._bilinear = nn.Bilinear(node_in_feats, edge_in_feats, out_feats)
#         self._activation = activation

#     def forward(
#         self,
#         node_inputs: torch.Tensor,
#         edge_inputs: torch.Tensor,
#     ) -> float:
#         x = self._bilinear(node_inputs, edge_inputs)

#         if self._activation == 'relu':
#             x = F.relu(x)
#         elif self._activation == 'softplus':
#             x = F.softplus(x)

#         return x


# class MutualMultiAttentionHead(nn.Module):
#     def __init__(
#         self,
#         node_in_feats: int,
#         edge_in_feats: int,
#         num_heads: int,
#         short_residual: bool,
#         dropout_probability: float,
#         message_aggregation_type: str,
#         head_pooling_type: str,
#         linear_projection_activation: str = None,
#     ) -> None:
#         super().__init__()
#         self._node_in_feats = node_in_feats
#         self._edge_in_feats = edge_in_feats
#         self._num_heads = num_heads
#         self._short_residual = short_residual
#         self._message_aggregation_type = message_aggregation_type
#         self._head_pooling_type = head_pooling_type
#         self._device = nn.Parameter(torch.empty(0))
#         self._node_query_linear = LinearLayer(
#             node_in_feats,
#             num_heads * node_in_feats,
#             linear_projection_activation,
#         )
#         self._node_key_linear = LinearLayer(
#             node_in_feats, num_heads, linear_projection_activation)
#         self._node_value_linear = LinearLayer(
#             node_in_feats,
#             num_heads * node_in_feats,
#             linear_projection_activation,
#         )
#         self._edge_query_linear = LinearLayer(
#             edge_in_feats,
#             num_heads * edge_in_feats,
#             linear_projection_activation,
#         )
#         self._edge_key_linear = LinearLayer(
#             edge_in_feats, num_heads, linear_projection_activation)
#         self._edge_value_linear = LinearLayer(
#             edge_in_feats,
#             num_heads * edge_in_feats,
#             linear_projection_activation,
#         )
#         self._node_dropout = nn.Dropout(dropout_probability)
#         self._edge_dropout = nn.Dropout(dropout_probability)

#     def _calculate_self_attention(
#         self,
#         query: torch.Tensor,
#         key: torch.Tensor,
#         in_feats: int,
#         short_residual: torch.Tensor = None,
#     ) -> torch.Tensor:
#         if short_residual is not None:
#             x = query @ torch.transpose(short_residual, -1, -2) @ key
#         else:
#             x = query @ torch.transpose(query, -1, -2) @ key

#         x /= math.sqrt(in_feats)
#         x = F.softmax(x, dim=1)

#         return x

#     def _create_node_attention_projection(
#         self,
#         g: dgl.DGLGraph,
#         edge_self_attention: torch.Tensor,
#     ) -> torch.Tensor:
#         node_attention_projection = torch.zeros(
#             [self._num_heads, g.num_nodes(), g.num_nodes()],
#             device=self._device.device,
#         )

#         for edge in range(g.num_edges()):
#             nodes = g.find_edges(edge)

#             source = nodes[0].item()
#             destination = nodes[1].item()

#             for head in range(self._num_heads):
#                 attention_score = edge_self_attention[head][edge]

#                 node_attention_projection[head][source][destination] = attention_score

#             return node_attention_projection

#     def _create_edge_attention_projection(
#         self,
#         g: dgl.DGLGraph,
#         lg: dgl.DGLGraph,
#         node_self_attention: torch.Tensor,
#     ) -> torch.Tensor:
#         edge_attention_projection = torch.zeros(
#             [self._num_heads, g.num_edges(), g.num_edges()],
#             device=self._device.device,
#         )

#         for node in range(lg.num_edges()):
#             edges = lg.find_edges(node)

#             source = edges[0].item()
#             destination = edges[1].item()

#             connecting_node = g.find_edges(source)[1].item()

#             for head in range(self._num_heads):
#                 attention_score = node_self_attention[head][connecting_node]

#                 edge_attention_projection[head][source][destination] = attention_score

#             return edge_attention_projection

#     def _calculate_message_passing(
#         self,
#         g: dgl.DGLGraph,
#         value: torch.Tensor,
#         attention_projection: torch.Tensor,
#     ):
#         adjacency = g.adj(ctx=self._device.device).to_dense()

#         if self._message_aggregation_type == 'sum':
#             x = attention_projection * adjacency
#         elif self._message_aggregation_type == 'mean':
#             degree_inv = torch.linalg.inv(torch.diag(g.in_degrees().float()))

#             x = degree_inv @ attention_projection * adjacency
#         elif self._message_aggregation_type == 'gcn':
#             degree_inv_sqrt = torch.sqrt(torch.linalg.inv(
#                 torch.diag(g.in_degrees().float())))
#             adjacency_inv_sqrt = torch.sqrt(torch.linalg.inv(adjacency))

#             x = degree_inv_sqrt @ attention_projection * \
#                 adjacency_inv_sqrt @ degree_inv_sqrt

#         message_passing = x @ value

#         return message_passing

#     def forward(
#         self,
#         g: dgl.DGLGraph,
#         lg: dgl.DGLGraph,
#         node_inputs: torch.Tensor,
#         edge_inputs: torch.Tensor,
#     ) -> Tuple[torch.Tensor, torch.Tensor]:
#         node_query = self._node_query_linear(node_inputs)
#         node_query = node_query.view(self._num_heads, -1, self._node_in_feats)
#         node_key = self._node_key_linear(node_inputs)
#         node_key = node_key.view(self._num_heads, -1, 1)
#         node_value = self._node_value_linear(node_inputs)
#         node_value = node_value.view(self._num_heads, -1, self._node_in_feats)

#         edge_query = self._edge_query_linear(edge_inputs)
#         edge_query = edge_query.view(self._num_heads, -1, self._edge_in_feats)
#         edge_key = self._edge_key_linear(edge_inputs)
#         edge_key = edge_key.view(self._num_heads, -1, 1)
#         edge_value = self._edge_value_linear(edge_inputs)
#         edge_value = edge_value.view(self._num_heads, -1, self._edge_in_feats)

#         if self._short_residual:
#             node_self_attention = self._calculate_self_attention(
#                 node_query, node_key, self._node_in_feats, node_inputs)
#             edge_self_attention = self._calculate_self_attention(
#                 edge_query, edge_key, self._edge_in_feats, edge_inputs)
#         else:
#             node_self_attention = self._calculate_self_attention(
#                 node_query, node_key, self._node_in_feats)
#             edge_self_attention = self._calculate_self_attention(
#                 edge_query, edge_key, self._edge_in_feats)

#         node_attention_projection = self._create_node_attention_projection(
#             g, edge_self_attention)
#         edge_attention_projection = self._create_edge_attention_projection(
#             g, lg, node_self_attention)

#         node_message_passing = self._calculate_message_passing(
#             g, node_value, node_attention_projection)
#         edge_message_passing = self._calculate_message_passing(
#             lg, edge_value, edge_attention_projection)

#         node_message_passing = self._node_dropout(node_message_passing)
#         edge_message_passing = self._edge_dropout(edge_message_passing)

#         if self._head_pooling_type == 'sum':
#             node_message_passing = node_message_passing.sum(dim=-3)
#             edge_message_passing = edge_message_passing.sum(dim=-3)
#         elif self._head_pooling_type == 'mean':
#             node_message_passing = node_message_passing.mean(dim=-3)
#             edge_message_passing = edge_message_passing.mean(dim=-3)

#         return node_message_passing, edge_message_passing


# class MutualAttentionTransformerLayer(nn.Module):
#     def __init__(
#         self,
#         node_in_feats: int,
#         node_out_feats: int,
#         edge_in_feats: int,
#         edge_out_feats: int,
#         num_heads: int,
#         short_residual: bool,
#         long_residual: bool,
#         dropout_probability: float,
#         message_aggregation_type: str,
#         head_pooling_type: str,
#         normalization_type: str,
#         linear_projection_activation: str = None,
#         linear_embedding_activation: str = None,
#     ) -> None:
#         super().__init__()
#         self._long_residual = long_residual
#         self._mutual_multi_attention_head = MutualMultiAttentionHead(
#             node_in_feats,
#             edge_in_feats,
#             num_heads,
#             short_residual,
#             dropout_probability,
#             message_aggregation_type,
#             head_pooling_type,
#             linear_projection_activation,
#         )
#         self._node_linear_embedding = LinearLayer(
#             node_in_feats, node_out_feats, linear_embedding_activation)
#         self._edge_linear_embedding = LinearLayer(
#             edge_in_feats, edge_out_feats, linear_embedding_activation)

#         if normalization_type == 'layer':
#             self._node_normalization_1 = nn.LayerNorm(node_in_feats)
#             self._node_normalization_2 = nn.LayerNorm(node_out_feats)

#             self._edge_normalization_1 = nn.LayerNorm(edge_in_feats)
#             self._edge_normalization_2 = nn.LayerNorm(edge_out_feats)
#         elif normalization_type == 'batch':
#             self._node_normalization_1 = nn.BatchNorm1d(node_in_feats)
#             self._node_normalization_2 = nn.BatchNorm1d(node_out_feats)

#             self._edge_normalization_1 = nn.BatchNorm1d(edge_in_feats)
#             self._edge_normalization_2 = nn.BatchNorm1d(edge_out_feats)

#     def forward(
#         self,
#         g: dgl.DGLGraph,
#         lg: dgl.DGLGraph,
#         node_inputs: torch.Tensor,
#         edge_inputs: torch.Tensor,
#     ) -> Tuple[torch.Tensor, torch.Tensor]:
#         node_embedding, edge_embedding = self._mutual_multi_attention_head(
#             g, lg, node_inputs, edge_inputs)

#         if self._long_residual:
#             node_embedding += node_inputs
#             edge_embedding += edge_inputs

#         node_embedding = self._node_normalization_1(node_embedding)
#         edge_embedding = self._edge_normalization_1(edge_embedding)

#         node_embedding = self._node_linear_embedding(node_embedding)
#         edge_embedding = self._edge_linear_embedding(edge_embedding)

#         node_embedding = self._node_normalization_2(node_embedding)
#         edge_embedding = self._edge_normalization_2(edge_embedding)

#         return node_embedding, edge_embedding


# class GraphMutualAttentionTransformer(nn.Module):
#     def __init__(
#         self,
#         node_in_feats: int,
#         node_hidden_feats: int,
#         node_out_feats: int,
#         edge_in_feats: int,
#         edge_hidden_feats: int,
#         edge_out_feats: int,
#         num_layers: int,
#         num_heads: int,
#         short_residual: bool,
#         long_residual: bool,
#         dropout_probability: float,
#         message_aggregation_type: str,
#         head_pooling_type: str,
#         readout_pooling_type: str,
#         normalization_type: str,
#         linear_projection_activation: str = None,
#         linear_embedding_activation: str = None,
#         bilinear_readout_activation: str = None,
#     ) -> None:
#         super().__init__()
#         self._node_out_feats = node_out_feats
#         self._edge_out_feats = edge_out_feats
#         self._num_layers = num_layers
#         self._transformer_layers = self._create_transformer_layers(
#             node_in_feats,
#             node_hidden_feats,
#             node_out_feats,
#             edge_in_feats,
#             edge_hidden_feats,
#             edge_out_feats,
#             num_layers,
#             num_heads,
#             short_residual,
#             long_residual,
#             dropout_probability,
#             message_aggregation_type,
#             head_pooling_type,
#             normalization_type,
#             linear_projection_activation,
#             linear_embedding_activation,
#         )
#         self._bilinear_readout = BilinearReadoutLayer(
#             node_out_feats, edge_out_feats, 1, bilinear_readout_activation)

#         if readout_pooling_type == 'sum':
#             self._readout_pooling = dgl.nn.pytorch.SumPooling()
#         elif readout_pooling_type == 'mean':
#             self._readout_pooling = dgl.nn.pytorch.AvgPooling()
#         elif readout_pooling_type == 'attention':
#             pass

#     def _create_transformer_layers(
#         self,
#         node_in_feats: int,
#         node_hidden_feats: int,
#         node_out_feats: int,
#         edge_in_feats: int,
#         edge_hidden_feats: int,
#         edge_out_feats: int,
#         num_layers: int,
#         num_heads: int,
#         short_residual: bool,
#         long_residual: bool,
#         dropout_probability: float,
#         message_aggregation_type: str,
#         head_pooling_type: str,
#         normalization_type: str,
#         linear_projection_activation: str = None,
#         linear_embedding_activation: str = None,
#     ) -> nn.ModuleList:
#         transformer_layers = nn.ModuleList()

#         if num_layers > 1:
#             transformer_layers.append(MutualAttentionTransformerLayer(
#                 node_in_feats,
#                 node_hidden_feats,
#                 edge_in_feats,
#                 edge_hidden_feats,
#                 num_heads,
#                 short_residual,
#                 long_residual,
#                 dropout_probability,
#                 message_aggregation_type,
#                 head_pooling_type,
#                 normalization_type,
#                 linear_projection_activation,
#                 linear_embedding_activation,
#             ))

#             for _ in range(num_layers - 2):
#                 transformer_layers.append(MutualAttentionTransformerLayer(
#                     node_hidden_feats,
#                     node_hidden_feats,
#                     edge_hidden_feats,
#                     edge_hidden_feats,
#                     num_heads,
#                     short_residual,
#                     long_residual,
#                     dropout_probability,
#                     message_aggregation_type,
#                     head_pooling_type,
#                     normalization_type,
#                     linear_projection_activation,
#                     linear_embedding_activation,
#                 ))

#             transformer_layers.append(MutualAttentionTransformerLayer(
#                 node_hidden_feats,
#                 node_out_feats,
#                 edge_hidden_feats,
#                 edge_out_feats,
#                 num_heads,
#                 short_residual,
#                 long_residual,
#                 dropout_probability,
#                 message_aggregation_type,
#                 head_pooling_type,
#                 normalization_type,
#                 linear_projection_activation,
#                 linear_embedding_activation,
#             ))
#         else:
#             transformer_layers.append(MutualAttentionTransformerLayer(
#                 node_in_feats,
#                 node_out_feats,
#                 edge_in_feats,
#                 edge_out_feats,
#                 num_heads,
#                 short_residual,
#                 long_residual,
#                 dropout_probability,
#                 message_aggregation_type,
#                 head_pooling_type,
#                 normalization_type,
#                 linear_projection_activation,
#                 linear_embedding_activation,
#             ))

#         return transformer_layers

#     def forward(
#         self,
#         g: dgl.DGLGraph,
#         lg: dgl.DGLGraph,
#         node_inputs: torch.Tensor,
#         edge_inputs: torch.Tensor,
#     ) -> torch.Tensor:
#         node_embedding = node_inputs
#         edge_embedding = edge_inputs

#         for transformer_layer in self._transformer_layers:
#             node_embedding, edge_embedding = transformer_layer(
#                 g, lg, node_inputs, edge_inputs)
        
#         node_embedding = self._readout_pooling(g, node_embedding)
#         edge_embedding = self._readout_pooling(lg, edge_embedding)

#         readout = self._bilinear_readout(node_embedding, edge_embedding)

#         return readout

## Simple

In [None]:
import math
from typing import Tuple

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearLayer(nn.Module):
    def __init__(
        self,
        in_feats: int,
        out_feats: int,
        activation: str = None,
    ) -> None:
        super().__init__()
        self._linear = nn.Linear(in_feats, out_feats)
        self._activation = activation

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        x = self._linear(inputs)

        if self._activation == 'relu':
            x = F.relu(x)
        elif self._activation == 'relu6':
            x = F.relu6(x)
        elif self._activation == 'leaky_relu':
            x = F.leaky_relu(x)
        elif self._activation == 'elu':
            x = F.elu(x)
        elif self._activation == 'selu':
            x = F.selu(x)
        elif self._activation == 'celu':
            x = F.celu(x)

        return x


class BilinearReadoutLayer(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        edge_in_feats: int,
        out_feats: int,
        activation: str = None,
    ) -> None:
        super().__init__()
        self._bilinear = nn.Bilinear(node_in_feats, edge_in_feats, out_feats)
        self._activation = activation

    def forward(
        self,
        node_inputs: torch.Tensor,
        edge_inputs: torch.Tensor,
    ) -> float:
        x = self._bilinear(node_inputs, edge_inputs)

        if self._activation == 'relu':
            x = F.relu(x)
        elif self._activation == 'softplus':
            x = F.softplus(x)

        return x


class MutualMultiAttentionHead(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        edge_in_feats: int,
        num_heads: int,
        dropout_probability: float,
        message_aggregation_type: str,
        head_pooling_type: str,
        linear_projection_activation: str = None,
    ) -> None:
        super().__init__()
        self._node_in_feats = node_in_feats
        self._edge_in_feats = edge_in_feats
        self._num_heads = num_heads
        self._message_aggregation_type = message_aggregation_type
        self._head_pooling_type = head_pooling_type
        self._device = nn.Parameter(torch.empty(0))
        self._node_key_linear = LinearLayer(
            node_in_feats, num_heads, linear_projection_activation)
        self._node_value_linear = LinearLayer(
            node_in_feats,
            num_heads * node_in_feats,
            linear_projection_activation,
        )
        self._edge_key_linear = LinearLayer(
            edge_in_feats, num_heads, linear_projection_activation)
        self._edge_value_linear = LinearLayer(
            edge_in_feats,
            num_heads * edge_in_feats,
            linear_projection_activation,
        )
        self._node_dropout = nn.Dropout(dropout_probability)
        self._edge_dropout = nn.Dropout(dropout_probability)

    def _calculate_self_attention(
        self,
        key: torch.Tensor,
        in_feats: int,
    ) -> torch.Tensor:
        x = key / math.sqrt(in_feats)
        x = F.softmax(x, dim=1)

        return x

    def _create_node_attention_projection(
        self,
        g: dgl.DGLGraph,
        edge_self_attention: torch.Tensor,
    ) -> torch.Tensor:
        node_attention_projection = torch.zeros(
            [self._num_heads, g.num_nodes(), g.num_nodes()],
            device=self._device.device,
        )

        for edge in range(g.num_edges()):
            nodes = g.find_edges(edge)

            source = nodes[0].item()
            destination = nodes[1].item()

            for head in range(self._num_heads):
                attention_score = edge_self_attention[head][edge]

                node_attention_projection[head][source][destination] = attention_score

            return node_attention_projection

    def _create_edge_attention_projection(
        self,
        g: dgl.DGLGraph,
        lg: dgl.DGLGraph,
        node_self_attention: torch.Tensor,
    ) -> torch.Tensor:
        edge_attention_projection = torch.zeros(
            [self._num_heads, g.num_edges(), g.num_edges()],
            device=self._device.device,
        )

        for node in range(lg.num_edges()):
            edges = lg.find_edges(node)

            source = edges[0].item()
            destination = edges[1].item()

            connecting_node = g.find_edges(source)[1].item()

            for head in range(self._num_heads):
                attention_score = node_self_attention[head][connecting_node]

                edge_attention_projection[head][source][destination] = attention_score

            return edge_attention_projection

    def _calculate_message_passing(
        self,
        g: dgl.DGLGraph,
        value: torch.Tensor,
        attention_projection: torch.Tensor,
    ) -> torch.Tensor:
        adjacency = g.adj(ctx=self._device.device).to_dense()

        if self._message_aggregation_type == 'sum':
            x = attention_projection * adjacency
        elif self._message_aggregation_type == 'mean':
            degree_inv = torch.linalg.inv(torch.diag(g.in_degrees().float()))

            x = degree_inv @ attention_projection * adjacency
        elif self._message_aggregation_type == 'gcn':
            degree_inv_sqrt = torch.sqrt(torch.linalg.inv(
                torch.diag(g.in_degrees().float())))
            adjacency_inv_sqrt = torch.sqrt(torch.linalg.inv(adjacency))

            x = degree_inv_sqrt @ attention_projection * \
                adjacency_inv_sqrt @ degree_inv_sqrt

        message_passing = x @ value

        return message_passing

    def forward(
        self,
        g: dgl.DGLGraph,
        lg: dgl.DGLGraph,
        node_inputs: torch.Tensor,
        edge_inputs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        node_key = self._node_key_linear(node_inputs)
        node_key = node_key.view(self._num_heads, -1, 1)
        node_value = self._node_value_linear(node_inputs)
        node_value = node_value.view(self._num_heads, -1, self._node_in_feats)

        edge_key = self._edge_key_linear(edge_inputs)
        edge_key = edge_key.view(self._num_heads, -1, 1)
        edge_value = self._edge_value_linear(edge_inputs)
        edge_value = edge_value.view(self._num_heads, -1, self._edge_in_feats)


        node_self_attention = self._calculate_self_attention(
            node_key, self._node_in_feats)
        edge_self_attention = self._calculate_self_attention(
            edge_key, self._edge_in_feats)

        node_attention_projection = self._create_node_attention_projection(
            g, edge_self_attention)
        edge_attention_projection = self._create_edge_attention_projection(
            g, lg, node_self_attention)

        node_message_passing = self._calculate_message_passing(
            g, node_value, node_attention_projection)
        edge_message_passing = self._calculate_message_passing(
            lg, edge_value, edge_attention_projection)

        node_message_passing = self._node_dropout(node_message_passing)
        edge_message_passing = self._edge_dropout(edge_message_passing)

        if self._head_pooling_type == 'sum':
            node_message_passing = node_message_passing.sum(dim=-3)
            edge_message_passing = edge_message_passing.sum(dim=-3)
        elif self._head_pooling_type == 'mean':
            node_message_passing = node_message_passing.mean(dim=-3)
            edge_message_passing = edge_message_passing.mean(dim=-3)

        return node_message_passing, edge_message_passing


class MutualAttentionTransformerLayer(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        node_out_feats: int,
        edge_in_feats: int,
        edge_out_feats: int,
        num_heads: int,
        long_residual: bool,
        dropout_probability: float,
        message_aggregation_type: str,
        head_pooling_type: str,
        normalization_type: str,
        linear_projection_activation: str = None,
        linear_embedding_activation: str = None,
    ) -> None:
        super().__init__()
        self._long_residual = long_residual
        self._mutual_multi_attention_head = MutualMultiAttentionHead(
            node_in_feats,
            edge_in_feats,
            num_heads,
            dropout_probability,
            message_aggregation_type,
            head_pooling_type,
            linear_projection_activation,
        )
        self._node_linear_embedding = LinearLayer(
            node_in_feats, node_out_feats, linear_embedding_activation)
        self._edge_linear_embedding = LinearLayer(
            edge_in_feats, edge_out_feats, linear_embedding_activation)

        if normalization_type == 'layer':
            self._node_normalization_1 = nn.LayerNorm(node_in_feats)
            self._node_normalization_2 = nn.LayerNorm(node_out_feats)

            self._edge_normalization_1 = nn.LayerNorm(edge_in_feats)
            self._edge_normalization_2 = nn.LayerNorm(edge_out_feats)
        elif normalization_type == 'batch':
            self._node_normalization_1 = nn.BatchNorm1d(node_in_feats)
            self._node_normalization_2 = nn.BatchNorm1d(node_out_feats)

            self._edge_normalization_1 = nn.BatchNorm1d(edge_in_feats)
            self._edge_normalization_2 = nn.BatchNorm1d(edge_out_feats)

    def forward(
        self,
        g: dgl.DGLGraph,
        lg: dgl.DGLGraph,
        node_inputs: torch.Tensor,
        edge_inputs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        node_embedding, edge_embedding = self._mutual_multi_attention_head(
            g, lg, node_inputs, edge_inputs)

        if self._long_residual:
            node_embedding += node_inputs
            edge_embedding += edge_inputs

        node_embedding = self._node_normalization_1(node_embedding)
        edge_embedding = self._edge_normalization_1(edge_embedding)

        node_embedding = self._node_linear_embedding(node_embedding)
        edge_embedding = self._edge_linear_embedding(edge_embedding)

        node_embedding = self._node_normalization_2(node_embedding)
        edge_embedding = self._edge_normalization_2(edge_embedding)

        return node_embedding, edge_embedding


class GraphMutualAttentionTransformer(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        node_hidden_feats: int,
        node_out_feats: int,
        edge_in_feats: int,
        edge_hidden_feats: int,
        edge_out_feats: int,
        num_layers: int,
        num_heads: int,
        long_residual: bool,
        dropout_probability: float,
        message_aggregation_type: str,
        head_pooling_type: str,
        readout_pooling_type: str,
        normalization_type: str,
        linear_projection_activation: str = None,
        linear_embedding_activation: str = None,
        bilinear_readout_activation: str = None,
    ) -> None:
        super().__init__()
        self._node_out_feats = node_out_feats
        self._edge_out_feats = edge_out_feats
        self._num_layers = num_layers
        self._transformer_layers = self._create_transformer_layers(
            node_in_feats,
            node_hidden_feats,
            node_out_feats,
            edge_in_feats,
            edge_hidden_feats,
            edge_out_feats,
            num_layers,
            num_heads,
            long_residual,
            dropout_probability,
            message_aggregation_type,
            head_pooling_type,
            normalization_type,
            linear_projection_activation,
            linear_embedding_activation,
        )
        self._bilinear_readout = BilinearReadoutLayer(
            node_out_feats, edge_out_feats, 1, bilinear_readout_activation)

        if readout_pooling_type == 'sum':
            self._readout_pooling = dgl.nn.pytorch.SumPooling()
        elif readout_pooling_type == 'mean':
            self._readout_pooling = dgl.nn.pytorch.AvgPooling()
        elif readout_pooling_type == 'attention':
            pass

    def _create_transformer_layers(
        self,
        node_in_feats: int,
        node_hidden_feats: int,
        node_out_feats: int,
        edge_in_feats: int,
        edge_hidden_feats: int,
        edge_out_feats: int,
        num_layers: int,
        num_heads: int,
        long_residual: bool,
        dropout_probability: float,
        message_aggregation_type: str,
        head_pooling_type: str,
        normalization_type: str,
        linear_projection_activation: str = None,
        linear_embedding_activation: str = None,
    ) -> nn.ModuleList:
        transformer_layers = nn.ModuleList()

        if num_layers > 1:
            transformer_layers.append(MutualAttentionTransformerLayer(
                node_in_feats,
                node_hidden_feats,
                edge_in_feats,
                edge_hidden_feats,
                num_heads,
                long_residual,
                dropout_probability,
                message_aggregation_type,
                head_pooling_type,
                normalization_type,
                linear_projection_activation,
                linear_embedding_activation,
            ))

            for _ in range(num_layers - 2):
                transformer_layers.append(MutualAttentionTransformerLayer(
                    node_hidden_feats,
                    node_hidden_feats,
                    edge_hidden_feats,
                    edge_hidden_feats,
                    num_heads,
                    long_residual,
                    dropout_probability,
                    message_aggregation_type,
                    head_pooling_type,
                    normalization_type,
                    linear_projection_activation,
                    linear_embedding_activation,
                ))

            transformer_layers.append(MutualAttentionTransformerLayer(
                node_hidden_feats,
                node_out_feats,
                edge_hidden_feats,
                edge_out_feats,
                num_heads,
                long_residual,
                dropout_probability,
                message_aggregation_type,
                head_pooling_type,
                normalization_type,
                linear_projection_activation,
                linear_embedding_activation,
            ))
        else:
            transformer_layers.append(MutualAttentionTransformerLayer(
                node_in_feats,
                node_out_feats,
                edge_in_feats,
                edge_out_feats,
                num_heads,
                long_residual,
                dropout_probability,
                message_aggregation_type,
                head_pooling_type,
                normalization_type,
                linear_projection_activation,
                linear_embedding_activation,
            ))

        return transformer_layers

    def forward(
        self,
        g: dgl.DGLGraph,
        lg: dgl.DGLGraph,
        node_inputs: torch.Tensor,
        edge_inputs: torch.Tensor,
    ) -> torch.Tensor:
        node_embedding = node_inputs
        edge_embedding = edge_inputs

        for transformer_layer in self._transformer_layers:
            node_embedding, edge_embedding = transformer_layer(
                g, lg, node_inputs, edge_inputs)
        
        node_embedding = self._readout_pooling(g, node_embedding)
        edge_embedding = self._readout_pooling(lg, edge_embedding)

        readout = self._bilinear_readout(node_embedding, edge_embedding)

        return readout


# Training

In [None]:
def train(model: nn.Module, train_dataloader, val_dataloader, test_dataloader, num_epochs: int, device: str) -> None:
    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(1, 1 + num_epochs):
        start = default_timer()

        train_loss = 0
        val_loss = 0

        # training
        model.train()

        for step, (batched_g, batched_lg, labels) in enumerate(train_dataloader):
            batched_g = batched_g.to(device)
            batched_lg = batched_lg.to(device)
            labels = labels.to(device)

            node_inputs = batched_g.ndata.pop('feat')
            edge_inputs = batched_g.edata.pop('feat')

            optimizer.zero_grad()

            pred = model(batched_g, batched_lg, node_inputs, edge_inputs).view(-1,)

            loss = torch.nn.L1Loss()(pred, labels)
            train_loss += loss

            # print(labels)

            loss.backward()
            optimizer.step()

        train_loss /= len(train_dataloader)

        # validation
        # if epoch % 10 == 0:
        model.eval()

        for step, (batched_g, batched_lg, labels) in enumerate(val_dataloader):
            batched_g = batched_g.to(device)
            batched_lg = batched_lg.to(device)
            labels = labels.to(device)

            node_inputs = batched_g.ndata.pop('feat')
            edge_inputs = batched_g.edata.pop('feat')

            with torch.no_grad():
                pred = model(batched_g, batched_lg, node_inputs, edge_inputs).view(-1,)

            loss = torch.nn.L1Loss()(pred, labels)
            val_loss += loss
        
        val_loss /= len(val_dataloader)
        
        stop = default_timer()

        print(f'Epoch: {epoch:3} Train loss: {train_loss:.2f} Validation loss: {val_loss:.2f} Epoch time: {stop - start:.2f}')
        # else:
        #     stop = default_timer()

        #     print(f'Epoch: {epoch:3} Train loss: {train_loss:.2f} Epoch time: {stop - start:.2f}')

    # # test
    # model.eval()

    # for step, (batched_g, batched_lg, labels) in enumerate(test_loader):
    #     batched_g = batched_g.to(device)
    #     batched_lg = batched_lg.to(device)
    #     labels = labels.to(device)

    #     node_inputs = batched_g.ndata.pop('feat')
    #     edge_inputs = batched_g.edata.pop('feat')

    #     optimizer.zero_grad()

    #     pred = model(batched_g, batched_lg, node_inputs, edge_inputs).view(-1,)

    #     loss = F.l1_loss(pred, labels)
    #     train_loss += loss

    #     loss.backward()
    #     optimizer.step()

    # train_loss /= len(test_dataloader)


In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

model = GraphMutualAttentionTransformer(
    node_in_feats=9,
    node_hidden_feats=9,
    node_out_feats=9,
    edge_in_feats=3,
    edge_hidden_feats=3,
    edge_out_feats=3,
    num_layers=5,
    num_heads=9,
    # short_residual=True,
    long_residual=True,
    dropout_probability=0.01,
    message_aggregation_type='sum',
    head_pooling_type='sum',
    readout_pooling_type='mean',
    normalization_type='batch',
    linear_projection_activation='relu',
    linear_embedding_activation='relu',
    bilinear_readout_activation='relu',
).to(device)

In [None]:
# torch.autograd.set_detect_anomaly(True)
train(model, train_dataloader, val_dataloader, test_dataloader, 100, device)