In [1]:
from lib.datasets import ATPBind3D

from torchdrug import transforms

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [2]:
from torchdrug.utils import download
import torch
from torchdrug import models
possible_weights = [
    "mc_gearnet_edge.pth",
    "distance_gearnet_edge.pth",
    "dihedral_gearnet_edge.pth",
    "attr_gearnet_edge.pth",
    "angle_gearnet_edge.pth"
]
md5s = [
    "c35402108f14e43f20feb475918f9c26",
    "a5b4781506b51a146a1b26c564917110",
    "3f4cc8a1d0401c4aea68bbac0ce9d990",
    "77ea524ffe0c11ec93a403696b1c80a9",
    "1f2c1bb27f8fdb3577e0476210a9692e"
]

def pretrained_gearnet(index):
    weights_path = download(
        "https://zenodo.org/record/7593637/files/%s?download=1" % possible_weights[index],
        "data/weight",
        md5=md5s[index]
    )
    # load the weights from the .pth file
    state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
    
    gearnet = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7,
                         edge_input_dim=59, num_angle_bin=8,
                         batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")

    gearnet.load_state_dict(state_dict)
    
    return gearnet


In [3]:
def _freeze_gearnet_edge(
    gearnet_edge, freeze_all=True, freeze_layer_count=-1
):
    if freeze_all:
        for param in gearnet_edge.parameters():
            param.requires_grad = False
    else:
        print("Freezing %d layers, total %d layers" % (freeze_layer_count, len(gearnet_edge.layers)))
        for layer in gearnet_edge.layers[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
        
        for layer in gearnet_edge.edge_layers[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
        
        for layer in gearnet_edge.batch_norms[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
    return None

In [4]:
from torchdrug import core, layers, models
from torchdrug.layers import geometry
import torch
from lib.tasks import NodePropertyPrediction

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

gearnet = pretrained_gearnet(1)
_freeze_gearnet_edge(gearnet, freeze_all=False, freeze_layer_count=5)


Freezing 5 layers, total 6 layers


In [5]:
task = NodePropertyPrediction(
    gearnet,
    graph_construction_model=graph_construction_model,
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc")
)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[1], batch_size=1)

16:32:04   Preprocess training set
16:32:10   {'batch_size': 1,
 'class': 'core.Engine',
 'gpus': [1],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.0001,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'NodePropertyPrediction',
          'criterion': 'bce',
          'graph_construction_model': {'class': 'layers.GraphConstruction',
                                       'edge_feature': 'gearnet',
                                       'edge_layers': [SpatialEdge(),
                                                       KNNEdge(),
                                                       SequentialEdge()],
                                       'node_layers': [AlphaCarbonNode()]},
          'metric': 'micro_auroc',
          'model': {'activation': 'relu',
        

In [6]:
solver.train(num_epoch=2)
solver.evaluate("valid")
solver.evaluate("test")

16:32:10   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:32:10   Epoch 0 begin
16:32:11   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:32:11   binary cross entropy: 0.708787
16:32:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:32:24   binary cross entropy: 0.181673
16:32:37   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:32:37   binary cross entropy: 0.273708
16:32:51   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:32:51   binary cross entropy: 0.135656
16:32:56   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:32:56   Epoch 0 end
16:32:56   duration: 46.39 secs
16:32:56   speed: 7.26 batch / sec
16:32:56   ETA: 46.39 secs
16:32:56   max GPU memory: 628.8 MiB
16:32:56   ------------------------------
16:32:56   average binary cross entropy: 0.181512
16:32:56   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:32:56   Epoch 1 begin
16:33:05   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:33:05   binary cross entropy: 0.114254
16:33:18   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:33:18   binary cross entropy: 0.144979
16:33:31   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:33:31   binary cross entropy:

{'micro_auroc': tensor(0.8013, device='cuda:1')}