In [3]:
from torchdrug import  layers, datasets,transforms,core
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
from torchdrug import data

from torch_scatter import scatter_add
import torch.nn as nn
from torchdrug import utils
from torch.utils import checkpoint
import torch.nn.functional as F
import torch.optim as optim

import time

if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    device = torch.device('cuda:1')
    torch.cuda.set_device(device)
else:
    raise ValueError("CUDA device 2 is not available")

# 数据集获取

In [4]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

# 只保留alpha碳的简化格式
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )

13:22:36   Extracting /home/xiaotong/scratch/protein-datasets/EnzymeCommission.zip to /home/xiaotong/scratch/protein-datasets


Loading /home/xiaotong/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:50<00:00, 370.72it/s]


In [35]:
graphs = dataset[0:2]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print("\n\n")
graph = graph_construction_model(graphs)
print(graph)




PackedProtein(batch_size=2, num_atoms=[185, 415], num_bonds=[3754, 8999], num_residues=[185, 415])


In [36]:
print(graph.num_nodes)
print(graph.batch_size)
print(graph.edge_list)


tensor([185, 415])
2
tensor([[ 95,  96,   5],
        [109, 110,   5],
        [108, 109,   5],
        ...,
        [438, 470,   0],
        [489, 470,   0],
        [493, 470,   0]])


# 关系卷积神经网络，获取多个不同的嵌入

In [37]:
class relationalGraph(layers.MessagePassingBase):
    
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None
            
    def trans(self, A, graph):
    
        Degree_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
        A_norm = torch.mm(torch.mm(Degree_inv_sqrt, A), Degree_inv_sqrt)
        
        n_rel = graph.num_relation
        n = A_norm.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A_norm[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A_norm[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation
        device = input.device  # Ensure device consistency
        
        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
        
            degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
            edge_weight = graph.edge_weight / degree_out[node_out]
            adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                                (graph.num_node, graph.num_node * graph.num_relation))
            update = torch.sparse.mm(adjacency.t().to(device), input.to(device))
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t().to(device), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update

    def combine(self, input, update):
        # 自环特征
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        input = input.repeat(self.num_relation, 1).to(device)
        loop_update = self.self_loop(input).to(device)
        
        output = self.linear(update)+loop_update
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, new_edge_list=None):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input, new_edge_list)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update).view(graph.num_relation, input.size(0), -1)
        return output

#### 测试

In [38]:
input_dim = graph.node_feature.shape[-1]
output_dim = 128
num_relations = graph.num_relation

input = graph.node_feature.float()
model = relationalGraph(input_dim, output_dim, num_relations)
relational_output = model(graph, input, new_edge_list = None)
print("output: ", relational_output.shape)
print("output: ", relational_output.grad_fn)



output:  torch.Size([7, 600, 128])
output:  <ViewBackward object at 0x7f523f7f9670>


# 重连接模块

### window self attention + gumble softmax

In [56]:
class Rewirescorelayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, temperature=0.5, dropout=0.1):
        super(Rewirescorelayer, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.window_size = window_size
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)
        self.k = k
        
        self.query = nn.Linear(in_features, out_features * num_heads)
        self.key = nn.Linear(in_features, out_features * num_heads)
        self.scale = 1 / (out_features ** 0.5)

    def split_windows(self, tensor, index, window_size, device):
        result = []
        index_list = []
        start = 0

        for idx in index:
            end = start + idx - 1
            while start <= end:
                if start + window_size <= end:
                    result.append(tensor[:, start:start + window_size, :])
                    index_list.append([start, start + window_size])
                    start += window_size
                else:
                    padding_rows = window_size - (end - start + 1)
                    restart = start - padding_rows
                    result.append(tensor[:, restart:restart + window_size, :])
                    index_list.append([restart, restart + window_size])

                    start = end+1
        
        result_tensor = torch.stack(result, dim=1).to(device)
        return result_tensor, index_list

    def gumbel_softmax_top_k(self, logits, tau=1.0, hard=False):
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / tau

        y_soft = F.softmax(gumbels, dim=-1)

        if hard:
            topk_indices = logits.topk(self.k, dim=-1)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
        else:
            y = y_soft

        return y

    def windows2adjacent(self, windows, index_list, output, device):
        # 确保所有张量在相同设备上
        output = output.to(device)
        windows = windows.to(device)

        # 创建一个新的张量来存储更新后的输出
        new_output = torch.zeros_like(output, device=device)

        # 填充新输出张量
        for i, index in enumerate(index_list):
            start, end = index
            new_output[:, start:end, start:end] = torch.clamp(new_output[:, start:end, start:end] + windows[:, i, :, :], 0, 1)

        # 获取输出张量的形状
        num_relations, num_nodes, _ = new_output.shape
        
        # 创建一个新的结果张量
        result = torch.zeros(num_relations * num_nodes, num_relations * num_nodes, device=device)

        # 填充结果张量
        for i in range(num_relations):
            result[i * num_nodes:(i + 1) * num_nodes, i * num_nodes:(i + 1) * num_nodes] = result[i * num_nodes:(i + 1) * num_nodes, i * num_nodes:(i + 1) * num_nodes] + new_output[i]

        return result

    def forward(self, graph, node_features):
        torch.cuda.synchronize()  # 同步
        start0 = time.time()
        device = node_features.device
        num_relations = node_features.size(0)
        num_nodes = node_features.size(1)
        index = graph.num_nodes.tolist()
        end0 = time.time()
        
        torch.cuda.synchronize()  # 同步
        start1 = time.time()
        Q = self.query(node_features).view(num_relations, num_nodes, self.num_heads, self.out_features).permute(0, 2, 1, 3)
        K = self.key(node_features).view(num_relations, num_nodes, self.num_heads, self.out_features).permute(0, 2, 1, 3)
        Q = Q.reshape(num_relations * self.num_heads, num_nodes, self.out_features).to(device)
        K = K.reshape(num_relations * self.num_heads, num_nodes, self.out_features).to(device)
        end1 = time.time()
        
        torch.cuda.synchronize()  # 同步
        output = torch.zeros(num_relations, num_nodes, num_nodes).to(device)
        result = torch.zeros(num_relations * num_nodes, num_relations * num_nodes).to(device)

        torch.cuda.synchronize()  # 同步
        start2 = time.time()
        Q_windows, Q_index = self.split_windows(Q, index, self.window_size, device)
        K_windows, _ = self.split_windows(K, index, self.window_size, device)
        end2 = time.time()  
        
        torch.cuda.synchronize()  # 同步
        start3 = time.time()
        scores = torch.einsum('b h i e, b h j e -> b h i j', Q_windows, K_windows) / self.scale                                 # (num_relations*num_heads, num_windows, window_size, window_size)
        attn = scores.softmax(dim=-1).view(num_relations, self.num_heads, -1, self.window_size, self.window_size).mean(dim=1)   # (num_relations, num_windows, window_size, window_size)
        end3 = time.time()
        
        torch.cuda.synchronize()  # 同步
        start4 = time.time()
        attn = self.gumbel_softmax_top_k(attn, tau=self.temperature, hard=True)                                                 # (num_relations, num_windows, window_size, window_size)
        end4 = time.time()
        
        torch.cuda.synchronize()  # 同步
        start5 = time.time()
        result = result + self.windows2adjacent(attn, Q_index, output, device)
        end5 = time.time()
        
        print(f"0运行时间: {end0 - start0:.6f} 秒")
        print(f"1运行时间: {end1 - start1:.6f} 秒")
        print(f"2运行时间: {end2 - start2:.6f} 秒")
        print(f"3运行时间: {end3 - start3:.6f} 秒")
        print(f"4运行时间: {end4 - start4:.6f} 秒")
        print(f"5运行时间: {end5 - start5:.6f} 秒")
        
        return attn


#### 测试

In [57]:
input_dim = relational_output.shape[-1]
output_dim = 256
num_heads = 8
window_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
k = 3
num_nodes = relational_output.size(0)


relational_output = relational_output.to(device)
module = Rewirescorelayer(input_dim, output_dim, num_heads, window_size, k).to(device)
start = time.time()
attn_output = module(graph, relational_output)
end = time.time()
print(f"运行时间: {end - start:.6f} 秒")
print(attn_output.shape)



0运行时间: 0.000011 秒
1运行时间: 0.000460 秒
2运行时间: 0.001107 秒
3运行时间: 0.000232 秒
4运行时间: 0.000257 秒
5运行时间: 0.003100 秒
运行时间: 0.044931 秒
torch.Size([7, 61, 10, 10])


In [53]:
print(attn_output)
indices = torch.nonzero(attn_output, as_tuple=True)
print(indices)


tensor([[[[0., 0., 0.,  ..., 1., 0., 1.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 1.,  ..., 1., 0., 0.],
          ...,
          [0., 1., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 1.,  ..., 1., 0., 0.]],

         [[0., 1., 0.,  ..., 1., 0., 1.],
          [0., 1., 0.,  ..., 0., 0., 1.],
          [1., 0., 0.,  ..., 1., 0., 0.],
          ...,
          [1., 0., 0.,  ..., 0., 0., 1.],
          [1., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.]],

         [[0., 0., 1.,  ..., 1., 1., 0.],
          [0., 0., 1.,  ..., 0., 0., 0.],
          [1., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [1., 0., 1.,  ..., 0., 0., 0.],
          [0., 0., 1.,  ..., 0., 0., 0.],
          [0., 0., 1.,  ..., 1., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 1., 1., 0.],
          [0., 0., 0.,  ..., 0., 1., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 1., 

### 测试不同degree进行采样

In [10]:
import torch
import torch.nn.functional as F

def gumbel_softmax_sample(logits, tau):
    # 从Gumbel(0, 1)分布中采样
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
    y = logits + gumbel_noise
    return F.softmax(y / tau, dim=-1)

def bernoulli_sampling_with_different_thresholds(probs, thresholds, tau=1.0):
    # 对数概率
    logits = torch.log(probs) - torch.log(1 - probs)
    # 进行Gumbel-Softmax采样
    y = gumbel_softmax_sample(logits, tau)
    # 硬化处理，根据每行的不同阈值
    z = (y > thresholds.unsqueeze(1)).float()
    return z

# 示例矩阵
n = 10
P = torch.rand(n, n)

# 为每一行设置不同的阈值
thresholds = torch.tensor([0.2, 0.5, 0.7, 0.9, 0.1, 0.3, 0.6, 0.8, 0.4, 0.2])

# 进行可微分伯努利采样并硬化处理
tau = 0.1  # 温度参数
sampled_matrix = bernoulli_sampling_with_different_thresholds(P, thresholds, tau)

#print("Probability Matrix:\n", P)
print("Thresholds:\n", thresholds)
print("Sampled Matrix:\n", sampled_matrix)


Thresholds:
 tensor([0.2000, 0.5000, 0.7000, 0.9000, 0.1000, 0.3000, 0.6000, 0.8000, 0.4000,
        0.2000])
Sampled Matrix:
 tensor([[1., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])


# Diffusion模块

### Rewired_gearnet 用于diffusion模块

In [18]:
class RewireGearnet(nn.Module):
    gradient_checkpoint = False

    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(RewireGearnet, self).__init__()
        self.num_relation = num_relation
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.linear = nn.Linear(num_relation * input_dim, output_dim)
        self.self_loop = nn.Linear(input_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(output_dim) if batch_norm else None
        self.activation = getattr(F, activation) if activation else None
        self.edge_linear = nn.Linear(edge_input_dim, output_dim) if edge_input_dim else None

    def trans(self, A, graph):
        n_rel = graph.num_relation
        n = A.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation

        device = input.device  # Ensure device consistency

        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
            adjacency = torch.sparse_coo_tensor(
                torch.stack([node_in, node_out]),
                graph.edge_weight.to(device),
                (graph.num_node, graph.num_node * graph.num_relation),
                device=device
            )
            update = torch.sparse.mm(adjacency.t(), input)
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t(), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = graph.edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(
                edge_input * edge_weight, node_out, dim=0,
                dim_size=graph.num_node * graph.num_relation
            )
            update += edge_update
            
        return update.view(graph.num_node, self.num_relation * self.input_dim).to(device)

    def combine(self, input, update):
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        if self.batch_norm:
            self.batch_norm.to(device)
        
        output = self.linear(update) + self.self_loop(input)
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

    def forward(self, graph, input, new_edge_list=None):
        """
        Perform message passing over the graph(s).

        Parameters:
            graph (Graph): graph(s)
            input (Tensor): node representations of shape :math:`(|V|, ...)`
        """
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self.message_and_aggregate, graph, input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

#### 测试

In [19]:
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relations = graph.num_relation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


new_node_feature = RewireGearnet(input_dim, output_dim, num_relations)(graph, graph.node_feature.to(device).float(), attn_output).to(device)

print(new_node_feature.shape)
print(new_node_feature)

torch.Size([1165, 512])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.1327, 0.0000],
        [0.0000, 0.0320, 0.0000,  ..., 0.0000, 0.0880, 0.1370],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3040, 0.2115],
        ...,
        [0.0000, 0.0000, 0.0793,  ..., 0.2069, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.4240],
        [0.0000, 0.3808, 0.3437,  ..., 0.0000, 0.0000, 0.1364]],
       device='cuda:1', grad_fn=<ReluBackward0>)


# 最终模型

In [21]:
# 定义一个装饰器来计时
def time_layer(layer, layer_name):
    def timed_layer(*args, **kwargs):
        start_time = time.time()
        output = layer(*args, **kwargs)
        end_time = time.time()
        print(f'{layer_name}: {end_time - start_time:.6f} seconds')
        return output
    return timed_layer

In [22]:
class DGMGearnet(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, score_dim, num_relation, num_heads, window_size, k, edge_input_dim=None, num_angle_bin=None,
                 short_cut=False, batch_norm=False, activation="relu", concat_hidden=True, readout="sum"):
        super(DGMGearnet, self).__init__()

        #if not isinstance(hidden_dims, Sequence):
            #hidden_dims = [hidden_dims]
        self.input_dim = input_dim
        self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
        self.dims = [input_dim] + list(hidden_dims)
        self.score_dim = score_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.k = k
        self.edge_dims = [edge_input_dim] + self.dims[:-1]
        self.num_relation = num_relation
        self.num_angle_bin = num_angle_bin
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden
        self.batch_norm = batch_norm

        self.layers = nn.ModuleList()
        self.score_layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            
            self.score_layers.append(relationalGraph(self.dims[i], self.score_dim, num_relation, 
                                                     edge_input_dim=None, batch_norm=False, activation="relu")) 

            
            self.score_layers.append(Rewirescorelayer(self.score_dim, self.dims[i+1], self.num_heads, self.window_size, 
                                            self.k, temperature=0.5, dropout=0.1))
            
            self.layers.append(RewireGearnet(self.dims[i], self.dims[i + 1], num_relation,
                                            edge_input_dim=None, batch_norm=False, activation="relu"))
        
        if num_angle_bin:
            self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
            self.edge_layers = nn.ModuleList()
            for i in range(len(self.edge_dims) - 1):
                self.edge_layers.append(layers.GeometricRelationalGraphConv(
                    self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

        if batch_norm:
            self.batch_norms = nn.ModuleList()
            for i in range(len(self.dims) - 1):
                self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1]))

        if readout == "sum":
            self.readout = layers.SumReadout()
        elif readout == "mean":
            self.readout = layers.MeanReadout()
        else:
            raise ValueError("Unknown readout `%s`" % readout)

    def forward(self, graph, input, edge_list=None, all_loss=None, metric=None):
        """
        Compute the node representations and the graph representation(s).

        Parameters:
            graph (Graph): :math:`n` graph(s)
            input (Tensor): input node representations
            all_loss (Tensor, optional): if specified, add loss to this tensor
            metric (dict, optional): if specified, output metrics to this dict

        Returns:
            dict with ``node_feature`` and ``graph_feature`` fields:
                node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)`
        """
        hiddens = []
        layer_input = input
        if self.num_angle_bin:
            line_graph = self.spatial_line_graph(graph)
            edge_input = line_graph.node_feature.float()

        for i in range(len(self.layers)):
            
            
            relational_output = time_layer(self.score_layers[2*i], 'relational output')(graph, layer_input, edge_list)
            new_edge_list = time_layer(self.score_layers[2*i+1], 'new edge list')(graph, relational_output)
            
            hidden = time_layer(self.layers[i], 'hidden')(graph, layer_input, new_edge_list)
            
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
            
            if self.num_angle_bin:
                edge_hidden = self.edge_layers[i](line_graph, edge_input)
                edge_weight = graph.edge_weight.unsqueeze(-1)
                if new_edge_list is None:
                    node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
                else:
                    node_out = new_edge_list[:, 1] * self.num_relation + new_edge_list[:, 2]
                
                    update = scatter_add(edge_hidden * edge_weight, node_out, dim=0,
                                        dim_size=graph.num_node * self.num_relation)
                update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1])
                update = self.layers[i].linear(update)
                update = self.layers[i].activation(update)
                hidden = hidden + update
                edge_input = edge_hidden
                
            if self.batch_norm:
                hidden = self.batch_norms[i](hidden)
                
            hiddens.append(hidden)
            layer_input = hidden
            edge_list = new_edge_list

        if self.concat_hidden:
            node_feature = torch.cat(hiddens, dim=-1)
        else:
            node_feature = hiddens[-1]
        graph_feature = self.readout(graph, node_feature)

        return {
            "graph_feature": graph_feature,
            "node_feature": node_feature
        }

### 测试

In [24]:
input_dim = graph.node_feature.shape[-1]
hidden_dims = [512, 512, 512, 512, 512]
score_dim = 128
num_relations = graph.num_relation
num_heads = 8
window_size = 10
k = 5


output = DGMGearnet(input_dim, hidden_dims, score_dim, num_relations, num_heads, window_size, k).to(device)(graph.to(device), graph.node_feature.to(device).float())

relational output: 0.001634 seconds
new edge list: 0.131950 seconds
hidden: 0.003069 seconds
relational output: 0.044167 seconds
new edge list: 0.135501 seconds
hidden: 0.003238 seconds
relational output: 0.043538 seconds
new edge list: 0.131423 seconds
hidden: 0.003260 seconds
relational output: 0.042964 seconds
new edge list: 0.130124 seconds
hidden: 0.003250 seconds
relational output: 0.042751 seconds
new edge list: 0.147749 seconds
hidden: 0.003250 seconds


In [15]:
print(output["node_feature"])
print(output["node_feature"].shape)
print("\n")

print(output["graph_feature"])
print(output["graph_feature"].shape)

tensor([[ 0.5056,  0.0000,  0.0000,  ...,  0.0000, 12.6781,  0.0000],
        [ 0.1331,  0.0000,  0.0000,  ...,  0.0000,  1.6680,  3.1325],
        [ 0.4216,  0.0000,  0.0000,  ...,  0.0000,  4.0469,  6.2530],
        ...,
        [ 0.4255,  0.0000,  0.0000,  ...,  0.0000, 12.3883,  0.0000],
        [ 0.6212,  0.0000,  0.0000,  ...,  0.0000,  5.0837,  1.6823],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  7.2654,  0.0000]],
       device='cuda:0', grad_fn=<CatBackward>)
torch.Size([572, 2560])


tensor([[  60.2262,    4.7113,   13.5840,  ...,   23.6377, 2026.7983,
          116.4902],
        [  66.2441,    3.1500,   15.5229,  ...,   13.4601, 2375.6035,
          124.7936]], device='cuda:0', grad_fn=<ScatterAddBackward>)
torch.Size([2, 2560])


### 验证反向传播

In [16]:
model = DGMGearnet(input_dim, hidden_dims, score_dim, num_relations, num_heads, window_size, k).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

output = model(graph.to(device), graph.node_feature.to(device).float())

output = output["graph_feature"]

# 定义一个简单的损失函数，例如均方误差
criterion = F.binary_cross_entropy_with_logits
target = torch.rand_like(output)

# 计算损失
loss = criterion(output, target, reduction="mean")
print(loss)
# 反向传播
loss.backward()

# 检查梯度是否为 NaN 或 inf
for name, param in model.named_parameters():
    if param.grad is not None:
        if torch.isnan(param.grad).any():
            print(f"Gradient of {name} contains NaN values.")
        if torch.isinf(param.grad).any():
            print(f"Gradient of {name} contains inf values.")
        else:
            print(f"Layer {name} - Gradient_norm: {torch.norm(param.grad)}")
            
    else:
        print(f"No gradient found for {name}")

# 检查损失是否为 NaN
if torch.isnan(loss):
    print("Loss is NaN.")
        
# 更新参数
optimizer.step()

[[0, 10], [10, 20], [20, 30], [30, 40], [40, 50], [50, 60], [60, 70], [70, 80], [80, 90], [90, 100], [100, 110], [110, 120], [120, 130], [130, 140], [140, 150], [150, 160], [160, 170], [170, 180], [180, 190], [190, 200], [200, 210], [210, 220], [220, 230], [230, 240], [240, 250], [250, 260], [254, 264], [264, 274], [274, 284], [284, 294], [294, 304], [304, 314], [314, 324], [324, 334], [334, 344], [344, 354], [354, 364], [364, 374], [374, 384], [384, 394], [394, 404], [404, 414], [414, 424], [424, 434], [434, 444], [444, 454], [454, 464], [464, 474], [474, 484], [484, 494], [494, 504], [504, 514], [514, 524], [524, 534], [534, 544], [544, 554], [554, 564], [562, 572]]
[[0, 10], [10, 20], [20, 30], [30, 40], [40, 50], [50, 60], [60, 70], [70, 80], [80, 90], [90, 100], [100, 110], [110, 120], [120, 130], [130, 140], [140, 150], [150, 160], [160, 170], [170, 180], [180, 190], [190, 200], [200, 210], [210, 220], [220, 230], [230, 240], [240, 250], [250, 260], [254, 264], [264, 274], [274, 