In [1]:
from torchdrug import  layers, datasets,transforms,core, models
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
from torchdrug import data

from torch_scatter import scatter_add
import torch.nn as nn
from torchdrug import utils
from torch.utils import checkpoint
import torch.nn.functional as F
import torch.optim as optim
from einops import rearrange, repeat, pack, unpack
from collections.abc import Sequence

import time

if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)
else:
    raise ValueError("CUDA device 0 is not available")

In [2]:
import matplotlib.pyplot as plt

# Get datasets

In [3]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

# 只保留alpha碳的简化格式
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )

17:21:40   Extracting /home/xiaotong/scratch/protein-datasets/EnzymeCommission.zip to /home/xiaotong/scratch/protein-datasets


Loading /home/xiaotong/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:45<00:00, 409.68it/s]


In [9]:
graphs = dataset[0:1]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
graph = graph_construction_model(graphs)
print(graph)

PackedProtein(batch_size=1, num_atoms=[185], num_bonds=[3754], num_residues=[185])


In [10]:

print(graph.num_nodes)
edge_list = graph.edge_list 

print(edge_list)
print(edge_list.shape)


tensor([185])
tensor([[ 90,  82,   1],
        [ 85,  78,   1],
        [ 83,  78,   1],
        ...,
        [ 75, 107,   0],
        [ 50, 108,   0],
        [ 51, 108,   0]])
torch.Size([3754, 3])


# Relationnal conv graph 

In [13]:
class relationalGraphConv(layers.MessagePassingBase):
    
    eps = 1e-10

    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraphConv, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(num_relation * input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None


    def trans(self, A, graph):
        n_rel = graph.num_relation
        A = A.view(A.size(0),n_rel, A.size(1)//n_rel).permute(1, 0, 2)
       # 初始化结果张量
        A_norm = torch.zeros_like(A)

        # 对每个矩阵进行处理
        for i in range(A.size(0)):
            # 计算度矩阵 (按行求和)
            degree = A[i].sum(dim=1)
            
            # 计算度矩阵的逆平方根
            degree_inv_sqrt = torch.pow(degree, -0.5)
            
            # 将度矩阵逆平方根转换为对角矩阵
            Degree_inv_sqrt = torch.diag(degree_inv_sqrt)
            
            # 进行归一化操作
            A_norm[i] = torch.mm(torch.mm(Degree_inv_sqrt, A[i]), Degree_inv_sqrt)
        Anorm = A_norm.permute(1, 0, 2).contiguous().view(A_norm.size(1), A_norm.size(0)*A_norm.size(2))
    
        return Anorm

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation
        device = input.device  # Ensure device consistency
        
        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
        
            degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
            edge_weight = graph.edge_weight / degree_out[node_out]
            adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                                (graph.num_node, graph.num_node * graph.num_relation))
            update = torch.sparse.mm(adjacency.t().to(device), input.to(device))
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t().to(device), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update.view(input.size(0), self.num_relation * self.input_dim)                           


    def combine(self, input, update):
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        output = self.linear(update) + self.self_loop(input)
        if self.batch_norm:
            self.batch_norm.to(device)
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, new_edge_list=None):
        device = input.device
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input, new_edge_list)
        else:
            update = self.message_and_aggregate(graph.to(device), input, new_edge_list)
        output = self.combine(input, update)
        return output

In [14]:
class relationalGraph(layers.MessagePassingBase):
    
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None
            
    def trans(self, A, graph):
        n_rel = graph.num_relation
        A = A.view(A.size(0),n_rel, A.size(1)//n_rel).permute(1, 0, 2)
       # 初始化结果张量
        A_norm = torch.zeros_like(A)

        # 对每个矩阵进行处理
        for i in range(A.size(0)):
            # 计算度矩阵 (按行求和)
            degree = A[i].sum(dim=1)
            
            # 计算度矩阵的逆平方根
            degree_inv_sqrt = torch.pow(degree, -0.5)
            
            # 将度矩阵逆平方根转换为对角矩阵
            Degree_inv_sqrt = torch.diag(degree_inv_sqrt)
            
            # 进行归一化操作
            A_norm[i] = torch.mm(torch.mm(Degree_inv_sqrt, A[i]), Degree_inv_sqrt)
        Anorm = A_norm.permute(1, 0, 2).contiguous().view(A_norm.size(1), A_norm.size(0)*A_norm.size(2))
    
        return Anorm

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation
        device = input.device  # Ensure device consistency
        
        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
        
            degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
            edge_weight = graph.edge_weight / degree_out[node_out]
            adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                                (graph.num_node, graph.num_node * graph.num_relation))
            update = torch.sparse.mm(adjacency.t().to(device), input.to(device))
        else:
            adjacency = self.trans(new_edge_list, graph).to(device)
            update = torch.mm(adjacency.t().to(device), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update

    def combine(self, input, update):
        # 自环特征
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        input = input.repeat(self.num_relation, 1).to(device)
        loop_update = self.self_loop(input).to(device)
        
        output = self.linear(update)+loop_update
        if self.batch_norm:
            self.batch_norm.to(device)
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, new_edge_list=None):
        device = input.device
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input, new_edge_list)
        else:
            update = self.message_and_aggregate(graph.to(device), input, new_edge_list)
        output = self.combine(input, update)
        return output

In [9]:
class relationalGraphStack(nn.Module):
    
    def __init__(self, dims, num_relation, edge_input_dim=None, batch_norm=True, activation="relu"):
        super(relationalGraphStack, self).__init__()
        self.num_layers = len(dims) - 1
        self.layers = nn.ModuleList()
        for i in range(self.num_layers-1):
            self.layers.append(relationalGraphConv(dims[i], dims[i + 1], num_relation, edge_input_dim, batch_norm, activation))
            
        self.layers.append(relationalGraph(dims[-2], dims[-1], num_relation, edge_input_dim, batch_norm, activation))
            

    def forward(self, graph, input, new_edge_list=None):
        device = input.device
        x = input
        for layer in self.layers:
            x = layer(graph.to(device), x, new_edge_list)            
        return x.view(graph.num_relation, input.size(0), -1)

In [10]:
rel_dims = [[21, 256, 512, 512]]
num_relations = graph.num_relation

model = relationalGraphStack(rel_dims[0], num_relations,batch_norm=True)
relational_output = model(graph, graph.node_feature.float().to(device))
print(relational_output.shape)

torch.Size([7, 600, 512])


# Local attention for rewire

In [11]:
class Rewirescorelayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, temperature=0.5):
        super(Rewirescorelayer, self).__init__()
        self.input_dim = in_features
        self.output_dim = out_features
        self.num_heads = num_heads
        self.window_size = window_size
        self.k = k
        self.temperature = temperature
        
        self.query = nn.Linear(in_features, out_features * num_heads).to(device)
        self.key = nn.Linear(in_features, out_features * num_heads).to(device)
        self.scale = 1 / (out_features ** 0.5)
    
    
    
    class LocalAttention(nn.Module):
        def __init__(
            self,
            window_size,
            look_backward = 1,
            look_forward = None,
            dim = None,
            scale = None,
            pad_start_position = None
        ):
            super().__init__()

            self.scale = scale

            self.window_size = window_size

            self.look_backward = look_backward
            self.look_forward = look_forward
            
            self.pad_start_position = pad_start_position

        def exists(self,val):
            return val is not None

        # 如果value不存在，返回d
        def default(self,value, d):
            return d if not self.exists(value) else value

        def to(self, t):
            return {'device': t.device, 'dtype': t.dtype}

        def max_neg_value(self, tensor):
            return -torch.finfo(tensor.dtype).max  #返回给定张量数据类型的所能表示的最大负值

        def look_around(self, x, backward = 1, forward = 0, pad_value = -1, dim = 2):  #x = bk: (40, 32, 16, 64)
            t = x.shape[1]    #获取一共有多少个窗口，这里是32
            dims = (len(x.shape) - dim) * (0, 0)   #一个长度为 len(x.shape) - dim 的元组，每个元素为 (0, 0)；其中len(x.shape) = 4
            padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)   #在第二维度上，前面加backward个元素，后面加forward个元素 -> (40, 33, 16, 64)
            tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)] #一个张量列表，每个张量的维度为(40, 32, 16, 64), len = 2
            return torch.cat(tensors, dim = dim) #在第二维度上拼接 -> (40, 32, 32, 64)
    
        def forward(
            self,
            q, k,
            mask = None,
            input_mask = None,
            window_size = None
        ):

            mask = self.default(mask, input_mask)
            assert not (self.exists(window_size) and not self.use_xpos), 'cannot perform window size extrapolation if xpos is not turned on'
            shape, pad_value, window_size, look_backward, look_forward = q.shape, -1, self.default(window_size, self.window_size), self.look_backward, self.look_forward
            (q, packed_shape), (k, _) = map(lambda t: pack([t], '* n d'), (q, k))  #打包成[5, 8, 512, 64] -> [40, 512, 64] 


            b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype   # 40, 512, 64
            scale = self.default(self.scale, dim_head ** -0.5)
            assert (n % window_size) == 0, f'sequence length {n} must be divisible by window size {window_size} for local attention'

            windows = n // window_size  # 512 / 16 = 32

            seq = torch.arange(n, device = device)                  # 0, 1, 2, 3, ..., 511
            b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)    # (1, 32, 16) 排序序列变形后的矩阵

            # bucketing

            bq, bk = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k)) #重构：（40，512，64）->（40, 32, 16, 64）

            bq = bq * scale    # (40, 32, 16, 64)

            look_around_kwargs = dict(
                backward =  look_backward,
                forward =  look_forward,
                pad_value = pad_value
            )

            bk = self.look_around(bk, **look_around_kwargs)      # (40, 32, 32, 64)
    

            # calculate positions for masking

            bq_t = b_t
            bq_k = self.look_around(b_t, **look_around_kwargs) # (1, 32, 32)

            bq_t = rearrange(bq_t, '... i -> ... i 1')      # (1, 32, 16, 1)
            bq_k = rearrange(bq_k, '... j -> ... 1 j')      # (1, 32, 1, 16)

            pad_mask = bq_k == pad_value

            sim = torch.einsum('b h i e, b h j e -> b h i j', bq, bk)  # (40, 32, 16, 64) * (40, 32, 32, 64) -> (40, 32, 16, 32)

            mask_value = self.max_neg_value(sim)

            sim = sim.masked_fill(pad_mask, mask_value)


            if self.exists(mask):
                batch = mask.shape[0]    # 5
                assert (b % batch) == 0

                h = b // mask.shape[0]  # 8

                mask = rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size)
                mask = self.look_around(mask, **{**look_around_kwargs, 'pad_value': False})
                mask = rearrange(mask, '... j -> ... 1 j')
                mask = repeat(mask, 'b ... -> (b h) ...', h = h)

                sim = sim.masked_fill(~mask, mask_value)
                del mask
                
            indices = [self.pad_start_position[i] // window_size for i in range(len(self.pad_start_position)) if i % 2 != 0]
            all_indices = list(range(windows))
            remaining_indices = [idx for idx in all_indices if idx not in indices]
            
            # 使用剩余的索引选择元素
            rest_sim = sim[:, remaining_indices, :, :]

            # attention
            attn = rest_sim.softmax(dim = -1)

            return attn

    def insert_zero_rows(self, tensor, lengths, target_lengths):
        assert len(lengths) == len(target_lengths), "Lengths and target lengths must be of the same length."
        
        # 计算每个位置需要插入的零行数
        zero_rows = [target - length for length, target in zip(lengths, target_lengths)]
        
        # 初始化结果列表
        parts = []
        mask_parts = []
        start = 0
        
        for i, length in enumerate(lengths):
            end = start + length
            
            # 原始张量部分
            parts.append(tensor[:, start:end, :])
            mask_parts.append(torch.ones(tensor.size(0), length, dtype=torch.bool, device=tensor.device))
            
            # 插入零行
            if zero_rows[i] > 0:
                zero_padding = torch.zeros(tensor.size(0), zero_rows[i], tensor.size(2), device=tensor.device)
                mask_padding = torch.zeros(tensor.size(0), zero_rows[i], dtype=torch.bool, device=tensor.device)
                parts.append(zero_padding)
                mask_parts.append(mask_padding)
            
            start = end
        
        # 拼接所有部分
        padded_tensor = torch.cat(parts, dim=1)
        mask = torch.cat(mask_parts, dim=1)
        
        return padded_tensor, mask


    def round_up_to_nearest_k_and_a_window_size(self, lst, k):
        pad_start_position = []
        result_lst = [(x + k - 1) // k * k +k for x in lst]
        for i in range(len(lst)):
            pad_start_position.append(sum(result_lst[:i])-i*k + lst[i])
            pad_start_position.append(sum(result_lst[:i+1])-k)
        return result_lst, pad_start_position

    def gumbel_softmax_top_k(self, logits,  top_k,  hard=False):
            gumbels = -torch.empty_like(logits).exponential_().log()
            gumbels = (logits + gumbels) / self.temperature

            y_soft = F.softmax(gumbels, dim=-1)

            if hard:
                topk_indices = logits.topk(top_k, dim=-1)[1]
                y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
                y = (y_hard - y_soft).detach() + y_soft
            else:
                y = y_soft

            return y
        
    def displace_tensor_blocks_to_rectangle(self, tensor, displacement):
        batch_size, num_blocks, block_height, block_width = tensor.shape

        # 计算新矩阵的宽度和高度
        height = num_blocks * displacement
        width =  (2 + num_blocks) * displacement

        # 初始化新的大张量，确保其形状为 (batch_size, height, width)
        new_tensor = torch.zeros(batch_size, height, width, device=tensor.device, dtype=tensor.dtype)

        for i in range(num_blocks):
            start_pos_height = i * displacement
            start_pos_width = i * displacement
            end_pos_height = start_pos_height + block_height
            end_pos_width = start_pos_width + block_width

            new_tensor[:, start_pos_height:end_pos_height, start_pos_width:end_pos_width] = tensor[:, i, :, :]

        return new_tensor
    
    def forward(self, graph, node_features):
        
        num_relation = node_features.size(0)
        index = graph.num_nodes.tolist()
        
        target_input, pad_start_position = self.round_up_to_nearest_k_and_a_window_size(index, self.window_size)
        padding_input, mask = self.insert_zero_rows(node_features, index, target_input)
        
        Q = self.query(padding_input).view(num_relation, padding_input.size(1), self.num_heads, self.output_dim).permute(0, 2, 1, 3)                           # (num_relations, num_nodes, num_heads, out_features
        K = self.key(padding_input).view(num_relation, padding_input.size(1), self.num_heads, self.output_dim).permute(0, 2, 1, 3)                             # (num_relations, num_nodes, num_heads, out_features)
        Q = Q.reshape(num_relation * self.num_heads, padding_input.size(1), self.output_dim)                                                  # (num_relations*num_heads, num_nodes, out_features)
        K = K.reshape(num_relation * self.num_heads, padding_input.size(1), self.output_dim) 
        
        attn = self.LocalAttention(
            dim = self.output_dim,                   # dimension of each head (you need to pass this in for relative positional encoding)
            window_size = self.window_size,          # window size. 512 is optimal, but 256 or 128 yields good enough results
            look_backward = 1,                  # each window looks at the window before
            look_forward = 1,                   # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it
            pad_start_position = pad_start_position
        ) 
        
        attn = attn(Q, K, mask = mask).view(num_relation, self.num_heads, -1, self.window_size, 3*self.window_size).mean(dim=1)  
        score = self.gumbel_softmax_top_k(attn, self.k, hard=True)
        
        result_tensor = self.displace_tensor_blocks_to_rectangle(score, self.window_size)
        result_tensor = result_tensor[:, :, 10:-10]
        indice = [pad_start_position[i] for i in range(len(pad_start_position)) if i % 2 == 0]
        indices = []

        for num in indice:
            next_multiple_of_10 = ((num + 9) // 10) * 10  # 计算向上取10的倍数
            sequence = range(num, next_multiple_of_10)  # 生成序列
            indices.extend(sequence)  # 直接将序列中的元素添加到结果列表中
        all_indices = list(range(result_tensor.size(1)))
        remaining_indices = [idx for idx in all_indices if idx not in indices]
        
        result_tensor = result_tensor[:, remaining_indices, :]
        result_tensor = result_tensor[:, :, remaining_indices]
        
        return result_tensor.permute(1, 0, 2).contiguous().view(result_tensor.size(1), result_tensor.size(0)*result_tensor.size(2))
        


In [12]:
input_dim = relational_output.shape[-1]
output_dim = 64
num_heads = 8
window_size = 10
k = 3
num_nodes = relational_output.size(0)

model = Rewirescorelayer(input_dim, output_dim, num_heads, window_size, k)
start = time.time()
attn_output = model(graph, relational_output)
end = time.time()
print(f"运行时间: {end - start:.6f} 秒")
print(attn_output.shape)

运行时间: 0.005996 秒
torch.Size([600, 4200])


# Diffusion

In [13]:
class RewireGearnet(nn.Module):
    gradient_checkpoint = False

    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(RewireGearnet, self).__init__()
        self.num_relation = num_relation
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.linear = nn.Linear(num_relation * input_dim, output_dim)
        self.self_loop = nn.Linear(input_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(output_dim) if batch_norm else None
        self.activation = getattr(F, activation) if activation else None
        self.edge_linear = nn.Linear(edge_input_dim, output_dim) if edge_input_dim else None

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation

        device = input.device  # Ensure device consistency
        update = torch.mm(new_edge_list.t(), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = graph.edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(
                edge_input * edge_weight, node_out, dim=0,
                dim_size=graph.num_node * graph.num_relation
            )
            update += edge_update
            
        return update.view(input.size(0), self.num_relation * self.input_dim).to(device)

    def combine(self, input, update):
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        if self.batch_norm:
            self.batch_norm.to(device)
        
        output = self.linear(update) + self.self_loop(input)
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

    def forward(self, graph, input, new_edge_list=None):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self.message_and_aggregate, graph, input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

In [14]:
class gearnetlayer(layers.RelationalGraphConv):


    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(gearnetlayer, self).__init__(input_dim, output_dim, num_relation, edge_input_dim,
                                                           batch_norm, activation)

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation

        device = input.device  # Ensure device consistency
        update = torch.mm(new_edge_list.t(), input.to(device))
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = graph.edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(
                edge_input * edge_weight, node_out, dim=0,
                dim_size=graph.num_node * graph.num_relation
            )
            update += edge_update
            
        return update.view(input.size(0), self.num_relation * self.input_dim).to(device)
    
    def combine(self, input, update):
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        if self.batch_norm:
            self.batch_norm.to(device)
        
        output = self.linear(update) + self.self_loop(input)
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

    def forward(self, graph, input, new_edge_list=None):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self.message_and_aggregate, graph, input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update)
        return output

In [15]:
class rewireGearNetstack(nn.Module):
    
    def __init__(self, dims, num_relation, edge_input_dim=None, batch_norm=True, activation="relu"):
        super(rewireGearNetstack, self).__init__()
        self.num_layers = len(dims) - 1
        self.layers = nn.ModuleList()
        
        self.layers.append(RewireGearnet(dims[0], dims[1], num_relation, edge_input_dim, batch_norm, activation))
        for i in range(self.num_layers-1):
            self.layers.append(gearnetlayer(dims[i+1], dims[i + 2], num_relation, edge_input_dim, batch_norm, activation))
            
 
            
    def forward(self, graph, input, new_edge_list=None):
        device = input.device
        x = input
        for layer in self.layers:
            x = layer(graph.to(device), x, new_edge_list)         
        return x

In [16]:
dims = [21, 512, 512]
num_relations = graph.num_relation
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


new_node_feature = rewireGearNetstack(dims, num_relations)(graph, graph.node_feature.to(device).float(), attn_output).to(device)

print(new_node_feature.shape)
print(new_node_feature)

torch.Size([600, 512])
tensor([[0.9666, 0.0000, 0.6164,  ..., 0.2444, 0.8194, 1.3916],
        [0.0902, 0.9526, 0.5950,  ..., 0.5453, 0.0078, 2.0728],
        [0.0000, 0.3094, 1.7195,  ..., 1.0820, 0.5567, 0.0000],
        ...,
        [0.0000, 0.6042, 0.0000,  ..., 0.0000, 0.0000, 0.2695],
        [0.0764, 0.0000, 0.5411,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7347, 0.0338,  ..., 0.9431, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<ReluBackward0>)


# Final model

In [17]:
# 定义一个装饰器来计时
def time_layer(layer, layer_name):
    def timed_layer(*args, **kwargs):
        start_time = time.time()
        output = layer(*args, **kwargs)
        end_time = time.time()
        print(f'{layer_name}: {end_time - start_time:.6f} seconds')
        return output
    return timed_layer

In [27]:
class DGMGearnet(nn.Module, core.Configurable):

    def __init__(self, relation_dims, score_in_dim, score_out_dim, diffusion_dims, num_relation, num_heads, window_size, k, edge_input_dim=None, num_angle_bin=None,
                 short_cut=False, batch_norm=False, activation="relu", concat_hidden=True, readout="sum"):
        super(DGMGearnet, self).__init__()

        #if not isinstance(hidden_dims, Sequence):
            #hidden_dims = [hidden_dims]
        self.relation_dims = relation_dims
        self.score_in_dim = score_in_dim
        self.score_out_dim = score_out_dim
        self.diffusion_dims = diffusion_dims    
        self.num_heads = num_heads
        self.window_size = window_size
        self.k = k
        #self.edge_dims = [edge_input_dim] + self.dims[:-1]
        self.num_relation = num_relation
        self.num_angle_bin = num_angle_bin
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden
        self.batch_norm = batch_norm
        self.ouput_dim = self.diffusion_dims[-1][-1]*len(self.diffusion_dims) if concat_hidden else self.diffusion_dims[-1][-1]
        print("output_dim", self.ouput_dim)

        self.layers = nn.ModuleList()
        self.score_layers = nn.ModuleList()
        for i in range(len(self.relation_dims)):
            if i == 0:
                self.score_layers.append(relationalGraphStack(self.relation_dims[i], num_relation, 
                                                        edge_input_dim=None, batch_norm=True, activation="relu")) 

            else:
                self.score_layers.append(relationalGraphStack(self.relation_dims[i], num_relation, 
                                                        edge_input_dim=None, batch_norm=True, activation="relu")) 
                    
            self.score_layers.append(Rewirescorelayer(self.score_in_dim, self.score_out_dim, self.num_heads, self.window_size, 
                                                        self.k, temperature=0.5))
            

            self.layers.append(rewireGearNetstack(self.diffusion_dims[i], num_relation,
                                                        edge_input_dim=None, batch_norm=True, activation="relu"))
        
        if num_angle_bin:
            self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
            self.edge_layers = nn.ModuleList()
            for i in range(len(self.edge_dims) - 1):
                self.edge_layers.append(layers.GeometricRelationalGraphConv(
                    self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

        if batch_norm:
            self.batch_norms = nn.ModuleList()
            for i in range(len(self.diffusion_dims) ):
                self.batch_norms.append(nn.BatchNorm1d(self.diffusion_dims[i][-1]))

        if readout == "sum":
            self.readout = layers.SumReadout()
        elif readout == "mean":
            self.readout = layers.MeanReadout()
        else:
            raise ValueError("Unknown readout `%s`" % readout)

    def forward(self, graph, input, edge_list=None, all_loss=None, metric=None):
        node_in, node_out, relation = graph.edge_list.t().to(device)
        node_out = node_out * self.num_relation + relation
        adjacency = torch.sparse_coo_tensor(
            torch.stack([node_in, node_out]),
            graph.edge_weight.to(device),
            (graph.num_node, graph.num_node * graph.num_relation),
            device=device
        )
        adjacency = adjacency.to_dense()
        
        hiddens = []
        layer_input = input
        score_layer_input = input
        
        if self.num_angle_bin:
            line_graph = self.spatial_line_graph(graph)
            edge_input = line_graph.node_feature.float()

        for i in range(len(self.layers)):
            
            
            relational_output = time_layer(self.score_layers[2*i], 'relational output')(graph, score_layer_input, edge_list)
            new_edge_list = time_layer(self.score_layers[2*i+1], 'new edge list')(graph, relational_output)
            new_edge_list = torch.max(adjacency, new_edge_list)
            hidden = time_layer(self.layers[i], 'hidden')(graph, layer_input, new_edge_list)
            
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
            
            if self.num_angle_bin:
                edge_hidden = self.edge_layers[i](line_graph, edge_input)
                edge_weight = graph.edge_weight.unsqueeze(-1)
                if new_edge_list is None:
                    node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
                else:
                    node_out = new_edge_list[:, 1] * self.num_relation + new_edge_list[:, 2]
                
                    update = scatter_add(edge_hidden * edge_weight, node_out, dim=0,
                                        dim_size=graph.num_node * self.num_relation)
                update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1])
                update = self.layers[i].linear(update)
                update = self.layers[i].activation(update)
                hidden = hidden + update
                edge_input = edge_hidden
                
            if self.batch_norm:
                hidden = self.batch_norms[i](hidden)
                
            hiddens.append(hidden)
            score_layer_input = torch.cat([hidden, relational_output.view(hidden.size(0), self.num_relation*hidden.size(-1))], dim=-1)
            layer_input = hidden
            print(score_layer_input.shape)
            edge_list = new_edge_list

        if self.concat_hidden:
            node_feature = torch.cat(hiddens, dim=-1)
        else:
            node_feature = hiddens[-1]
        graph_feature = self.readout(graph, node_feature)

        return {
            "graph_feature": graph_feature,
            "node_feature": node_feature
        }

In [28]:
relation_dims = [[21, 128, 512, 512], [4096, 1024, 512, 512]]
score_in_dim = 512
score_out_dim = 64
diffusion_dims = [[21, 512, 512, 512], [512, 512, 512, 512]]   
num_relations = graph.num_relation
num_heads = 8
window_size = 10
k = 5

output = DGMGearnet(relation_dims, score_in_dim, score_out_dim, diffusion_dims, num_relations, num_heads, window_size, k, batch_norm=True).to(device)(graph.to(device), graph.node_feature.to(device).float())

output_dim 1024
relational output: 0.004743 seconds
new edge list: 0.004357 seconds
hidden: 0.001370 seconds
torch.Size([600, 4096])
relational output: 0.006383 seconds
new edge list: 0.003749 seconds
hidden: 0.001331 seconds
torch.Size([600, 4096])


In [29]:
print(output["node_feature"])

print(output["node_feature"].shape)
print("\n")

print(output["graph_feature"])
print(output["graph_feature"].shape)
a = diffusion_dims[-1][-1]*len(diffusion_dims)
print(a)

tensor([[-0.6396,  0.3717, -0.1841,  ..., -0.6632, -0.6266,  0.3609],
        [-0.6695, -0.6594, -0.6908,  ...,  0.0965, -0.6266,  0.2488],
        [ 0.3171, -0.6594, -0.6908,  ...,  0.1406, -0.6266,  3.9892],
        ...,
        [-0.6695, -0.6594,  0.1060,  ..., -0.6632, -0.6266,  0.6253],
        [-0.6695, -0.6594, -0.6908,  ..., -0.2858,  0.6364, -0.6156],
        [-0.6695, -0.6594, -0.5598,  ...,  0.0772, -0.6266, -0.2836]],
       device='cuda:0', grad_fn=<CatBackward>)
torch.Size([600, 1024])


tensor([[-34.3780,   2.6419,  21.6274,  ..., -26.3327,   0.0387,   7.7938],
        [ 34.3782,  -2.6417, -21.6273,  ...,  26.3327,  -0.0388,  -7.7937]],
       device='cuda:0', grad_fn=<ScatterAddBackward>)
torch.Size([2, 1024])
1024


# Test

In [26]:
input_dim = 21
hidden_dims = [512, 512, 512, 512, 512, 512]


gearnet = R.search("models.GearNet")(input_dim, hidden_dims, num_relations, batch_norm=True, concat_hidden=True, readout="sum")
print(gearnet.output_dim)

gearnet_output = gearnet(graph, graph.node_feature.float())
print(gearnet_output["node_feature"])
print(gearnet_output["node_feature"].shape)
print("\n")

print(gearnet_output["graph_feature"])
print(gearnet_output["graph_feature"].shape)

3072
tensor([[ 2.6553,  0.6771, -0.7213,  ..., -0.6440, -0.3518,  0.5691],
        [ 1.1149,  0.7099, -0.7213,  ..., -0.6440, -0.5663,  1.2097],
        [-0.3453, -0.1319, -0.7213,  ..., -0.6440, -0.5663,  0.8508],
        ...,
        [-0.7127, -0.6336,  1.3473,  ...,  0.1884,  0.2556,  0.4028],
        [-0.7127, -0.6336, -0.6713,  ...,  0.1304,  0.0762,  0.0320],
        [ 1.1920,  0.0447,  0.5976,  ..., -0.1654, -0.5416, -0.2907]],
       grad_fn=<CatBackward>)
torch.Size([600, 3072])


tensor([[ -17.7990,  -14.1211,   29.1059,  ..., -113.1195,  -75.7146,
          100.5911],
        [  17.7990,   14.1209,  -29.1059,  ...,  113.1196,   75.7146,
         -100.5909]], grad_fn=<ScatterAddBackward>)
torch.Size([2, 3072])


### 画图

In [46]:
import networkx as nx
num_nodes = graph.num_nodes
edge_list = graph.edge_list
print(edge_list)
# 1. 过滤行，只保留每行最后一个数为1的行
filtered_rows = edge_list[edge_list[:, 2] == 1]

# 2. 去掉最后一列，只保留前两列
filtered_rows = filtered_rows[:, :2]

# 3. 将结果转换为一对对的形式
edge_list = [tuple(pair) for pair in filtered_rows.tolist()]
print(edge_list)

G = nx.Graph()
G.add_edges_from(edge_list)
plt.figure(figsize=(10, 10))  # 设置画布大小
nx.draw(G, with_labels=True, node_size=70, node_color='lightblue', font_size=6, font_color='black', edge_color='black')
plt.title('Graph Visualization')
plt.savefig('tu.png')

tensor([[ 90,  82,   1],
        [ 85,  78,   1],
        [ 83,  78,   1],
        ...,
        [ 75, 107,   0],
        [ 50, 108,   0],
        [ 51, 108,   0]])
[(90, 82), (85, 78), (83, 78), (84, 79), (73, 79), (85, 79), (85, 80), (87, 80), (106, 81), (91, 81), (90, 81), (76, 81), (105, 81), (91, 82), (106, 82), (84, 78), (92, 82), (76, 82), (90, 83), (89, 83), (91, 83), (89, 84), (79, 84), (90, 84), (79, 85), (80, 85), (78, 85), (79, 86), (96, 86), (80, 86), (70, 75), (75, 69), (92, 69), (75, 70), (37, 70), (172, 71), (171, 71), (175, 71), (93, 72), (94, 72), (92, 72), (92, 73), (78, 73), (40, 74), (69, 74), (97, 87), (44, 75), (69, 75), (92, 75), (43, 75), (40, 75), (92, 76), (82, 76), (91, 76), (43, 76), (43, 77), (40, 77), (92, 77), (73, 78), (102, 95), (72, 92), (69, 93), (72, 93), (115, 93), (68, 93), (112, 93), (73, 93), (89, 94), (72, 94), (73, 94), (84, 94), (89, 95), (90, 95), (115, 95), (106, 92), (100, 95), (89, 96), (88, 96), (90, 96), (102, 96), (88, 97), (102, 97), (

In [25]:
node_in, node_out, relation = graph.edge_list.t()
node_out = node_out * graph.num_relation + relation

degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
print(degree_out.shape)
adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), graph.edge_weight,
                                    (graph.num_node, graph.num_node * graph.num_relation))

adjacency = adjacency.to_dense()

adjacency2 = adjacency.t().view(graph.num_node, graph.num_relation, graph.num_node).permute(1, 0, 2).reshape(graph.num_relation*graph.num_node , graph.num_node).t()

adjacency3 = adjacency2.view(graph.num_node, graph.num_relation, graph.num_node).permute(1, 0, 2).reshape(graph.num_relation*graph.num_node , graph.num_node)
print(adjacency3)

#inverse
adjacency4 = adjacency3.view(graph.num_relation, graph.num_node, graph.num_node).permute(1, 0, 2).reshape(graph.num_node, graph.num_relation*graph.num_node)
print("\nEqual:", torch.equal(adjacency2, adjacency4))

adjacency5 = adjacency4.t().view(graph.num_relation, graph.num_node, graph.num_node).permute(1, 0, 2).reshape(graph.num_node*graph.num_relation, graph.num_node).t()
print("\nEqual:", torch.equal(adjacency5, adjacency))

torch.Size([1295])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

Equal: True

Equal: True


In [32]:
degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
edge_weight = graph.edge_weight / degree_out[node_out]
adjacency_with_degree = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                            (graph.num_node, graph.num_node * graph.num_relation)).to_dense()



new_edge_list = adjacency2.t().view(graph.num_relation, graph.num_node, graph.num_node).permute(1, 0, 2).reshape(graph.num_node*graph.num_relation, graph.num_node).t()
row, col = new_edge_list.nonzero(as_tuple=True)
new_edge_weight = torch.ones_like(col, dtype=torch.float32)
degree_out2 = scatter_add(new_edge_weight, col, dim_size=graph.num_node * graph.num_relation)
edge_weight2 = new_edge_weight / degree_out2[col]
adjacency_with_degree2 = utils.sparse_coo_tensor(torch.stack([row, col]), edge_weight2,
                                            (graph.num_node, graph.num_node * graph.num_relation)).to_dense()

print("\nEqual:", torch.equal(adjacency_with_degree2, adjacency_with_degree))


Equal: True
