In [1]:
from lib.datasets import ATPBind3D

from torchdrug import transforms
from torchdrug import data, core, layers, tasks, metrics, utils, models
from torchdrug.layers import functional
from torchdrug.core import Registry as R

import torch
from torch.utils import data as torch_data
from torch.nn import functional as F
from lib.tasks import NodePropertyPrediction


In [2]:
truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [3]:
from torchdrug import core, layers
from torchdrug.layers import geometry
import torch
from lib.tasks import NodePropertyPrediction

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")
task = NodePropertyPrediction(
    models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7,
                   edge_input_dim=59, num_angle_bin=8,
                   batch_norm=True, concat_hidden=True, short_cut=True, readout="sum").cuda(),
    graph_construction_model=graph_construction_model,
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc")
)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=1)


13:25:03   Preprocess training set
13:25:04   {'batch_size': 1,
 'class': 'core.Engine',
 'gpus': [0],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.0001,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'NodePropertyPrediction',
          'criterion': 'bce',
          'graph_construction_model': {'class': 'layers.GraphConstruction',
                                       'edge_feature': 'gearnet',
                                       'edge_layers': [SpatialEdge(),
                                                       KNNEdge(),
                                                       SequentialEdge()],
                                       'node_layers': [AlphaCarbonNode()]},
          'metric': 'micro_auroc',
          'model': {'activation': 'relu',
        

In [4]:
solver.train(num_epoch=10)
solver.evaluate("valid")


13:25:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:25:04   Epoch 0 begin
13:25:05   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:25:05   binary cross entropy: 0.715257
13:25:26   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:25:26   binary cross entropy: 0.179901
13:25:47   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:25:47   binary cross entropy: 0.276355
13:26:10   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:26:10   binary cross entropy: 0.15135
13:26:18   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:26:18   Epoch 0 end
13:26:18   duration: 1.23 mins
13:26:18   speed: 4.57 batch / sec
13:26:18   ETA: 11.06 mins
13:26:18   max GPU memory: 1581.6 MiB
13:26:18   ------------------------------
13:26:18   average binary cross entropy: 0.183451
13:26:18   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:26:18   Epoch 1 begin
13:26:31   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:26:31   binary cross entropy: 0.0931136
13:26:53   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:26:53   binary cross entropy: 0.145226
13:27:15   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:27:15   binary cross entropy:

{'micro_auroc': tensor(0.8721, device='cuda:0')}

In [5]:
solver.evaluate("test")

13:41:11   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:41:11   Evaluate on test
13:41:16   ------------------------------
13:41:16   micro_auroc: 0.846686


{'micro_auroc': tensor(0.8467, device='cuda:0')}