In [14]:
from lib.datasets import ATPBind3D

from torchdrug import transforms

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [16]:
from torchdrug.utils import download
import torch
from torchdrug import models

md5_dict = {
    "mc": "c35402108f14e43f20feb475918f9c26",
    "distance": "a5b4781506b51a146a1b26c564917110",
    "dihedral": "3f4cc8a1d0401c4aea68bbac0ce9d990",
    "attr": "77ea524ffe0c11ec93a403696b1c80a9",
    "angle": "1f2c1bb27f8fdb3577e0476210a9692e"
}

def pretrained_gearnet(name):
    # Check https://zenodo.org/record/7593637 for the pretrained weights
    if name not in ["mc", "distance", "dihedral", "attr", "angle"]:
        raise ValueError("Unknown pretrained weights %s" % name)
    weights_path = download(
        "https://zenodo.org/record/7593637/files/%s_gearnet_edge.pth?download=1" % name,
        "data/weight",
        md5=md5_dict[name]
    )
    # load the weights from the .pth file
    state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
    
    gearnet = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7,
                         edge_input_dim=59, num_angle_bin=8,
                         batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")

    gearnet.load_state_dict(state_dict)
    
    return gearnet


# This could take ~5min for first run; second run should be fast
gearnet = pretrained_gearnet("angle")

In [20]:
def _freeze_gearnet_edge(
    gearnet_edge, freeze_all=True, freeze_layer_count=-1
):
    if freeze_all:
        for param in gearnet_edge.parameters():
            param.requires_grad = False
    else:
        print("Freezing %d layers, total %d layers" % (freeze_layer_count, len(gearnet_edge.layers)))
        for layer in gearnet_edge.layers[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
        
        for layer in gearnet_edge.edge_layers[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
        
        for layer in gearnet_edge.batch_norms[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False


_freeze_gearnet_edge(gearnet, freeze_all=False, freeze_layer_count=3)


Freezing 3 layers, total 6 layers


In [21]:
from torchdrug import core, layers, models
from torchdrug.layers import geometry
import torch
from lib.tasks import NodePropertyPrediction
from lib.disable_logger import DisableLogger

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")


task = NodePropertyPrediction(
    gearnet,
    graph_construction_model=graph_construction_model,
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc", "mcc")
)

optimizer = torch.optim.Adam(task.parameters(), lr=5e-4)
with DisableLogger():
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                         log_interval=10000000,
                         gpus=[1], batch_size=1)

In [22]:
metrics = []
for i in range(10):
    solver.train(num_epoch=1)
    metrics.append(solver.evaluate("test"))

20:35:21   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:35:21   Epoch 0 begin
20:35:21   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:35:21   binary cross entropy: 0.694982
20:36:01   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:36:01   Epoch 0 end
20:36:01   duration: 41.57 secs
20:36:01   speed: 8.11 batch / sec
20:36:01   ETA: 0.00 secs
20:36:01   max GPU memory: 1032.1 MiB
20:36:01   ------------------------------
20:36:01   average binary cross entropy: 0.0698818
20:36:01   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:36:01   Evaluate on test
20:36:05   ------------------------------
20:36:05   mcc: 0.228477
20:36:05   micro_auroc: 0.75708
20:36:05   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:36:05   Epoch 1 begin
20:36:45   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:36:45   Epoch 1 end
20:36:45   duration: 44.48 secs
20:36:45   speed: 7.58 batch / sec
20:36:45   ETA: 0.00 secs
20:36:45   max GPU memory: 1032.1 MiB
20:36:45   ------------------------------
20:36:45   average binary cross entropy: 0.0169162
20:36:45   >>>>>>>>>>>>>>>>>>>>>>>>