# 前期准备

In [1]:
from torchdrug import  layers, datasets,transforms,core
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
from torchdrug import data

from torch_scatter import scatter_add
import torch.nn as nn
from torchdrug import utils
from torch.utils import checkpoint
import torch.nn.functional as F

import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm
  return torch._C._cuda_getDeviceCount() > 0


In [2]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

# 只保留alpha碳的简化格式
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=1)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )


graphs = dataset[:2]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print("\n\n")
graph = graph_construction_model(graphs)
print(graph)

22:40:02   Extracting /home/cu/scratch/protein-datasets/EnzymeCommission.zip to /home/cu/scratch/protein-datasets


Loading /home/cu/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:22<00:00, 831.02it/s] 





PackedProtein(batch_size=2, num_atoms=[185, 415], num_bonds=[3388, 8173], num_residues=[185, 415])


In [63]:
class relationalGraph(layers.MessagePassingBase):
    
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None
            
    def trans(self, A, graph):
    
        Degree_inv_sqrt = torch.diag(torch.pow(torch.sum(A, dim=1), -0.5))
        A_norm = torch.mm(torch.mm(Degree_inv_sqrt, A), Degree_inv_sqrt)
        
        n_rel = graph.num_relation
        n = A_norm.size(0)
        n_rel = n_rel.item()  # 将 n_rel 从 Tensor 转换为 int
        assert n % n_rel == 0, "n must be divisible by n_rel"
        
        block_size = n // n_rel
        
        # 初始化一个张量来存储累加结果
        accumulated = torch.zeros_like(A_norm[:block_size])
        
        # 将后面的所有块累加到第一块
        for i in range(n_rel):
            accumulated += A_norm[i * block_size: (i + 1) * block_size]
        
        # 用累加后的第一块替换原始矩阵的第一块
        A_trans = accumulated
    
        return A_trans

    def message_and_aggregate(self, graph, input, new_edge_list):
        assert graph.num_relation == self.num_relation
        device = input.device  # Ensure device consistency
        
        if new_edge_list is None:
            node_in, node_out, relation = graph.edge_list.t().to(device)
            node_out = node_out * self.num_relation + relation
        
            edge_weight = torch.ones_like(node_out)
            degree_out = scatter_add(edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
            degree_out = degree_out
            edge_weight = edge_weight / degree_out[node_out]
            adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                                (graph.num_node, graph.num_node * graph.num_relation))
            update = torch.sparse.mm(adjacency.t(), input)
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float().to(device)
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1).to(device)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update

    def combine(self, input, update):
        # 自环特征
        device = input.device
        self.linear.to(device)  # Ensure the linear layers are on the correct device
        self.self_loop.to(device)
        input = input.repeat(self.num_relation, 1).to(device)
        loop_update = self.self_loop(input).to(device)
        
        output = self.linear(update)+loop_update
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, new_edge_list=None):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input)
        else:
            update = self.message_and_aggregate(graph, input, new_edge_list)
        output = self.combine(input, update).view(graph.num_relation, input.size(0), -1)
        return output

In [64]:
input_dim = graph.node_feature.shape[-1]
output_dim = 128
num_relations = graph.num_relation

relational_output = relationalGraph(input_dim, output_dim, num_relations)(graph, graph.node_feature.float(), new_edge_list = None)
print(relational_output)
print("output: ", relational_output.shape)

tensor([[[0.0504, 0.0000, 0.4579,  ..., 0.0000, 0.0063, 0.3589],
         [0.0000, 0.0000, 0.0664,  ..., 0.0092, 0.0000, 0.2898],
         [0.2564, 0.0778, 0.6395,  ..., 0.0000, 0.0000, 0.2653],
         ...,
         [0.0000, 0.1378, 0.4195,  ..., 0.0000, 0.0000, 0.3237],
         [0.1255, 0.0667, 0.2849,  ..., 0.0000, 0.0000, 0.2467],
         [0.0073, 0.0000, 0.0689,  ..., 0.0589, 0.0401, 0.3647]],

        [[0.0000, 0.0292, 0.4592,  ..., 0.0000, 0.0000, 0.3837],
         [0.0000, 0.0165, 0.2323,  ..., 0.1008, 0.0000, 0.2555],
         [0.0000, 0.0338, 0.3991,  ..., 0.0000, 0.0000, 0.3721],
         ...,
         [0.2455, 0.0450, 0.2922,  ..., 0.0449, 0.0000, 0.4263],
         [0.1255, 0.0667, 0.2849,  ..., 0.0000, 0.0000, 0.2467],
         [0.1074, 0.0109, 0.1830,  ..., 0.2791, 0.0126, 0.3196]],

        [[0.0132, 0.0319, 0.4318,  ..., 0.0000, 0.0000, 0.3909],
         [0.0000, 0.1368, 0.0188,  ..., 0.0974, 0.0000, 0.2854],
         [0.2005, 0.0696, 0.5348,  ..., 0.0295, 0.0000, 0.

# local attention

In [148]:

class test(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, k, temperature=0.5, dropout=0.1):
        super(test, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.window_size = window_size
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)
        self.k = k
        
        self.query = nn.Linear( in_features, out_features* num_heads)
        self.key = nn.Linear( in_features, out_features* num_heads)
        self.scale = 1 / (out_features ** 0.5)

    def split_windows(self, tensor, index, window_size):
        result = []
        index_list  = []
        # 初始的开始索引
        start = 0

        for idx in index:
            end = start + idx - 1
            while start < end:
                if start + window_size <= end:
                    result.append(tensor[:, start:start+window_size, :])
                    index_list.append([start, start+window_size])
                    start += window_size
                else:
                    # 计算还需要填充多少行
                    padding_rows = window_size - (end - start + 1)
                    restart = start - padding_rows
                    result.append(tensor[:, restart:restart + window_size , :])
                    index_list.append([restart, restart+window_size])
                    start = end + 1
                    
        # 转换结果列表为 tensor
        result_tensor = torch.stack(result, dim=1)
        return result_tensor, index_list

    def gumbel_softmax_top_k(self, logits, tau=1.0, hard=False):
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / tau

        y_soft = F.softmax(gumbels, dim=-1)

        if hard:
            topk_indices = logits.topk(self.k, dim=-1)[1]
            y_hard = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
            y = (y_hard - y_soft).detach() + y_soft
        else:
            y = y_soft

        return y
    
    def windows2adjacent(self, windows, index_list, output):
    
        for i, index in enumerate(index_list):
            start, end = index
            output[:, start:end, start:end] = windows[:, i, :, :]
            
        num_relations, num_nodes, _ = output.shape
        result = torch.zeros(num_relations* num_nodes, num_relations*num_nodes)
        for i in range(num_relations):
            result[i*num_nodes:(i+1)*num_nodes, i*num_nodes:(i+1)*num_nodes] = output[i]
        return result

    def forward(self, graph, node_features):
        device = node_features.device
        num_relations = node_features.size(0)
        num_nodes = node_features.size(1)
        index = graph.num_nodes.tolist()
        
        Q = self.query(node_features).view(num_relations, num_nodes, self.num_heads, self.out_features).permute(0, 2, 1, 3)
        K = self.key(node_features).view(num_relations, num_nodes, self.num_heads, self.out_features).permute(0, 2, 1, 3)   
        Q = Q.reshape(num_relations * self.num_heads, num_nodes, self.out_features)
        K = K.reshape(num_relations * self.num_heads, num_nodes, self.out_features)
        
        output = torch.zeros(num_relations, num_nodes, num_nodes)
       
        Q_windows, Q_index = self.split_windows(Q, index, self.window_size)
        K_windows, _ = self.split_windows(K, index, self.window_size)
        
        # 计算 attention
        scores  = torch.einsum('b h i e, b h j e -> b h i j', Q_windows, K_windows) / self.scale
        attn = scores.softmax(dim=-1).view(num_relations, self.num_heads, -1, self.window_size, self.window_size).mean(dim=1)
        attn = self.gumbel_softmax_top_k(attn, tau=self.temperature, hard=True)
        
        result = self.windows2adjacent(attn, Q_index, output)
        
        return result
        

In [152]:
input_dim = relational_output.shape[-1]
output_dim = 128
num_heads = 8
window_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
k = 5
num_nodes = relational_output.size(0)


relational_output = relational_output.to(device)
start_time = time.time()
module = test(input_dim, output_dim, num_heads, window_size, k).to(device)
attn_output = module(graph, relational_output)
end_time = time.time()
print(f"运行时间: {end_time - start_time:.6f} 秒")

运行时间: 0.029165 秒


In [150]:
attn_output = module(graph, relational_output)
print(attn_output.shape)
print(attn_output)
#a = attn_output[1200:, 599:1199]
indices = torch.nonzero(attn_output, as_tuple=True)
print(indices)

torch.Size([3000, 3000])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 0., 1., 1.]], grad_fn=<CopySlices>)
(tensor([   0,    0,    0,  ..., 2999, 2999, 2999]), tensor([   3,    4,    5,  ..., 2995, 2998, 2999]))


In [147]:
import matplotlib.pyplot as plt
import numpy as np

# 创建一个示例矩阵 (可以是随机矩阵或其他数据)
matrix = attn_output.cpu().detach().numpy()

# 绘制矩阵为图像
plt.imshow(matrix, cmap='viridis', interpolation='nearest')
plt.colorbar()  # 显示颜色条
plt.title("Matrix as Image")
plt.xlabel("Column")
plt.ylabel("Row")
plt.savefig("fig/rewire.png")
