In [1]:
from lib.datasets import ATPBind3D

from torchdrug import transforms

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [2]:
from torchdrug import core, layers, models
from torchdrug.layers import geometry
from torchdrug import transforms
import torch
from lib.tasks import NodePropertyPrediction
from lib.lr_scheduler import CosineAnnealingLR
from lib.disable_logger import DisableLogger
from lib.custom_models import LMGearNetModel

GPU = 2

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                   
                                                    edge_feature="gearnet")

lm_gearnet = LMGearNetModel(GPU, gearnet_hidden_dim_size=512, gearnet_hidden_dim_count=6)

task = NodePropertyPrediction(
    lm_gearnet,
    graph_construction_model=graph_construction_model,
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc", "mcc")
)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
with DisableLogger():
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                        gpus=[GPU], batch_size=1)

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
state_dict = torch.load("ResidueType_lmg_6_512_0.51948.pth")

lm_gearnet.gearnet.load_state_dict(state_dict)

<All keys matched successfully>

In [4]:
metrics = []
for i in range(10):
    solver.train(num_epoch=1)
    metrics.append(solver.evaluate("valid"))


18:27:15   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:27:15   Epoch 0 begin
18:27:18   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:27:18   binary cross entropy: 0.744193
18:27:57   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:27:57   binary cross entropy: 0.175487
18:28:36   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:28:36   binary cross entropy: 0.235156
18:29:15   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:29:15   binary cross entropy: 0.0679581
18:29:29   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:29:29   Epoch 0 end
18:29:29   duration: 2.58 mins
18:29:29   speed: 2.18 batch / sec
18:29:29   ETA: 0.00 secs
18:29:29   max GPU memory: 3814.7 MiB
18:29:29   ------------------------------
18:29:29   average binary cross entropy: 0.163202
18:29:29   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:29:29   Evaluate on valid
18:29:41   ------------------------------
18:29:41   mcc: 0.352142
18:29:41   micro_auroc: 0.864055
18:29:41   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:29:41   Epoch 1 begin
18:30:06   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
18:30:06   binary cross entro