In [40]:
from ogb.graphproppred import PygGraphPropPredDataset
import torch
import numpy as np
import math

In [85]:
class GraphAttnBias(torch.nn.Module):
    """
    Compute attention bias for each head.
    """

    def __init__(
        self,
        num_heads,
        num_spatial
    ):
        super().__init__()
        self.num_heads = num_heads

        self.spatial_pos_encoder = torch.nn.Embedding(num_spatial, num_heads, padding_idx=0)

        self.graph_token_virtual_distance = torch.nn.Embedding(1, num_heads)

        self.spatial_pos_encoder.weight.data.normal_(mean=0.0, std=0.02)

        self.graph_token_virtual_distance.weight.data.normal_(mean=0.0, std=0.02)

    
    
    def forward(self, batched_data):
        attn_bias, spatial_pos, x = (
            batched_data["attn_bias"],
            batched_data["spatial_pos"],
            batched_data["x"],
        )
        # # in_degree, out_degree = batched_data.in_degree, batched_data.in_degree
        # edge_input, attn_edge_type = (
        #     batched_data["edge_input"],
        #     batched_data["attn_edge_type"],
        # )

        n_graph, n_node = x.size()[:2]
        graph_attn_bias = attn_bias.clone()
        graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
            1, self.num_heads, 1, 1
        )  # [n_graph, n_head, n_node+1, n_node+1]

        # spatial pos
        # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
        spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
        graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias

        # # reset spatial pos here
        # t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
        # graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
        # graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t

        # # edge feature
        # if self.edge_type == "multi_hop":
        #     spatial_pos_ = spatial_pos.clone()
        #     spatial_pos_[spatial_pos_ == 0] = 1  # set pad to 1
        #     # set 1 to 1, x > 1 to x - 1
        #     spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
        #     if self.multi_hop_max_dist > 0:
        #         spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
        #         edge_input = edge_input[:, :, :, : self.multi_hop_max_dist, :]
        #     # [n_graph, n_node, n_node, max_dist, n_head]
        #     edge_input = self.edge_encoder(edge_input).mean(-2)
        #     max_dist = edge_input.size(-2)
        #     edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(
        #         max_dist, -1, self.num_heads
        #     )
        #     edge_input_flat = torch.bmm(
        #         edge_input_flat,
        #         self.edge_dis_encoder.weight.reshape(
        #             -1, self.num_heads, self.num_heads
        #         )[:max_dist, :, :],
        #     )
        #     edge_input = edge_input_flat.reshape(
        #         max_dist, n_graph, n_node, n_node, self.num_heads
        #     ).permute(1, 2, 3, 0, 4)
        #     edge_input = (
        #         edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))
        #     ).permute(0, 3, 1, 2)
        # else:
        #     # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
        #     edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)

        # graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + edge_input
        # graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1)  # reset

        return graph_attn_bias