In [1]:
from lib.datasets import ATPBind3D

from torchdrug import transforms

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [2]:
from torchdrug import core, layers, models
from torchdrug.layers import geometry
from torchdrug import transforms
import torch
from lib.tasks import NodePropertyPrediction
from lib.lr_scheduler import CosineAnnealingLR
from lib.disable_logger import DisableLogger

GPU = 2

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                   
                                                    edge_feature="gearnet")

gearnet = models.GearNet(input_dim=21, hidden_dims=[512] * 9, num_relation=7,
                   edge_input_dim=59, num_angle_bin=8,
                   batch_norm=True, concat_hidden=True, short_cut=True, readout="sum").cuda(GPU)

task = NodePropertyPrediction(
    gearnet,
    graph_construction_model=graph_construction_model,
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc", "mcc")
)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-5)
with DisableLogger():
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                        gpus=[GPU], batch_size=1)


In [3]:
state_dict = torch.load("ResidueType_9_512_0.38827.pth")

new_state_dict = {}
for k in state_dict['model'].keys():
    if k.startswith("model"):
        new_state_dict[k.replace("model.", "")] = state_dict['model'][k]

gearnet.load_state_dict(new_state_dict)


<All keys matched successfully>

In [4]:
metrics = []
for i in range(10):
    solver.train(num_epoch=1)
    metrics.append(solver.evaluate("test"))


16:25:18   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:25:18   Epoch 0 begin
16:25:21   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:25:21   binary cross entropy: 0.771378
16:25:57   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:25:57   binary cross entropy: 0.125302
16:26:32   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:26:32   binary cross entropy: 0.211382
16:27:10   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:27:10   binary cross entropy: 0.069091
16:27:23   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:27:23   Epoch 0 end
16:27:23   duration: 3.30 mins
16:27:23   speed: 1.70 batch / sec
16:27:23   ETA: 0.00 secs
16:27:23   max GPU memory: 2628.2 MiB
16:27:23   ------------------------------
16:27:23   average binary cross entropy: 0.15316
16:27:23   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:27:23   Evaluate on test
16:27:33   ------------------------------
16:27:33   mcc: 0.438988
16:27:33   micro_auroc: 0.898515
16:27:33   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:27:33   Epoch 1 begin
16:27:56   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:27:56   binary cross entropy:

In [None]:
def _freeze_gearnet_edge(
    gearnet_edge, freeze_all=True, freeze_layer_count=-1
):
    if freeze_all:
        for param in gearnet_edge.parameters():
            param.requires_grad = False
    else:
        print("Freezing %d layers, total %d layers" % (freeze_layer_count, len(gearnet_edge.layers)))
        for layer in gearnet_edge.layers[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
        
        for layer in gearnet_edge.edge_layers[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False
        
        for layer in gearnet_edge.batch_norms[:freeze_layer_count]:
            for param in layer.parameters():
                param.requires_grad = False


_freeze_gearnet_edge(gearnet, freeze_all=False, freeze_layer_count=3)