In [53]:
from torchdrug import  layers
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
from torchdrug import data

from torch_scatter import scatter_add
import torch.nn as nn
from torchdrug import utils
from torch.utils import checkpoint
import torch.nn.functional as F

# 数据集获取

In [34]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

# 只保留alpha碳的简化格式
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )


graphs = dataset[:2]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print("\n\n")
graph = graph_construction_model(graphs)
print(graph)

13:48:29   Extracting /home/xiaotong/scratch/protein-datasets/EnzymeCommission.zip to /home/xiaotong/scratch/protein-datasets


Loading /home/xiaotong/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:44<00:00, 416.65it/s]





PackedProtein(batch_size=2, num_atoms=[185, 415], num_bonds=[3754, 8999], num_residues=[185, 415])


In [80]:
print(graph.num_nodes)
print(graph.batch_size)

tensor([185, 415])
2


# 关系卷积神经网络，获取多个不同的嵌入

In [50]:
class relationalGraph(layers.MessagePassingBase):
    
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None

    def message_and_aggregate(self, graph, input, edge_list):
        assert graph.num_relation == self.num_relation

        if edge_list is None:
            node_in, node_out, relation = graph.edge_list.t()
        else:
            node_in, node_out, relation = edge_list.t()
        node_out = node_out * self.num_relation + relation
        
        edge_weight = torch.ones_like(node_out)
        degree_out = scatter_add(edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
        degree_out = degree_out
        edge_weight = edge_weight / degree_out[node_out]
        adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                            (graph.num_node, graph.num_node * graph.num_relation))
        update = torch.sparse.mm(adjacency.t(), input)
        
        if self.edge_linear:
            edge_input = graph.edge_feature.float()
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update

    def combine(self, input, update):
        # 自环特征
        input = input.repeat(self.num_relation, 1)
        loop_update = self.self_loop(input)
        
        output = self.linear(update)+loop_update
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output
    
    def forward(self, graph, input, edge_list):
        
        if self.gradient_checkpoint:
            update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input)
        else:
            update = self.message_and_aggregate(graph, input, edge_list)
        output = self.combine(input, update)
        return output

### 测试

In [52]:
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relations = graph.num_relation

relational_output = relationalGraph(input_dim, output_dim, num_relations)(graph, graph.node_feature.float(), graph.edge_list)
print("output: ", relational_output.shape)

output:  torch.Size([4200, 512])


# 重连接模块

In [156]:
import torch

def get_start_end(current, graph):
    """
    根据一维张量 a 生成新的张量 b。
    
    :param a: 输入的一维张量
    :return: 输出的一维张量 b
    """
    # 初始化 b，第一个元素是 0
    segment = graph.num_nodes.repeat(graph.num_relation)
    index = torch.zeros(segment.size(0) + 1, dtype=segment.dtype)
    
    # 计算 b 的每个元素
    for i in range(1, len(b)):
        index[i] = index[i - 1] + segment[i - 1]
    
    # 遍历张量以找到索引值的位置
    for i in range(len(index) - 1):
        if index[i] <= current < index[i + 1]:
            return (index[i].item(), index[i + 1].item())
        elif index[i] == current:
            return (index[i].item(), index[i + 1].item())
    
    # 如果索引值恰好等于张量的最后一个元素
    if current == index[-1]:
        return (index[-1].item(), index[-1].item())
    
    

# 示例使用
start, end = get_start_end(185, graph)
print(start, end)


185 600


In [172]:
from torch import nn
import torch
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, in_features, out_features, num_heads, window_size, temperature=0.5, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.window_size = window_size
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)

        self.query = nn.Linear(in_features, out_features * num_heads)
        self.key = nn.Linear(in_features, out_features * num_heads)
        self.value = nn.Linear(in_features, out_features * num_heads)
        self.scale = torch.sqrt(torch.FloatTensor([out_features])).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        
    def get_start_end(self, current, graph):
        # 初始化 segment，第一个元素是 0
        segment = graph.num_nodes.repeat(graph.num_relation)
        index = torch.zeros(segment.size(0) + 1, dtype=segment.dtype)
        
        # 计算 index 的每个元素
        for i in range(1, len(index)):
            index[i] = index[i - 1] + segment[i - 1]
        
        # 遍历张量以找到索引值的位置
        for i in range(len(index) - 1):
            if index[i] <= current < index[i + 1]:
                return (index[i].item(), index[i + 1].item())
            elif index[i] == current:
                return (index[i].item(), index[i + 1].item())
        
        # 如果索引值恰好等于张量的最后一个元素
        if current == index[-1]:
            return (index[-1].item(), index[-1].item())

    def forward(self, node_features, graph):
        device = node_features.device
        num_nodes = node_features.size(0)
        half_window = self.window_size // 2

        # Apply linear layers and split into multiple heads
        Q = self.query(node_features).view(num_nodes, self.num_heads, self.out_features)  # [num_nodes, num_heads, out_features]
        K = self.key(node_features).view(num_nodes, self.num_heads, self.out_features)    # [num_nodes, num_heads, out_features]
        

        # Initialize output tensor
        output = torch.zeros(num_nodes, num_nodes, device=device)  

        # Precompute start and end indices for each node
        start_end_indices = [self.get_start_end(i, graph) for i in range(num_nodes)]

        # Compute sliding window attention
        for i in range(num_nodes):
            start_index, end_index = start_end_indices[i]
            start = max(start_index, i - half_window)
            end = min(end_index, i + half_window + 1)

            Q_i = Q[i].unsqueeze(0)  # [1, num_heads, out_features]
            K_window = K[start:end]  # [window_size, num_heads, out_features]

            scores = torch.einsum("nhd,mhd->nhm", Q_i, K_window) / self.scale  # [num_heads, 1, window_size]
            scores = scores / self.temperature

            attention_weights = F.softmax(scores, dim=-1)  # [num_heads, 1, window_size]
            attention_weights = attention_weights.mean(dim=1)  # [num_heads, window_size]

            output[i, start:end] = attention_weights

        return output

# 示例使用
# 假设 graph 和 node_features 已经定义
# graph.num_nodes 和 graph.num_relation 需要正确设置
# node_features 是一个 [num_nodes, in_features] 的张量


In [173]:
import numpy as np

input_dim = relational_output.shape[-1]
output_dim = 1024
num_heads = 8
window_size = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


relational_output = relational_output.to(device)
module = MultiHeadSelfAttention(input_dim, output_dim, num_heads, window_size).to(device)
attn_output = module(relational_output, graph)