In [1]:
from lib.datasets import ATPBind3D

from torchdrug import transforms

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [2]:
from torchdrug import core, layers, models
from torchdrug.layers import geometry
from torchdrug import transforms
import torch
from lib.tasks import NodePropertyPrediction
from lib.lr_scheduler import CosineAnnealingLR
from lib.disable_logger import DisableLogger

GPU = 2

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                   
                                                    edge_feature="gearnet")

gearnet = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7,
                   edge_input_dim=59, num_angle_bin=8,
                   batch_norm=True, concat_hidden=True, short_cut=True, readout="sum").cuda(GPU)

task = NodePropertyPrediction(
    gearnet,
    graph_construction_model=graph_construction_model,
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc", "mcc")
)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-5)
with DisableLogger():
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                        gpus=[GPU], batch_size=1)


In [3]:
state_dict = torch.load("ResidueType_6_512_0.38472.pth")

new_state_dict = {}
for k in state_dict['model'].keys():
    if k.startswith("model"):
        new_state_dict[k.replace("model.", "")] = state_dict['model'][k]

gearnet.load_state_dict(new_state_dict)


<All keys matched successfully>

In [4]:
metrics = []
for i in range(10):
    solver.train(num_epoch=1)
    metrics.append(solver.evaluate("test"))


17:45:07   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:45:07   Epoch 0 begin
17:45:11   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:45:11   binary cross entropy: 0.697501
17:45:40   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:45:40   binary cross entropy: 0.139899
17:46:09   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:46:09   binary cross entropy: 0.220732
17:46:39   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:46:39   binary cross entropy: 0.0758589
17:46:49   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:46:49   Epoch 0 end
17:46:49   duration: 1.96 mins
17:46:49   speed: 2.86 batch / sec
17:46:49   ETA: 0.00 secs
17:46:49   max GPU memory: 1581.7 MiB
17:46:49   ------------------------------
17:46:49   average binary cross entropy: 0.154527
17:46:49   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:46:49   Evaluate on test
17:46:58   ------------------------------
17:46:58   mcc: 0.377613
17:46:58   micro_auroc: 0.900009
17:46:58   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:46:58   Epoch 1 begin
17:47:17   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:47:17   binary cross entrop