In [8]:
from torchdrug import  layers, datasets,transforms,core
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
from torchdrug import data

from torch_scatter import scatter_add
import torch.nn as nn
from torchdrug import utils
from torch.utils import checkpoint
import torch.nn.functional as F
import torch.optim as optim
from einops import rearrange, repeat, pack, unpack

import time

if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)
else:
    raise ValueError("CUDA device 0 is not available")

# Get datasets

In [2]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

# 只保留alpha碳的简化格式
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )

14:49:04   Extracting /home/xiaotong/scratch/protein-datasets/EnzymeCommission.zip to /home/xiaotong/scratch/protein-datasets


Loading /home/xiaotong/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:49<00:00, 377.87it/s]


In [4]:
graphs = dataset[4:8]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print("\n\n")
graph = graph_construction_model(graphs)
print(graph)




PackedProtein(batch_size=4, num_atoms=[437, 437, 332, 224], num_bonds=[8141, 8207, 6874, 5209], num_residues=[437, 437, 332, 224])


In [5]:
print(graph.num_nodes)
print(graph.batch_size)
print(graph.edge_list)

tensor([437, 437, 332, 224])
4
tensor([[ 132,  131,    3],
        [ 147,  146,    3],
        [ 146,  145,    3],
        ...,
        [1262, 1429,    0],
        [1260, 1429,    0],
        [1263, 1429,    0]])


# Relationnal conv graph 

In [106]:
class relationalGraph(layers.MessagePassingBase):
    
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None
            
    def trans(self, A, graph):
        n_rel = graph.num_relation
        A = A.view(A.size(0),n_rel, A.size(1)//n_rel).permute(1, 0, 2)
       # 初始化结果张量
        A_norm = torch.zeros_like(A)

        # 对每个 (1024, 1024) 矩阵进行处理
        for i in range(A.size(0)):
            # 计算度矩阵 (按行求和)
            degree = A[i].sum(dim=1)
            
            # 计算度矩阵的逆平方根
            degree_inv_sqrt = torch.pow(degree, -0.5)
            
            # 将度矩阵逆平方根转换为对角矩阵
            Degree_inv_sqrt = torch.diag(degree_inv_sqrt)
            
            # 进行归一化操作
            A_norm[i] = torch.mm(torch.mm(Degree_inv_sqrt, A[i]), Degree_inv_sqrt)
        Anorm = A_norm.permute(1, 0, 2).contiguous().view(A_norm.size(1), A_norm.size(0)*A_norm.size(2))
    
        return Anorm

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation
        device = input.device  # Ensure device consistency
        
        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
        
            degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
            edge_weight = graph.edge_weight / degree_out[node_out]
            adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                                (graph.num_node, graph.num_node * graph.num_relation))
            update = torch.sparse.mm(adjacency.t().to(device), input.to(device))
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t().to(device), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update

    def combine(self, input, update):
        # 自环特征
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        input = input.repeat(self.num_relation, 1).to(device)
        loop_update = self.self_loop(input).to(device)
        
        output = self.linear(update)+loop_update
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, new_edge_list=None):
        device = input.device
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input, new_edge_list)
        else:
            update = self.message_and_aggregate(graph.to(device), input, new_edge_list)
        output = self.combine(input, update).view(graph.num_relation, input.size(0), -1)
        return output

In [37]:
input_dim = graph.node_feature.shape[-1]
output_dim = 128
num_relations = graph.num_relation

input = graph.node_feature.float().to(device)
model = relationalGraph(input_dim, output_dim, num_relations)
relational_output = model(graph, input, new_edge_list = None)
print("output: ", relational_output.shape)

output:  torch.Size([7, 1430, 128])


# Local attention for rewire

In [66]:
class Rewirescorelayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, temperature=0.5):
        super(Rewirescorelayer, self).__init__()
        self.input_dim = in_features
        self.output_dim = out_features
        self.num_heads = num_heads
        self.window_size = window_size
        self.k = k
        self.temperature = temperature
        
        self.query = nn.Linear(in_features, out_features * num_heads).to(device)
        self.key = nn.Linear(in_features, out_features * num_heads).to(device)
        self.scale = 1 / (out_features ** 0.5)
    
    
    
    class LocalAttention(nn.Module):
        def __init__(
            self,
            window_size,
            look_backward = 1,
            look_forward = None,
            dim = None,
            scale = None,
            pad_start_position = None
        ):
            super().__init__()

            self.scale = scale

            self.window_size = window_size

            self.look_backward = look_backward
            self.look_forward = look_forward
            
            self.pad_start_position = pad_start_position

        def exists(self,val):
            return val is not None

        # 如果value不存在，返回d
        def default(self,value, d):
            return d if not self.exists(value) else value

        def to(self, t):
            return {'device': t.device, 'dtype': t.dtype}

        def max_neg_value(self, tensor):
            return -torch.finfo(tensor.dtype).max  #返回给定张量数据类型的所能表示的最大负值

        def look_around(self, x, backward = 1, forward = 0, pad_value = -1, dim = 2):  #x = bk: (40, 32, 16, 64)
            t = x.shape[1]    #获取一共有多少个窗口，这里是32
            dims = (len(x.shape) - dim) * (0, 0)   #一个长度为 len(x.shape) - dim 的元组，每个元素为 (0, 0)；其中len(x.shape) = 4
            padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)   #在第二维度上，前面加backward个元素，后面加forward个元素 -> (40, 33, 16, 64)
            tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)] #一个张量列表，每个张量的维度为(40, 32, 16, 64), len = 2
            return torch.cat(tensors, dim = dim) #在第二维度上拼接 -> (40, 32, 32, 64)
    
        def forward(
            self,
            q, k,
            mask = None,
            input_mask = None,
            window_size = None
        ):

            mask = self.default(mask, input_mask)
            assert not (self.exists(window_size) and not self.use_xpos), 'cannot perform window size extrapolation if xpos is not turned on'
            shape, pad_value, window_size, look_backward, look_forward = q.shape, -1, self.default(window_size, self.window_size), self.look_backward, self.look_forward
            (q, packed_shape), (k, _) = map(lambda t: pack([t], '* n d'), (q, k))  #打包成[5, 8, 512, 64] -> [40, 512, 64] 


            b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype   # 40, 512, 64
            scale = self.default(self.scale, dim_head ** -0.5)
            assert (n % window_size) == 0, f'sequence length {n} must be divisible by window size {window_size} for local attention'

            windows = n // window_size  # 512 / 16 = 32

            seq = torch.arange(n, device = device)                  # 0, 1, 2, 3, ..., 511
            b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)    # (1, 32, 16) 排序序列变形后的矩阵

            # bucketing

            bq, bk = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k)) #重构：（40，512，64）->（40, 32, 16, 64）

            bq = bq * scale    # (40, 32, 16, 64)

            look_around_kwargs = dict(
                backward =  look_backward,
                forward =  look_forward,
                pad_value = pad_value
            )

            bk = self.look_around(bk, **look_around_kwargs)      # (40, 32, 32, 64)
    

            # calculate positions for masking

            bq_t = b_t
            bq_k = self.look_around(b_t, **look_around_kwargs) # (1, 32, 32)

            bq_t = rearrange(bq_t, '... i -> ... i 1')      # (1, 32, 16, 1)
            bq_k = rearrange(bq_k, '... j -> ... 1 j')      # (1, 32, 1, 16)

            pad_mask = bq_k == pad_value

            sim = torch.einsum('b h i e, b h j e -> b h i j', bq, bk)  # (40, 32, 16, 64) * (40, 32, 32, 64) -> (40, 32, 16, 32)

            mask_value = self.max_neg_value(sim)

            sim = sim.masked_fill(pad_mask, mask_value)


            if self.exists(mask):
                batch = mask.shape[0]    # 5
                assert (b % batch) == 0

                h = b // mask.shape[0]  # 8

                mask = rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size)
                mask = self.look_around(mask, **{**look_around_kwargs, 'pad_value': False})
                mask = rearrange(mask, '... j -> ... 1 j')
                mask = repeat(mask, 'b ... -> (b h) ...', h = h)

                sim = sim.masked_fill(~mask, mask_value)
                del mask
                
            indices = [self.pad_start_position[i] // window_size for i in range(len(self.pad_start_position)) if i % 2 != 0]
            all_indices = list(range(windows))
            remaining_indices = [idx for idx in all_indices if idx not in indices]
            
            # 使用剩余的索引选择元素
            rest_sim = sim[:, remaining_indices, :, :]

            # attention
            attn = rest_sim.softmax(dim = -1)

            return attn

    def insert_zero_rows(self, tensor, lengths, target_lengths):
        assert len(lengths) == len(target_lengths), "Lengths and target lengths must be of the same length."
        
        # 计算每个位置需要插入的零行数
        zero_rows = [target - length for length, target in zip(lengths, target_lengths)]
        
        # 初始化结果列表
        parts = []
        mask_parts = []
        start = 0
        
        for i, length in enumerate(lengths):
            end = start + length
            
            # 原始张量部分
            parts.append(tensor[:, start:end, :])
            mask_parts.append(torch.ones(tensor.size(0), length, dtype=torch.bool, device=tensor.device))
            
            # 插入零行
            if zero_rows[i] > 0:
                zero_padding = torch.zeros(tensor.size(0), zero_rows[i], tensor.size(2), device=tensor.device)
                mask_padding = torch.zeros(tensor.size(0), zero_rows[i], dtype=torch.bool, device=tensor.device)
                parts.append(zero_padding)
                mask_parts.append(mask_padding)
            
            start = end
        
        # 拼接所有部分
        padded_tensor = torch.cat(parts, dim=1)
        mask = torch.cat(mask_parts, dim=1)
        
        return padded_tensor, mask


    def round_up_to_nearest_k_and_a_window_size(self, lst, k):
        pad_start_position = []
        result_lst = [(x + k - 1) // k * k +k for x in lst]
        for i in range(len(lst)):
            pad_start_position.append(sum(result_lst[:i])-i*k + lst[i])
            pad_start_position.append(sum(result_lst[:i+1])-k)
        return result_lst, pad_start_position

    def gumbel_softmax_top_k(self, logits,  top_k,  hard=False):
            gumbels = -torch.empty_like(logits).exponential_().log()
            gumbels = (logits + gumbels) / self.temperature

            y_soft = F.softmax(gumbels, dim=-1)

            if hard:
                topk_indices = logits.topk(top_k, dim=-1)[1]
                y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
                y = (y_hard - y_soft).detach() + y_soft
            else:
                y = y_soft

            return y
        
    def displace_tensor_blocks_to_rectangle(self, tensor, displacement):
        batch_size, num_blocks, block_height, block_width = tensor.shape

        # 计算新矩阵的宽度和高度
        height = num_blocks * displacement
        width =  (2 + num_blocks) * displacement

        # 初始化新的大张量，确保其形状为 (batch_size, height, width)
        new_tensor = torch.zeros(batch_size, height, width, device=tensor.device, dtype=tensor.dtype)

        for i in range(num_blocks):
            start_pos_height = i * displacement
            start_pos_width = i * displacement
            end_pos_height = start_pos_height + block_height
            end_pos_width = start_pos_width + block_width

            new_tensor[:, start_pos_height:end_pos_height, start_pos_width:end_pos_width] = tensor[:, i, :, :]

        return new_tensor
    
    def forward(self, graph, node_features):
        
        num_relation = node_features.size(0)
        index = graph.num_nodes.tolist()
        
        target_input, pad_start_position = self.round_up_to_nearest_k_and_a_window_size(index, self.window_size)
        padding_input, mask = self.insert_zero_rows(node_features, index, target_input)
        
        Q = self.query(padding_input).view(num_relation, padding_input.size(1), self.num_heads, self.output_dim).permute(0, 2, 1, 3)                           # (num_relations, num_nodes, num_heads, out_features
        K = self.key(padding_input).view(num_relation, padding_input.size(1), self.num_heads, self.output_dim).permute(0, 2, 1, 3)                             # (num_relations, num_nodes, num_heads, out_features)
        Q = Q.reshape(num_relation * self.num_heads, padding_input.size(1), self.output_dim)                                                  # (num_relations*num_heads, num_nodes, out_features)
        K = K.reshape(num_relation * self.num_heads, padding_input.size(1), self.output_dim) 
        
        attn = self.LocalAttention(
            dim = self.output_dim,                   # dimension of each head (you need to pass this in for relative positional encoding)
            window_size = self.window_size,          # window size. 512 is optimal, but 256 or 128 yields good enough results
            look_backward = 1,                  # each window looks at the window before
            look_forward = 1,                   # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it
            pad_start_position = pad_start_position
        ) 
        
        attn = attn(Q, K, mask = mask).view(num_relation, self.num_heads, -1, self.window_size, 3*self.window_size).mean(dim=1)  
        score = self.gumbel_softmax_top_k(attn, self.k, hard=True)
        
        result_tensor = self.displace_tensor_blocks_to_rectangle(score, self.window_size)
        result_tensor = result_tensor[:, :, 10:-10]
        indice = [pad_start_position[i] for i in range(len(pad_start_position)) if i % 2 == 0]
        indices = []

        for num in indice:
            next_multiple_of_10 = ((num + 9) // 10) * 10  # 计算向上取10的倍数
            sequence = range(num, next_multiple_of_10)  # 生成序列
            indices.extend(sequence)  # 直接将序列中的元素添加到结果列表中
        all_indices = list(range(result_tensor.size(1)))
        remaining_indices = [idx for idx in all_indices if idx not in indices]
        
        result_tensor = result_tensor[:, remaining_indices, :]
        result_tensor = result_tensor[:, :, remaining_indices]
        
        
        return result_tensor.permute(1, 0, 2).contiguous().view(result_tensor.size(1), result_tensor.size(0)*result_tensor.size(2))
        


In [67]:
input_dim = relational_output.shape[-1]
output_dim = 256
num_heads = 8
window_size = 10
k = 3
num_nodes = relational_output.size(0)

model = Rewirescorelayer(input_dim, output_dim, num_heads, window_size, k)
start = time.time()
attn_output = model(graph, relational_output)
end = time.time()
print(f"运行时间: {end - start:.6f} 秒")
print(attn_output.shape)

运行时间: 0.015247 秒
torch.Size([1430, 10010])


# Diffusion

In [69]:
class RewireGearnet(nn.Module):
    gradient_checkpoint = False

    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(RewireGearnet, self).__init__()
        self.num_relation = num_relation
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.linear = nn.Linear(num_relation * input_dim, output_dim)
        self.self_loop = nn.Linear(input_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(output_dim) if batch_norm else None
        self.activation = getattr(F, activation) if activation else None
        self.edge_linear = nn.Linear(edge_input_dim, output_dim) if edge_input_dim else None

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation

        device = input.device  # Ensure device consistency

        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
            adjacency = torch.sparse_coo_tensor(
                torch.stack([node_in, node_out]),
                graph.edge_weight.to(device),
                (graph.num_node, graph.num_node * graph.num_relation),
                device=device
            )
            update = torch.sparse.mm(adjacency.t(), input)
        else:
            update = torch.mm(new_edge_list.t(), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = graph.edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(
                edge_input * edge_weight, node_out, dim=0,
                dim_size=graph.num_node * graph.num_relation
            )
            update += edge_update
            
        return update.view(graph.num_node, self.num_relation * self.input_dim).to(device)

    def combine(self, input, update):
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        if self.batch_norm:
            self.batch_norm.to(device)
        
        output = self.linear(update) + self.self_loop(input)
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

    def forward(self, graph, input, new_edge_list=None):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self.message_and_aggregate, graph, input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

In [70]:
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relations = graph.num_relation
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


new_node_feature = RewireGearnet(input_dim, output_dim, num_relations)(graph, graph.node_feature.to(device).float(), attn_output).to(device)

print(new_node_feature.shape)
print(new_node_feature)

torch.Size([1430, 512])
tensor([[0.0668, 0.2442, 0.0000,  ..., 0.1006, 0.0000, 0.1292],
        [0.0636, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.6651, 0.1436, 0.6689,  ..., 0.0000, 0.0000, 0.0905],
        ...,
        [0.0514, 0.3041, 0.0401,  ..., 0.0000, 0.0000, 0.1909],
        [0.0000, 0.0000, 0.2475,  ..., 0.3640, 0.0000, 0.3992],
        [0.0122, 0.1713, 0.0000,  ..., 0.0000, 0.0000, 0.0040]],
       device='cuda:0', grad_fn=<ReluBackward0>)


# Final model

In [114]:
# 定义一个装饰器来计时
def time_layer(layer, layer_name):
    def timed_layer(*args, **kwargs):
        start_time = time.time()
        output = layer(*args, **kwargs)
        end_time = time.time()
        print(f'{layer_name}: {end_time - start_time:.6f} seconds')
        return output
    return timed_layer

In [117]:
class DGMGearnet(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, score_dim, num_relation, num_heads, window_size, k, edge_input_dim=None, num_angle_bin=None,
                 short_cut=False, batch_norm=False, activation="relu", concat_hidden=True, readout="sum"):
        super(DGMGearnet, self).__init__()

        #if not isinstance(hidden_dims, Sequence):
            #hidden_dims = [hidden_dims]
        self.input_dim = input_dim
        self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
        self.dims = [input_dim] + list(hidden_dims)
        self.score_dim = score_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.k = k
        self.edge_dims = [edge_input_dim] + self.dims[:-1]
        self.num_relation = num_relation
        self.num_angle_bin = num_angle_bin
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden
        self.batch_norm = batch_norm

        self.layers = nn.ModuleList()
        self.score_layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            
            self.score_layers.append(relationalGraph(self.dims[i], self.score_dim, num_relation, 
                                                     edge_input_dim=None, batch_norm=False, activation="relu")) 

            
            self.score_layers.append(Rewirescorelayer(self.score_dim, self.dims[i+1], self.num_heads, self.window_size, 
                                            self.k, temperature=0.5))
            
            self.layers.append(RewireGearnet(self.dims[i], self.dims[i + 1], num_relation,
                                            edge_input_dim=None, batch_norm=False, activation="relu"))
        
        if num_angle_bin:
            self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
            self.edge_layers = nn.ModuleList()
            for i in range(len(self.edge_dims) - 1):
                self.edge_layers.append(layers.GeometricRelationalGraphConv(
                    self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

        if batch_norm:
            self.batch_norms = nn.ModuleList()
            for i in range(len(self.dims) - 1):
                self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1]))

        if readout == "sum":
            self.readout = layers.SumReadout()
        elif readout == "mean":
            self.readout = layers.MeanReadout()
        else:
            raise ValueError("Unknown readout `%s`" % readout)

    def forward(self, graph, input, edge_list=None, all_loss=None, metric=None):
        hiddens = []
        layer_input = input
        if self.num_angle_bin:
            line_graph = self.spatial_line_graph(graph)
            edge_input = line_graph.node_feature.float()

        for i in range(len(self.layers)):
            
            
            relational_output = time_layer(self.score_layers[2*i], 'relational output')(graph, layer_input, edge_list)
            new_edge_list = time_layer(self.score_layers[2*i+1], 'new edge list')(graph, relational_output)
            
            hidden = time_layer(self.layers[i], 'hidden')(graph, layer_input, new_edge_list)
            
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
            
            if self.num_angle_bin:
                edge_hidden = self.edge_layers[i](line_graph, edge_input)
                edge_weight = graph.edge_weight.unsqueeze(-1)
                if new_edge_list is None:
                    node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
                else:
                    node_out = new_edge_list[:, 1] * self.num_relation + new_edge_list[:, 2]
                
                    update = scatter_add(edge_hidden * edge_weight, node_out, dim=0,
                                        dim_size=graph.num_node * self.num_relation)
                update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1])
                update = self.layers[i].linear(update)
                update = self.layers[i].activation(update)
                hidden = hidden + update
                edge_input = edge_hidden
                
            if self.batch_norm:
                hidden = self.batch_norms[i](hidden)
                
            hiddens.append(hidden)
            layer_input = hidden
            edge_list = new_edge_list

        if self.concat_hidden:
            node_feature = torch.cat(hiddens, dim=-1)
        else:
            node_feature = hiddens[-1]
        graph_feature = self.readout(graph, node_feature)

        return {
            "graph_feature": graph_feature,
            "node_feature": node_feature
        }

In [120]:
input_dim = graph.node_feature.shape[-1]
hidden_dims = [512, 512, 512, 512, 512]
score_dim = 128
num_relations = graph.num_relation
num_heads = 8
window_size = 10
k = 5


output = DGMGearnet(input_dim, hidden_dims, score_dim, num_relations, num_heads, window_size, k).to(device)(graph.to(device), graph.node_feature.to(device).float())

relational output: 0.001817 seconds
new edge list: 0.012704 seconds
hidden: 0.000778 seconds
relational output: 0.003483 seconds
new edge list: 0.011794 seconds
hidden: 0.001010 seconds
relational output: 0.003436 seconds
new edge list: 0.011773 seconds
hidden: 0.001004 seconds
relational output: 0.003420 seconds
new edge list: 0.011732 seconds
hidden: 0.000989 seconds
relational output: 0.003418 seconds
new edge list: 0.011899 seconds
hidden: 0.001000 seconds


In [121]:
print(output["node_feature"])
print(output["node_feature"].shape)
print("\n")

print(output["graph_feature"])
print(output["graph_feature"].shape)

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  3.4151,  1.6128,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., 12.9921, 14.8914,  0.0000],
        [ 0.2311,  0.0000,  0.0000,  ..., 18.6177,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.4006,  0.0000,  ...,  6.1682,  8.7172,  0.0000],
        [ 0.4184,  0.0718,  0.0000,  ...,  0.0000,  2.0799,  0.0000],
        [ 0.0785,  0.1888,  0.0000,  ...,  3.8867,  0.0000,  0.0000]],
       device='cuda:0', grad_fn=<CatBackward>)
torch.Size([1430, 2560])


tensor([[5.9159e+01, 2.7642e+01, 5.1323e+00,  ..., 4.5037e+03, 1.9291e+03,
         3.3170e+03],
        [6.2541e+01, 2.6804e+01, 5.2057e+00,  ..., 4.1339e+03, 1.8076e+03,
         3.7433e+03],
        [4.4461e+01, 2.2220e+01, 5.0708e+00,  ..., 3.4908e+03, 1.5562e+03,
         2.3470e+03],
        [3.4962e+01, 1.7107e+01, 4.3736e+00,  ..., 1.9345e+03, 9.1557e+02,
         1.6042e+03]], device='cuda:0', grad_fn=<ScatterAddBackward>)
torch.Size([4, 2560])
