In [1]:
import math
from timeit import default_timer
from typing import Union, Tuple

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator
from torch.utils.data import DataLoader
from tqdm import trange, tqdm

from torch.profiler.profiler import tensorboard_trace_handler

torch.manual_seed(13)

Using backend: pytorch


<torch._C.Generator at 0x7f59c40bad70>

In [2]:
class ProcessedMolhiv(dgl.data.DGLDataset):
    def __init__(self, ogb_dataset: dgl.data.DGLDataset, normalize: bool = False) -> None:
        self._ogb_dataset = ogb_dataset
        self._normalize = normalize
        self.graphs = []
        self.line_graphs = []
        self.labels = []
        super().__init__(name='ProcessedMolhiv')

    def process(self):
        max_node = 0
        min_node = 0
        max_edge = 0
        min_edge = 0

        for i in trange(len(self._ogb_dataset)):
            g = self._ogb_dataset[i][0].add_self_loop()
            lg = dgl.line_graph(g, backtracking=False).add_self_loop()

            g.ndata['feat'] = g.ndata['feat'].float()
            g.edata['feat'] = g.edata['feat'].float()

            if self._normalize:
                g.ndata['feat'] /= 91
                g.edata['feat'] /= 3

            # if g.ndata['feat'].max() > max_node:
            #     max_node = g.ndata['feat'].max()
            
            # if g.ndata['feat'].min() < min_node:
            #     min_node = g.ndata['feat'].min()

            # if g.edata['feat'].max() > max_edge:
            #     max_edge = g.edata['feat'].max()
            
            # if g.edata['feat'].min() < min_edge:
            #     min_edge = g.edata['feat'].min()

            self.graphs.append(g)
            self.line_graphs.append(lg)
            self.labels.append(self._ogb_dataset[i][1])

        # print(f'max_node: {max_node}')
        # print(f'min_node: {min_node}')
        # print(f'max_edge: {max_edge}')
        # print(f'min_edge: {min_edge}')

    def __getitem__(self, idx: Union[int, torch.Tensor]):
        if isinstance(idx, int):
            return self.graphs[idx], self.line_graphs[idx], self.labels[idx]
        elif torch.is_tensor(idx) and idx.dtype == torch.long:
            if idx.dim() == 0:
                return self.graphs[idx], self.line_graphs[idx], self.labels[idx]
            elif idx.dim() == 1:
                return dgl.data.utils.Subset(self, idx.cpu())

    def __len__(self):
        return len(self.graphs)

dataset = DglGraphPropPredDataset(root='/home/ksadowski/datasets', name='ogbg-molhiv')
processed_dataset = ProcessedMolhiv(dataset, normalize=True)

split_idx = dataset.get_idx_split()


100%|██████████| 41127/41127 [00:30<00:00, 1343.03it/s]


In [3]:
BATCH_SIZE = 64

train_dataloader = dgl.dataloading.pytorch.GraphDataLoader(
    processed_dataset[split_idx['train']],
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)

val_dataloader = dgl.dataloading.pytorch.GraphDataLoader(
    processed_dataset[split_idx['valid']],
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
)

test_dataloader = dgl.dataloading.pytorch.GraphDataLoader(
    processed_dataset[split_idx['test']],
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
)

print(f'Train samples: {len(train_dataloader)}')
print(f'Val samples: {len(val_dataloader)}')
print(f'Test samples: {len(test_dataloader)}')

Train samples: 515
Val samples: 65
Test samples: 65


In [4]:
class LinearLayer(nn.Module):
    def __init__(
        self,
        in_feats: int,
        out_feats: int,
        activation: str = None,
    ) -> None:
        super().__init__()
        self._linear = nn.Linear(in_feats, out_feats)
        self._activation = activation

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        x = self._linear(inputs)

        if self._activation == 'relu':
            x = F.relu(x)
        elif self._activation == 'relu6':
            x = F.relu6(x)
        elif self._activation == 'leaky_relu':
            x = F.leaky_relu(x)
        elif self._activation == 'elu':
            x = F.elu(x)
        elif self._activation == 'selu':
            x = F.selu(x)
        elif self._activation == 'celu':
            x = F.celu(x)

        return x


class BilinearReadoutLayer(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        edge_in_feats: int,
        out_feats: int,
        activation: str = None,
    ) -> None:
        super().__init__()
        self._bilinear = nn.Bilinear(node_in_feats, edge_in_feats, out_feats)
        self._activation = activation

    def forward(
        self,
        node_inputs: torch.Tensor,
        edge_inputs: torch.Tensor,
    ) -> float:
        x = self._bilinear(node_inputs, edge_inputs)

        if self._activation == 'relu':
            x = F.relu(x)
        elif self._activation == 'softplus':
            x = F.softplus(x)
        elif self._activation == 'sigmoid':
            x = F.sigmoid(x)
        elif self._activation == 'softmax':
            x = F.softmax(x, dim=1)

        return x


class MutualMultiAttentionHead(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        edge_in_feats: int,
        num_heads: int,
        dropout_probability: float,
        message_aggregation_type: str,
        head_pooling_type: str,
        linear_projection_activation: str = None,
    ) -> None:
        super().__init__()
        self._node_in_feats = node_in_feats
        self._edge_in_feats = edge_in_feats
        self._num_heads = num_heads
        self._message_aggregation_type = message_aggregation_type
        self._head_pooling_type = head_pooling_type
        self._device = nn.Parameter(torch.empty(0))
        self._node_key_linear = LinearLayer(
            node_in_feats, num_heads, linear_projection_activation)
        self._node_value_linear = LinearLayer(
            node_in_feats,
            num_heads * node_in_feats,
            linear_projection_activation,
        )
        self._edge_key_linear = LinearLayer(
            edge_in_feats, num_heads, linear_projection_activation)
        self._edge_value_linear = LinearLayer(
            edge_in_feats,
            num_heads * edge_in_feats,
            linear_projection_activation,
        )
        self._node_dropout = nn.Dropout(dropout_probability)
        self._edge_dropout = nn.Dropout(dropout_probability)

    def _calculate_self_attention(
        self,
        key: torch.Tensor,
        in_feats: int,
    ) -> torch.Tensor:
        x = key / math.sqrt(in_feats)
        x = F.softmax(x, dim=1)

        return x

    @torch.jit.script
    def _node_attention_script(
        edge_self_attention: torch.Tensor, 
        source_nodes: torch.Tensor, 
        destination_nodes: torch.Tensor, 
        num_heads: int, 
        num_nodes: int, 
        num_edges: int
    ) -> torch.Tensor:
        attention_projection = torch.zeros(
            [num_heads, num_nodes, num_nodes],
            dtype=torch.float32,
            device='cpu',
        )

        for edge in range(num_edges):
            source = source_nodes[edge]
            destination = destination_nodes[edge]

            for head in range(num_heads):
                attention_score = edge_self_attention[head][edge]

                attention_projection[head][source][destination] = attention_score

        return attention_projection

    def _create_node_attention_projection(
        self,
        g: dgl.DGLGraph,
        edge_self_attention: torch.Tensor,
    ) -> torch.Tensor:
        edges = g.edges()

        node_attention_projection = self._node_attention_script(
            edge_self_attention,
            edges[0],
            edges[1],
            self._num_heads,
            g.num_nodes(),
            g.num_edges(),
        )

        return node_attention_projection

    @torch.jit.script
    def _edge_attention_script(
        node_self_attention: torch.Tensor,
        source_lg_nodes: torch.Tensor,
        destination_lg_nodes: torch.Tensor,
        destination_g_nodes: torch.Tensor,
        num_heads: int, 
        num_g_edges: int,
        num_lg_edges: int,
    ):
        attention_projection = torch.zeros(
            [num_heads, num_g_edges, num_g_edges],
            dtype=torch.float32,
            device='cpu',
        )

        for lg_edge in range(num_lg_edges):
            source = source_lg_nodes[lg_edge]
            destination = destination_lg_nodes[lg_edge]

            connecting_g_node = destination_g_nodes[source]

            for head in range(num_heads):
                attention_score = node_self_attention[head][connecting_g_node]

                attention_projection[head][source][destination] = attention_score

        return attention_projection

    def _create_edge_attention_projection(
        self,
        g: dgl.DGLGraph,
        lg: dgl.DGLGraph,
        node_self_attention: torch.Tensor,
    ) -> torch.Tensor:

        edge_attention_projection = self._edge_attention_script(
            node_self_attention,
            lg.edges()[0],
            lg.edges()[1],
            g.edges()[1],
            self._num_heads, 
            g.num_edges(),
            lg.num_edges(),
            )

        return edge_attention_projection

    def _calculate_message_passing(
        self,
        g: dgl.DGLGraph,
        value: torch.Tensor,
        attention_projection: torch.Tensor,
    ) -> torch.Tensor:
        if self._message_aggregation_type == 'sum':
            x = attention_projection
        elif self._message_aggregation_type == 'mean':
            degree_inv = torch.linalg.inv(torch.diag(g.in_degrees().float()))

            x = degree_inv @ attention_projection
        elif self._message_aggregation_type == 'gcn':
            adjacency = g.adj(ctx=self._device.device).to_dense()

            degree_inv_sqrt = torch.sqrt(torch.linalg.inv(
                torch.diag(g.in_degrees().float())))
            adjacency_inv_sqrt = torch.sqrt(torch.linalg.inv(adjacency))

            x = degree_inv_sqrt @ attention_projection * \
                adjacency_inv_sqrt @ degree_inv_sqrt

        message_passing = x @ value

        return message_passing

    def forward(
        self,
        g: dgl.DGLGraph,
        lg: dgl.DGLGraph,
        node_inputs: torch.Tensor,
        edge_inputs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        node_key = self._node_key_linear(node_inputs)
        node_key = node_key.view(self._num_heads, -1, 1)
        node_value = self._node_value_linear(node_inputs)
        node_value = node_value.view(self._num_heads, -1, self._node_in_feats)

        edge_key = self._edge_key_linear(edge_inputs)
        edge_key = edge_key.view(self._num_heads, -1, 1)
        edge_value = self._edge_value_linear(edge_inputs)
        edge_value = edge_value.view(self._num_heads, -1, self._edge_in_feats)

        node_self_attention = self._calculate_self_attention(
            node_key, self._node_in_feats)
        edge_self_attention = self._calculate_self_attention(
            edge_key, self._edge_in_feats)

        node_attention_projection = self._create_node_attention_projection(
            g, edge_self_attention)
        edge_attention_projection = self._create_edge_attention_projection(
            g, lg, node_self_attention)

        node_message_passing = self._calculate_message_passing(
            g, node_value, node_attention_projection)
        edge_message_passing = self._calculate_message_passing(
            lg, edge_value, edge_attention_projection)

        node_message_passing = self._node_dropout(node_message_passing)
        edge_message_passing = self._edge_dropout(edge_message_passing)

        if self._head_pooling_type == 'sum':
            node_message_passing = node_message_passing.sum(dim=-3)
            edge_message_passing = edge_message_passing.sum(dim=-3)
        elif self._head_pooling_type == 'mean':
            node_message_passing = node_message_passing.mean(dim=-3)
            edge_message_passing = edge_message_passing.mean(dim=-3)

        return node_message_passing, edge_message_passing


class MutualAttentionTransformerLayer(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        node_out_feats: int,
        edge_in_feats: int,
        edge_out_feats: int,
        num_heads: int,
        long_residual: bool,
        dropout_probability: float,
        message_aggregation_type: str,
        head_pooling_type: str,
        normalization_type: str,
        linear_projection_activation: str = None,
        linear_embedding_activation: str = None,
    ) -> None:
        super().__init__()
        self._long_residual = long_residual
        self._mutual_multi_attention_head = MutualMultiAttentionHead(
            node_in_feats,
            edge_in_feats,
            num_heads,
            dropout_probability,
            message_aggregation_type,
            head_pooling_type,
            linear_projection_activation,
        )
        self._node_linear_embedding = LinearLayer(
            node_in_feats, node_out_feats, linear_embedding_activation)
        self._edge_linear_embedding = LinearLayer(
            edge_in_feats, edge_out_feats, linear_embedding_activation)

        if normalization_type == 'layer':
            self._node_normalization_1 = nn.LayerNorm(node_in_feats)
            self._node_normalization_2 = nn.LayerNorm(node_out_feats)

            self._edge_normalization_1 = nn.LayerNorm(edge_in_feats)
            self._edge_normalization_2 = nn.LayerNorm(edge_out_feats)
        elif normalization_type == 'batch':
            self._node_normalization_1 = nn.BatchNorm1d(node_in_feats)
            self._node_normalization_2 = nn.BatchNorm1d(node_out_feats)

            self._edge_normalization_1 = nn.BatchNorm1d(edge_in_feats)
            self._edge_normalization_2 = nn.BatchNorm1d(edge_out_feats)

    def forward(
        self,
        g: dgl.DGLGraph,
        lg: dgl.DGLGraph,
        node_inputs: torch.Tensor,
        edge_inputs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        node_embedding, edge_embedding = self._mutual_multi_attention_head(
            g, lg, node_inputs, edge_inputs)

        if self._long_residual:
            node_embedding += node_inputs
            edge_embedding += edge_inputs

        node_embedding = self._node_normalization_1(node_embedding)
        edge_embedding = self._edge_normalization_1(edge_embedding)

        node_embedding = self._node_linear_embedding(node_embedding)
        edge_embedding = self._edge_linear_embedding(edge_embedding)

        node_embedding = self._node_normalization_2(node_embedding)
        edge_embedding = self._edge_normalization_2(edge_embedding)

        return node_embedding, edge_embedding


class GraphMutualAttentionTransformer(nn.Module):
    def __init__(
        self,
        node_in_feats: int,
        node_hidden_feats: int,
        node_out_feats: int,
        edge_in_feats: int,
        edge_hidden_feats: int,
        edge_out_feats: int,
        num_layers: int,
        num_heads: int,
        long_residual: bool,
        dropout_probability: float,
        message_aggregation_type: str,
        head_pooling_type: str,
        readout_pooling_type: str,
        normalization_type: str,
        linear_projection_activation: str = None,
        linear_embedding_activation: str = None,
        bilinear_readout_activation: str = None,
    ) -> None:
        super().__init__()
        self._node_out_feats = node_out_feats
        self._edge_out_feats = edge_out_feats
        self._num_layers = num_layers
        self._transformer_layers = self._create_transformer_layers(
            node_in_feats,
            node_hidden_feats,
            node_out_feats,
            edge_in_feats,
            edge_hidden_feats,
            edge_out_feats,
            num_layers,
            num_heads,
            long_residual,
            dropout_probability,
            message_aggregation_type,
            head_pooling_type,
            normalization_type,
            linear_projection_activation,
            linear_embedding_activation,
        )
        self._bilinear_readout = BilinearReadoutLayer(
            node_out_feats, edge_out_feats, 1, bilinear_readout_activation)

        if readout_pooling_type == 'sum':
            self._readout_pooling = dgl.nn.pytorch.SumPooling()
        elif readout_pooling_type == 'mean':
            self._readout_pooling = dgl.nn.pytorch.AvgPooling()
        elif readout_pooling_type == 'attention':
            pass

    def _create_transformer_layers(
        self,
        node_in_feats: int,
        node_hidden_feats: int,
        node_out_feats: int,
        edge_in_feats: int,
        edge_hidden_feats: int,
        edge_out_feats: int,
        num_layers: int,
        num_heads: int,
        long_residual: bool,
        dropout_probability: float,
        message_aggregation_type: str,
        head_pooling_type: str,
        normalization_type: str,
        linear_projection_activation: str = None,
        linear_embedding_activation: str = None,
    ) -> nn.ModuleList:
        transformer_layers = nn.ModuleList()

        if num_layers > 1:
            transformer_layers.append(MutualAttentionTransformerLayer(
                node_in_feats,
                node_hidden_feats,
                edge_in_feats,
                edge_hidden_feats,
                num_heads,
                long_residual,
                dropout_probability,
                message_aggregation_type,
                head_pooling_type,
                normalization_type,
                linear_projection_activation,
                linear_embedding_activation,
            ))

            for _ in range(num_layers - 2):
                transformer_layers.append(MutualAttentionTransformerLayer(
                    node_hidden_feats,
                    node_hidden_feats,
                    edge_hidden_feats,
                    edge_hidden_feats,
                    num_heads,
                    long_residual,
                    dropout_probability,
                    message_aggregation_type,
                    head_pooling_type,
                    normalization_type,
                    linear_projection_activation,
                    linear_embedding_activation,
                ))

            transformer_layers.append(MutualAttentionTransformerLayer(
                node_hidden_feats,
                node_out_feats,
                edge_hidden_feats,
                edge_out_feats,
                num_heads,
                long_residual,
                dropout_probability,
                message_aggregation_type,
                head_pooling_type,
                normalization_type,
                linear_projection_activation,
                linear_embedding_activation,
            ))
        else:
            transformer_layers.append(MutualAttentionTransformerLayer(
                node_in_feats,
                node_out_feats,
                edge_in_feats,
                edge_out_feats,
                num_heads,
                long_residual,
                dropout_probability,
                message_aggregation_type,
                head_pooling_type,
                normalization_type,
                linear_projection_activation,
                linear_embedding_activation,
            ))

        return transformer_layers

    def forward(
        self,
        g: dgl.DGLGraph,
        lg: dgl.DGLGraph,
        node_inputs: torch.Tensor,
        edge_inputs: torch.Tensor,
    ) -> torch.Tensor:
        node_embedding = node_inputs
        edge_embedding = edge_inputs

        for transformer_layer in self._transformer_layers:
            node_embedding, edge_embedding = transformer_layer(
                g, lg, node_embedding, edge_embedding)

        node_embedding = self._readout_pooling(g, node_embedding)
        edge_embedding = self._readout_pooling(lg, edge_embedding)

        readout = self._bilinear_readout(node_embedding, edge_embedding)

        return readout

In [5]:
def train(model: nn.Module, device: str, dataloader: DataLoader) -> None:
    optimizer = torch.optim.Adam(model.parameters())
    loss_accum = 0

    model.train()

    # with torch.profiler.profile(
    #     schedule=torch.profiler.schedule(
    #         wait=2,
    #         warmup=2,
    #         active=6,
    #         repeat=1,
    #     ),
    #     activities=[
    #         torch.profiler.ProfilerActivity.CPU,
    #         # torch.profiler.ProfilerActivity.CUDA,
    #     ],
    #     on_trace_ready=tensorboard_trace_handler('/home/ksadowski/projects/data_science/gmat/ogb_molhiv/profiler_logs'),
    # ) as profiler:
    # with tqdm(total=len(dataloader), desc='Batch steps') as pbar:
    for batched_g, batched_lg, labels in dataloader:
        batched_g = batched_g.to(device)
        batched_lg = batched_lg.to(device)
        labels = labels.to(device)

        node_inputs = batched_g.ndata.pop('feat')
        edge_inputs = batched_g.edata.pop('feat')

        optimizer.zero_grad()

        pred = model(batched_g, batched_lg, node_inputs, edge_inputs)

        loss = F.binary_cross_entropy_with_logits(pred.to(torch.float32), labels.to(torch.float32))
        loss_accum += loss

        loss.backward()
        optimizer.step()
        # profiler.step()

            # pbar.update(1)

    return loss_accum / len(dataloader)


def eval(model: nn.Module, device: str, dataloader: DataLoader, evaluator: Evaluator):
    model.eval()

    y_true = []
    y_pred = []

    for batched_g, batched_lg, labels in dataloader:
        batched_g = batched_g.to(device)
        batched_lg = batched_lg.to(device)
        labels = labels.to(device)

        node_inputs = batched_g.ndata.pop('feat')
        edge_inputs = batched_g.edata.pop('feat')

        with torch.no_grad():
            pred = model(batched_g, batched_lg, node_inputs, edge_inputs)

        y_true.append(labels.view(pred.shape).detach().cpu())
        y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)

In [6]:
g = processed_dataset[0][0]
lg = processed_dataset[0][1]

node_inputs = g.ndata['feat']
edge_inputs = g.edata['feat']



In [7]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

model = GraphMutualAttentionTransformer(
    node_in_feats=9,
    node_hidden_feats=9,
    node_out_feats=9,
    edge_in_feats=3,
    edge_hidden_feats=3,
    edge_out_feats=3,
    num_layers=3,
    num_heads=4,
    # short_residual=True,
    long_residual=True,
    dropout_probability=0.01,
    message_aggregation_type='sum',
    head_pooling_type='mean',
    readout_pooling_type='mean',
    normalization_type='batch',
    linear_projection_activation='relu',
    linear_embedding_activation='relu',
    # bilinear_readout_activation='softmax',
).to(device)

# model(g, node_inputs, edge_inputs)



In [8]:
NUM_EPOCHS = 100

evaluator = Evaluator(name='ogbg-molhiv')

for epoch in range(1, 1 + NUM_EPOCHS):
    start = default_timer()

    train_perf = train(model, device, train_dataloader)

    # train_perf = eval(model, device, train_loader, evaluator)
    val_perf = eval(model, device, val_dataloader, evaluator)
    val_perf = val_perf['rocauc']
    test_perf = eval(model, device, test_dataloader, evaluator)
    test_perf = test_perf['rocauc']

    stop = default_timer()

    print(
        f'Epoch: {epoch: 3} Train Loss: {train_perf:.4f} '
        f'ROC AUC // Val: {val_perf:.4f} Test: {test_perf:.4f} // '
        f'Epoch Time: {(stop - start) / 60:.2f} min.'
    )

Epoch:   1 Train Loss: 0.3739 ROC AUC // Val: 0.5584 Test: 0.5190 // Epoch Time: 2.10 min.
Epoch:   2 Train Loss: 0.1573 ROC AUC // Val: 0.6440 Test: 0.6328 // Epoch Time: 2.10 min.
Epoch:   3 Train Loss: 0.1521 ROC AUC // Val: 0.6373 Test: 0.6291 // Epoch Time: 2.11 min.
Epoch:   4 Train Loss: 0.1499 ROC AUC // Val: 0.6546 Test: 0.6215 // Epoch Time: 2.11 min.
Epoch:   5 Train Loss: 0.1495 ROC AUC // Val: 0.6977 Test: 0.6453 // Epoch Time: 2.12 min.
Epoch:   6 Train Loss: 0.1484 ROC AUC // Val: 0.6929 Test: 0.6702 // Epoch Time: 2.10 min.
Epoch:   7 Train Loss: 0.1479 ROC AUC // Val: 0.7019 Test: 0.6541 // Epoch Time: 2.12 min.
Epoch:   8 Train Loss: 0.1474 ROC AUC // Val: 0.6839 Test: 0.6556 // Epoch Time: 2.13 min.
Epoch:   9 Train Loss: 0.1475 ROC AUC // Val: 0.7083 Test: 0.6550 // Epoch Time: 2.11 min.
Epoch:  10 Train Loss: 0.1469 ROC AUC // Val: 0.7024 Test: 0.6791 // Epoch Time: 2.10 min.
Epoch:  11 Train Loss: 0.1479 ROC AUC // Val: 0.6706 Test: 0.6607 // Epoch Time: 2.12 min.

KeyboardInterrupt: 