In [2]:
from torchdrug import core
from torchdrug import datasets, transforms,layers
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
import torchdrug
from torchdrug import data

import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('Agg')

# 数据集展示

In [3]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

16:10:23   Extracting /home/xiaotong/scratch/protein-datasets/EnzymeCommission.zip to /home/xiaotong/scratch/protein-datasets


Loading /home/xiaotong/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:47<00:00, 391.38it/s]


展示第一个样本的前两个残基的原子

In [4]:
# 数据集第一个样本，前两个残基的原子
protein = dataset[0]["graph"]
print(protein)
is_first_two = (protein.residue_number == 1) | (protein.residue_number == 2) | (protein.residue_number == 3)
first_two = protein.residue_mask(is_first_two, compact=True)

first_two.visualize()
#plt.savefig("fig/first_two_three.png")


Protein(num_atom=1596, num_bond=2920, num_residue=349)


In [5]:
first_two_elements = dataset[:2]
graphs = [element["graph"] for element in first_two_elements]
protein2 = data.Protein.pack(graphs)
print(protein2)

PackedProtein(batch_size=2, num_atoms=[1596, 3761], num_bonds=[2920, 6468], num_residues=[349, 997])


测试edge_feature

In [5]:
graph = dataset[0]["graph"]
print(graph)

edge_list = graph.edge_list
print(edge_list)

num_relations = 7

node_in, node_out, _ = edge_list.t()
print("node_in: ", node_in)
print("node_out: ", node_out)
print("\n")


print("atom2residue:", graph.atom2residue)
print("atom2residue.shape:", graph.atom2residue.shape)
print("\n")

residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
print("residue_in: ", residue_in)
print("residue_out: ", residue_out)
print("residue_in.shape: ", residue_in.shape)
print("\n")




Protein(num_atom=1596, num_bond=2920, num_residue=349)
tensor([[   1,    0,    0],
        [   0,    1,    0],
        [   2,    1,    0],
        ...,
        [1429, 1430,    0],
        [1431, 1421,    0],
        [1421, 1431,    0]])
node_in:  tensor([   1,    0,    2,  ..., 1429, 1431, 1421])
node_out:  tensor([   0,    1,    1,  ..., 1430, 1421, 1431])


atom2residue: tensor([  0,   0,   0,  ..., 346, 347, 348])
atom2residue.shape: torch.Size([1596])


residue_in:  tensor([  0,   0,   0,  ..., 184, 184, 184])
residue_out:  tensor([  0,   0,   0,  ..., 184, 184, 184])
residue_in.shape:  torch.Size([2920])




只保留alpha碳，以及按照gearnet格式简化图

In [7]:
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )

_protein = data.Protein.pack([protein])
protein_ = graph_construction_model(_protein)
print("Graph before: ", _protein)
print("Graph after: ", protein_)

print("node_feature: ", protein_.node_feature.shape)
print("edge_feature: ", protein_.edge_feature.shape)

print("edge_weight: ", protein_.edge_weight.shape)

print("node_position: ", protein_.node_position.shape)

# 测试unsqueeze
edge_weight = protein_.edge_weight.unsqueeze(-1)
print("new_edge_weight: ", edge_weight.shape, edge_weight)



Graph before:  PackedProtein(batch_size=1, num_atoms=[1596], num_bonds=[2920], num_residues=[349])
Graph after:  PackedProtein(batch_size=1, num_atoms=[185], num_bonds=[3754], num_residues=[185])
node_feature:  torch.Size([185, 21])
edge_feature:  torch.Size([3754, 59])
edge_weight:  torch.Size([3754])
node_position:  torch.Size([185, 3])
new_edge_weight:  torch.Size([3754, 1]) tensor([[1.],
        [1.],
        [1.],
        ...,
        [1.],
        [1.],
        [1.]])


# Gearnet 流程

生成稀疏的邻接矩阵

In [7]:
node_in, node_out, relation = protein_.edge_list.t()
node_out = node_out * protein_.num_relation + relation
print(node_in, node_out)
print(node_in.shape, node_out.shape)

import torchdrug.utils as utils
adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), protein_.edge_weight,
                                    (protein_.num_node, protein_.num_node * protein_.num_relation))

print(adjacency)
print(adjacency.shape)

tensor([90, 85, 83,  ..., 75, 50, 51]) tensor([575, 547, 547,  ..., 749, 756, 756])
torch.Size([3754]) torch.Size([3754])
tensor(indices=tensor([[ 90,  85,  83,  ...,  75,  50,  51],
                       [575, 547, 547,  ..., 749, 756, 756]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       size=(185, 1295), nnz=3754, layout=torch.sparse_coo)
torch.Size([185, 1295])


# DGM模块

In [147]:
graphs = dataset[:2]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print(graphs)
protein2_ = graph_construction_model(graphs)
graph = protein2_
#graph = protein_
print("Graph before: ", protein2)
print("Graph after: ", graph)
print("\n")
print("graph_node_feature :",graph.node_feature.shape)
print("graph_edge_feature :",graph.edge_feature.shape)
print("\n")
print("graph_node _feature_type :",graph.node_feature.dtype)

len(graph.node_feature)
#print(graph.node_feature.max())

print("\n")
print("graph.edge_list: ", graph.edge_list)
print("graph.edge_list.shape: ", graph.edge_list.shape)

print("\n")
print("graph.edge_weight: ", graph.edge_weight)
print("graph.edge_weight.shape: ", graph.edge_weight.shape)

PackedProtein(batch_size=2, num_atoms=[1596, 3761], num_bonds=[2920, 6468], num_residues=[349, 997])
Graph before:  PackedProtein(batch_size=2, num_atoms=[1596, 3761], num_bonds=[2920, 6468], num_residues=[349, 997])
Graph after:  PackedProtein(batch_size=2, num_atoms=[185, 415], num_bonds=[3754, 8999], num_residues=[185, 415])


graph_node_feature : torch.Size([600, 21])
graph_edge_feature : torch.Size([12753, 59])


graph_node _feature_type : torch.int64


graph.edge_list:  tensor([[ 95,  96,   5],
        [109, 110,   5],
        [108, 109,   5],
        ...,
        [438, 470,   0],
        [489, 470,   0],
        [493, 470,   0]])
graph.edge_list.shape:  torch.Size([12753, 3])


graph.edge_weight:  tensor([1., 1., 1.,  ..., 1., 1., 1.])
graph.edge_weight.shape:  torch.Size([12753])


In [64]:
#dataloader = data.DataLoader(graph, batch_size=1, shuffle=True)

关系神经网络

In [148]:
from torch_scatter import scatter_mean, scatter_add, scatter_max
import torch.nn as nn
from torchdrug import utils

class relationalGraph(layers. MessagePassingBase):
    eps = 1e-10

    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
        super(relationalGraph, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.edge_input_dim = edge_input_dim

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(output_dim)
        else:
            self.batch_norm = None
        if isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            self.activation = activation

        #self.self_loop = nn.Linear(input_dim, output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        if edge_input_dim:
            self.edge_linear = nn.Linear(edge_input_dim, input_dim)
        else:
            self.edge_linear = None

    def message_and_aggregate(self, graph, input):
        assert graph.num_relation == self.num_relation

        node_in, node_out, relation = graph.edge_list.t()
        node_out = node_out * self.num_relation + relation
        degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation)
        edge_weight = graph.edge_weight / degree_out[node_out]
        adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight,
                                            (graph.num_node, graph.num_node * graph.num_relation))
        update = torch.sparse.mm(adjacency.t(), input)
        if self.edge_linear:
            edge_input = graph.edge_feature.float()
            edge_input = self.edge_linear(edge_input)
            edge_weight = edge_weight.unsqueeze(-1)
            edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
                                      dim_size=graph.num_node * graph.num_relation)
            update += edge_update

        return update


    def combine(self, input, update):
        output = self.linear(update)
        if self.batch_norm:
            output = self.batch_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

In [149]:
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relations = graph.num_relation

output = relationalGraph(input_dim, output_dim, num_relations)(graph, graph.node_feature.float())
print("output: ", output.shape)

output:  torch.Size([4200, 512])


In [150]:
def split_matrix(matrix, num_relation):
    # 计算每份的行数
    rows_per_split = matrix.shape[0] // num_relation

    # 分割矩阵
    split_matrices = []
    for i in range(num_relation):
        start_idx = i * rows_per_split
        end_idx = start_idx + rows_per_split
        split_matrices.append(matrix[start_idx:end_idx])

    return split_matrices

relation_output = split_matrix(output, graph.num_relation)

print("relation_output: ", relation_output[0].shape)

relation_output:  torch.Size([600, 512])


计算score

In [151]:
from torchdrug.layers import functional
from torch import nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, in_features, out_features, num_heads, temperature=0.5, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)

        self.query = nn.Linear(in_features, out_features * num_heads)
        self.key = nn.Linear(in_features, out_features * num_heads)
        self.value = nn.Linear(in_features, out_features * num_heads)
        self.scale = torch.sqrt(torch.FloatTensor([out_features]))
        

    def forward(self, node_features, graph):
        num_nodes = node_features.size(0)

        # 计算查询、键和值，并为多头自注意力进行变形
        Q = self.query(node_features).view(num_nodes, self.num_heads, self.out_features)  # [num_nodes, num_heads, out_features]
        K = self.key(node_features).view(num_nodes, self.num_heads, self.out_features)    # [num_nodes, num_heads, out_features]
        #V = self.value(node_features).view(num_nodes, self.num_heads, self.out_features)  # [num_nodes, num_heads, out_features]

        # 计算相似性分数
        scores = torch.einsum("nhd,mhd->nhm", Q, K) / self.scale  # [num_heads, num_nodes, num_nodes]
        
        scores = scores / self.temperature

        # 计算注意力权重
        attention_weights = F.softmax(scores, dim=-1)  # [num_heads, num_nodes, num_nodes]
        
        # 应用 dropout
        attention_weights = self.dropout(attention_weights)

        # 多头结果合并
        attention_weights = attention_weights.mean(dim=-2)  # [num_nodes, num_nodes]

        return attention_weights


In [177]:
import torch
import torch.nn as nn
from torchdrug import layers, utils

class CombinedModel(nn.Module):
    def __init__(self, input_dim, output_dim, num_relation, edge_input_dim, 
                 num_heads, attention_out_features, temperature=0.5, dropout=0.1):
        super(CombinedModel, self).__init__()
        
        self.relational_graph = relationalGraph(input_dim, output_dim, num_relation, edge_input_dim)
        self.attention = MultiHeadSelfAttention(output_dim, attention_out_features, num_heads, temperature, dropout)
    
    def split_matrix(self,matrix, num_relation):
        # 计算每份的行数
        rows_per_split = matrix.shape[0] // num_relation

        # 分割矩阵
        split_matrices = []
        for i in range(num_relation):
            start_idx = i * rows_per_split
            end_idx = start_idx + rows_per_split
            split_matrices.append(matrix[start_idx:end_idx])

        return split_matrices
    
    
    def merge_relation_matrices(self,matrix_list):
        merged_list = []

        # 遍历每个矩阵
        for idx, matrix in enumerate(matrix_list):
            num_rows = matrix.size(0)
            # 创建一个形状为 [num_rows, 1] 的张量，记录矩阵索引
            relation_idx = torch.full((num_rows, 1), idx, dtype=torch.long)
            # 将矩阵和索引列拼接在一起
            extended_matrix = torch.cat((matrix, relation_idx), dim=1)
            merged_list.append(extended_matrix)
        
        # 将所有扩展后的矩阵合并成一个矩阵
        merged_matrix = torch.cat(merged_list, dim=0)
        
        return merged_matrix

    def process_list(self,tensor_list):
        result_list = []
        
        for matrix in tensor_list:
            num_nodes = matrix.size(0)
            indices_list = []

            for i in range(num_nodes):
                row = matrix[i]
                # 找到前5个最大的数及其索引
                sorted_indices = torch.argsort(row, descending=True)  # 从大到小排序
                max_values = row[sorted_indices][:5]
                
                # 检查是否有超过5个相同的最大数
                count = (row == max_values[-1]).sum().item()
                
                if count > 5:
                    # 选择与当前行序号距离最近的列索引
                    relevant_indices = sorted_indices[:count]
                    distances = torch.abs(relevant_indices - i)
                    closest_indices = relevant_indices[torch.argsort(distances)][:5]
                    selected_indices = closest_indices
                else:
                    selected_indices = sorted_indices[:5]
                
                # 记录当前行号和所选列号
                for j in selected_indices:
                    indices_list.append([i, j.item()])
            
            result_matrix = torch.tensor(indices_list, dtype=torch.long)
            result_list.append(result_matrix)
            
        
        return merge_relation_matrices(result_list)
    
    
    
    def forward(self, node_features, graph):
        # Apply relational graph layer
        relational_output = self.relational_graph(graph, node_features)
        
        relational_output = split_matrix(relational_output, graph.num_relation)
        #print("relational_output: ", len(relational_output))
        
        attention_output = []
        for i in range(len(relational_output)):
            # Apply multi-head self attention
            output = self.attention(relational_output[i], graph)
            attention_output.append(output)
        
        
        return attention_output
    

# Example of how to initialize and use the model
input_dim = graph.node_feature.shape[-1]
output_dim = 512
num_relation = graph.num_relation
edge_input_dim = graph.edge_feature.shape[-1]
num_heads = 8
attention_out_features = 512

model = CombinedModel(input_dim, output_dim, num_relation, edge_input_dim, 
                      num_heads, attention_out_features)

# Example data


output = model(graph.node_feature.float(), graph)
print("output: ", output[0],output[1])



output:  tensor([[0.0018, 0.0016, 0.0019,  ..., 0.0019, 0.0016, 0.0016],
        [0.0014, 0.0018, 0.0019,  ..., 0.0019, 0.0011, 0.0016],
        [0.0016, 0.0016, 0.0012,  ..., 0.0016, 0.0018, 0.0019],
        ...,
        [0.0016, 0.0016, 0.0019,  ..., 0.0018, 0.0018, 0.0019],
        [0.0014, 0.0018, 0.0018,  ..., 0.0019, 0.0018, 0.0016],
        [0.0016, 0.0016, 0.0016,  ..., 0.0018, 0.0012, 0.0019]],
       grad_fn=<MeanBackward1>) tensor([[0.0016, 0.0014, 0.0016,  ..., 0.0018, 0.0018, 0.0019],
        [0.0016, 0.0019, 0.0016,  ..., 0.0018, 0.0019, 0.0016],
        [0.0016, 0.0016, 0.0016,  ..., 0.0016, 0.0014, 0.0014],
        ...,
        [0.0019, 0.0016, 0.0019,  ..., 0.0016, 0.0019, 0.0016],
        [0.0019, 0.0016, 0.0016,  ..., 0.0016, 0.0016, 0.0014],
        [0.0016, 0.0016, 0.0016,  ..., 0.0018, 0.0019, 0.0016]],
       grad_fn=<MeanBackward1>)


In [178]:
def merge_relation_matrices(matrix_list):
    merged_list = []

    # 遍历每个矩阵
    for idx, matrix in enumerate(matrix_list):
        num_rows = matrix.size(0)
        # 创建一个形状为 [num_rows, 1] 的张量，记录矩阵索引
        relation_idx = torch.full((num_rows, 1), idx, dtype=torch.long)
        # 将矩阵和索引列拼接在一起
        extended_matrix = torch.cat((matrix, relation_idx), dim=1)
        merged_list.append(extended_matrix)
    
    # 将所有扩展后的矩阵合并成一个矩阵
    merged_matrix = torch.cat(merged_list, dim=0)
    
    return merged_matrix

def process_list(tensor_list):
    result_list = []
    
    for matrix in tensor_list:
        num_nodes = matrix.size(0)
        indices_list = []

        for i in range(num_nodes):
            row = matrix[i]
            # 找到前5个最大的数及其索引
            sorted_indices = torch.argsort(row, descending=True)  # 从大到小排序
            max_values = row[sorted_indices][:5]
            
            # 检查是否有超过5个相同的最大数
            count = (row == max_values[-1]).sum().item()
            
            if count > 5:
                # 选择与当前行序号距离最近的列索引
                relevant_indices = sorted_indices[:count]
                distances = torch.abs(relevant_indices - i)
                closest_indices = relevant_indices[torch.argsort(distances)][:5]
                selected_indices = closest_indices
            else:
                selected_indices = sorted_indices[:5]
            
            # 记录当前行号和所选列号
            for j in selected_indices:
                indices_list.append([i, j.item()])
        
        result_matrix = torch.tensor(indices_list, dtype=torch.long)
        result_list.append(result_matrix)
        
    
    return merge_relation_matrices(result_list)

output = process_list(output)
print("output: ", output)

output:  tensor([[  0, 585,   0],
        [  0,  32,   0],
        [  0, 333,   0],
        ...,
        [599, 450,   6],
        [599, 296,   6],
        [599, 240,   6]])


In [179]:
graphs = dataset[:2]
graphs = [element["graph"] for element in graphs]
graphs = data.Protein.pack(graphs)
print(graphs)
protein2_ = graph_construction_model(graphs)


@R.register("models.DGMGearNet")
class DGMGeometryAwareRelationalGraphNeuralNetwork(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, num_angle_bin=None,
                 short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
        super(DGMGeometryAwareRelationalGraphNeuralNetwork, self).__init__()

        if not isinstance(hidden_dims, Sequence):
            hidden_dims = [hidden_dims]
        self.input_dim = input_dim
        self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
        self.dims = [input_dim] + list(hidden_dims)
        self.edge_dims = [edge_input_dim] + self.dims[:-1]
        self.num_relation = num_relation
        self.num_angle_bin = num_angle_bin
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden
        self.batch_norm = batch_norm

        self.layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            
            self.layers.append(layers.GeometricRelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation,
                                                                   None, batch_norm, activation))
        if num_angle_bin:
            self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
            self.edge_layers = nn.ModuleList()
            for i in range(len(self.edge_dims) - 1):
                self.edge_layers.append(layers.GeometricRelationalGraphConv(
                    self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

        if batch_norm:
            self.batch_norms = nn.ModuleList()
            for i in range(len(self.dims) - 1):
                self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1]))

        if readout == "sum":
            self.readout = layers.SumReadout()
        elif readout == "mean":
            self.readout = layers.MeanReadout()
        else:
            raise ValueError("Unknown readout `%s`" % readout)

    def forward(self, graph, input, all_loss=None, metric=None):
        """
        Compute the node representations and the graph representation(s).

        Parameters:
            graph (Graph): :math:`n` graph(s)
            input (Tensor): input node representations
            all_loss (Tensor, optional): if specified, add loss to this tensor
            metric (dict, optional): if specified, output metrics to this dict

        Returns:
            dict with ``node_feature`` and ``graph_feature`` fields:
                node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)`
        """
        hiddens = []
        layer_input = input
        if self.num_angle_bin:
            line_graph = self.spatial_line_graph(graph)
            edge_input = line_graph.node_feature.float()

        for i in range(len(self.layers)):
            hidden = self.layers[i](graph, layer_input)
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
            if self.num_angle_bin:
                edge_hidden = self.edge_layers[i](line_graph, edge_input)
                edge_weight = graph.edge_weight.unsqueeze(-1)
                node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
                update = scatter_add(edge_hidden * edge_weight, node_out, dim=0,
                                     dim_size=graph.num_node * self.num_relation)
                update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1])
                update = self.layers[i].linear(update)
                update = self.layers[i].activation(update)
                hidden = hidden + update
                edge_input = edge_hidden
            if self.batch_norm:
                hidden = self.batch_norms[i](hidden)
            hiddens.append(hidden)
            layer_input = hidden

        if self.concat_hidden:
            node_feature = torch.cat(hiddens, dim=-1)
        else:
            node_feature = hiddens[-1]
        graph_feature = self.readout(graph, node_feature)

        return {
            "graph_feature": graph_feature,
            "node_feature": node_feature
        }

PackedProtein(batch_size=2, num_atoms=[1596, 3761], num_bonds=[2920, 6468], num_residues=[349, 997])


# Gearnet测试

In [11]:
from torchdrug import models

dataloader = data.DataLoader(graph, batch_size=1, shuffle=True)

In [13]:
GN = models.GearNet(input_dim=21,
                    hidden_dims=512,
                    batch_norm=True,
                    concat_hidden=True,
                    short_cut=True,
                    readout="sum",
                    num_relation=7,
                    edge_input_dim=59
                    )

# 获取一个批次的数据并传递给模型
for batch in dataloader:
    #graph = batch['graph']
    gearnet_output= GN(graph, graph.node_feature.to(torch.float32))
    node_feature1 = gearnet_output["node_feature"]
    print("hiddden_node_feature:", node_feature1.shape)

    break  # 这里只查看一个批次

hiddden_node_feature: torch.Size([600, 512])


edge_list:  tensor([[ 95,  96,   5],
        [109, 110,   5],
        [108, 109,   5],
        ...,
        [438, 470,   0],
        [489, 470,   0],
        [493, 470,   0]])


KeyError: tensor(5)