In [22]:
from lib.datasets import ATPBind3D

from torchdrug import transforms

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [26]:
from torchdrug.utils import download
import torch
from torchdrug import models

md5_dict = {
    "mc": "c35402108f14e43f20feb475918f9c26",
    "distance": "a5b4781506b51a146a1b26c564917110",
    "dihedral": "3f4cc8a1d0401c4aea68bbac0ce9d990",
    "attr": "77ea524ffe0c11ec93a403696b1c80a9",
    "angle": "1f2c1bb27f8fdb3577e0476210a9692e"
}

def pretrained_gearnet(name):
    # Check https://zenodo.org/record/7593637 for the pretrained weights
    if name not in ["mc", "distance", "dihedral", "attr", "angle"]:
        raise ValueError("Unknown pretrained weights %s" % name)
    weights_path = download(
        "https://zenodo.org/record/7593637/files/%s_gearnet_edge.pth?download=1" % name,
        "data/weight",
        md5=md5_dict[name]
    )
    # load the weights from the .pth file
    state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
    
    gearnet = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7,
                         edge_input_dim=59, num_angle_bin=8,
                         batch_norm=False, concat_hidden=False, short_cut=False, readout="sum")

    gearnet.load_state_dict(state_dict)
    
    return gearnet


# This could take ~5min for first run; second run should be fast
gearnets = {}
for name in md5_dict.keys():
    gearnets[name] = pretrained_gearnet(name)

In [29]:
def _freeze_gearnet_edge(
    gearnet_edge, freeze_all=True, freeze_layer_count=-1
):
    if freeze_all:
        for param in gearnet_edge.parameters():
            param.requires_grad = False
    else:
        print("Freezing %d layers, total %d layers" % (freeze_layer_count, len(gearnet_edge.layers)))
        for layer in gearnet_edge.layers[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
        
        for layer in gearnet_edge.edge_layers[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
        
        for layer in gearnet_edge.batch_norms[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False


_freeze_gearnet_edge(gearnets['distance'], freeze_all=True, freeze_layer_count=1)
#_freeze_gearnet_edge(gearnets['attr'], freeze_all=False, freeze_layer_count=1)


In [None]:
from torchdrug import core, layers, models
from torchdrug.layers import geometry
import torch
from lib.tasks import NodePropertyPrediction
from lib.disable_logger import DisableLogger

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")


tasks = {}
for name, gn in gearnets.items():
    tasks[name] = NodePropertyPrediction(
        gn,
        graph_construction_model=graph_construction_model,
        normalization=False,
        num_mlp_layer=2,
        metric=("micro_auroc", "mcc")
    )

for name, task in tasks.items():
    if name in ['mc', 'distance']:
        continue
    optimizer = torch.optim.Adam(task.parameters(), lr=5e-4)
    print(name)
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                         log_interval=10000000,
                         gpus=[3], batch_size=1)
    metrics = []
    for i in range(10):
        solver.train(num_epoch=1)
        metrics.append(solver.evaluate("test"))
    print(metrics)

dihedral
16:16:59   Preprocess training set
16:17:02   {'batch_size': 1,
 'class': 'core.Engine',
 'gpus': [3],
 'gradient_interval': 1,
 'log_interval': 10000000,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.0005,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'NodePropertyPrediction',
          'criterion': 'bce',
          'graph_construction_model': {'class': 'layers.GraphConstruction',
                                       'edge_feature': 'gearnet',
                                       'edge_layers': [SpatialEdge(),
                                                       KNNEdge(),
                                                       SequentialEdge()],
                                       'node_layers': [AlphaCarbonNode()]},
          'metric': ('micro_auroc', 'mcc'),
          'model': {'activa