In [2]:
from torchdrug import core
from torchdrug import datasets, transforms,layers
from torchdrug.core import Registry as R
from torchdrug.layers import geometry

import torch
import torchdrug
from torchdrug import data

import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('Agg')

# 数据集展示

In [3]:
EnzymeCommission = R.search("datasets.EnzymeCommission")
PV = R.search("transforms.ProteinView")
trans = PV(view = "residue")
dataset = EnzymeCommission("~/scratch/protein-datasets/", test_cutoff=0.95, 
                           atom_feature="full", bond_feature="full", verbose=1, transform = trans)

17:41:25   Extracting /home/xiaotong/scratch/protein-datasets/EnzymeCommission.zip to /home/xiaotong/scratch/protein-datasets


Loading /home/xiaotong/scratch/protein-datasets/EnzymeCommission/enzyme_commission.pkl.gz: 100%|██████████| 18716/18716 [00:47<00:00, 393.54it/s]


展示第一个样本的前两个残基的原子

In [4]:
# 数据集第一个样本，前两个残基的原子
protein = dataset[0]["graph"]
is_first_two = (protein.residue_number == 1) | (protein.residue_number == 2)
first_two = protein.residue_mask(is_first_two, compact=True)

first_two.visualize()
plt.savefig("fig/first_two.png")

测试edge_feature

In [33]:
graph = dataset[0]["graph"]
print(graph)

edge_list = graph.edge_list
print(edge_list)

num_relations = 7

node_in, node_out, _ = edge_list.t()
print("node_in: ", node_in)
print("node_out: ", node_out)
print("\n")


print("atom2residue:", graph.atom2residue)
print("atom2residue.shape:", graph.atom2residue.shape)
print("\n")

residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
print("residue_in: ", residue_in)
print("residue_out: ", residue_out)
print("residue_in.shape: ", residue_in.shape)
print("\n")




Protein(num_atom=1596, num_bond=2920, num_residue=349)
tensor([[   1,    0,    0],
        [   0,    1,    0],
        [   2,    1,    0],
        ...,
        [1429, 1430,    0],
        [1431, 1421,    0],
        [1421, 1431,    0]])
node_in:  tensor([   1,    0,    2,  ..., 1429, 1431, 1421])
node_out:  tensor([   0,    1,    1,  ..., 1430, 1421, 1431])


atom2residue: tensor([  0,   0,   0,  ..., 346, 347, 348])
atom2residue.shape: torch.Size([1596])


residue_in:  tensor([  0,   0,   0,  ..., 184, 184, 184])
residue_out:  tensor([  0,   0,   0,  ..., 184, 184, 184])
residue_in.shape:  torch.Size([2920])




只保留alpha碳，以及按照gearnet格式简化图

In [5]:
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)
                                                                 ],
                                                    edge_feature="gearnet"
                                                    )

_protein = data.Protein.pack([protein])
protein_ = graph_construction_model(_protein)
print("Graph before: ", _protein)
print("Graph after: ", protein_)

print("node_feature: ", protein_.node_feature.shape)
print("edge_feature: ", protein_.edge_feature.shape)

print("edge_weight: ", protein_.edge_weight.shape)

print("node_position: ", protein_.node_position.shape)

# 测试unsqueeze
edge_weight = protein_.edge_weight.unsqueeze(-1)
print("new_edge_weight: ", edge_weight.shape, edge_weight)



Graph before:  PackedProtein(batch_size=1, num_atoms=[1596], num_bonds=[2920], num_residues=[349])
Graph after:  PackedProtein(batch_size=1, num_atoms=[185], num_bonds=[3754], num_residues=[185])
node_feature:  torch.Size([185, 21])
edge_feature:  torch.Size([3754, 59])
edge_weight:  torch.Size([3754])
node_position:  torch.Size([185, 3])
new_edge_weight:  torch.Size([3754, 1]) tensor([[1.],
        [1.],
        [1.],
        ...,
        [1.],
        [1.],
        [1.]])


# Gearnet 流程

生成稀疏的邻接矩阵

In [6]:
node_in, node_out, relation = protein_.edge_list.t()
node_out = node_out * protein_.num_relation + relation
print(node_in, node_out)
print(node_in.shape, node_out.shape)

import torchdrug.utils as utils
adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), protein_.edge_weight,
                                    (protein_.num_node, protein_.num_node * protein_.num_relation))

print(adjacency)
print(adjacency.shape)

tensor([90, 85, 83,  ..., 75, 50, 51]) tensor([575, 547, 547,  ..., 749, 756, 756])
torch.Size([3754]) torch.Size([3754])
tensor(indices=tensor([[ 90,  85,  83,  ...,  75,  50,  51],
                       [575, 547, 547,  ..., 749, 756, 756]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       size=(185, 1295), nnz=3754, layout=torch.sparse_coo)
torch.Size([185, 1295])


# DGM模块

In [7]:
graph = protein_
print(graph)
print("graph_node_feature :",graph.node_feature.shape)
print("graph_edge_feature :",graph.edge_feature.shape)
print("\n")


input_dim = graph.node_feature.shape[-1]
embed_function = layers.MultiLayerPerceptron(input_dim, 
                                             hidden_dims = [512, 512], 
                                             short_cut=True, 
                                             batch_norm=True, 
                                             activation="relu")
k = 5


PackedProtein(batch_size=1, num_atoms=[185], num_bonds=[3754], num_residues=[185])
graph_node_feature : torch.Size([185, 21])
graph_edge_feature : torch.Size([3754, 59])




In [14]:
input = graph.node_feature.to(torch.float32)
H = embed_function(input)
#print(H)
#print(H.shape)

