In [1]:
from torchdrug import transforms
from torchdrug import data, core, layers, tasks, metrics, utils, models
from torchdrug.layers import functional
from torchdrug.core import Registry as R

import torch
from torch.utils import data as torch_data
from torch.nn import functional as F
from lib.tasks import NodePropertyPrediction

In [2]:
from lib.datasets import ATPBind

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind(atom_feature=None, bond_feature=None,
                  residue_feature="default", transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [346, 42, 41]
train samples: 346, valid samples: 42, test samples: 41


In [3]:
# freeze_bert in https://github.com/aws-samples/lm-gvp/blob/0b7a6d96486e2ee222929917570432296554cfe7/lmgvp/modules.py#L47

from transformers import BertModel, BertTokenizer

def _freeze_bert(
    bert_model: BertModel, freeze_bert=True, freeze_layer_count=-1
):
    """Freeze parameters in BertModel (in place)

    Args:
        bert_model: HuggingFace bert model
        freeze_bert: Bool whether or not to freeze the bert model
        freeze_layer_count: If freeze_bert, up to what layer to freeze.

    Returns:
        bert_model
    """
    if freeze_bert:
        # freeze the entire bert model
        for param in bert_model.parameters():
            param.requires_grad = False
    else:
        # freeze the embeddings
        for param in bert_model.embeddings.parameters():
            param.requires_grad = False
        if freeze_layer_count != -1:
            # freeze layers in bert_model.encoder
            for layer in bert_model.encoder.layer[:freeze_layer_count]:
                for param in layer.parameters():
                    param.requires_grad = False
    return None


In [4]:
# Cusom model Wrapping BERT: check https://torchdrug.ai/docs/notes/model.html
class BertWrapModel(torch.nn.Module, core.Configurable):
    def __init__(self):
        super().__init__()
        self.bert_tokenizer = BertTokenizer.from_pretrained(
            "Rostlab/prot_bert", do_lower_case=False)
        self.bert_model = BertModel.from_pretrained(
            "Rostlab/prot_bert").to('cuda')
        _freeze_bert(self.bert_model, freeze_bert=False, freeze_layer_count=29)
        self.input_dim = 21
        self.output_dim = self.bert_model.config.hidden_size

    def forward(self, graph, _, all_loss=None, metric=None):
        # print("graph: ", graph)
        # print("sequence: ", graph.to_sequence())
        input = [seq.replace('.', ' ') for seq in graph.to_sequence()]

        encoded_input = self.bert_tokenizer(
            input, return_tensors='pt').to('cuda')
        # print("Input size: ", encoded_input["input_ids"].size())
        x = self.bert_model(**encoded_input)
        # print("Output size just after model: ", x.last_hidden_state.size())
        
        # skip residue feature for [CLS] and [SEP], since they are not in the original sequence
        return {"residue_feature": torch.squeeze(x.last_hidden_state)[1:-1]}


In [5]:
bert_wrap_model = BertWrapModel()
bert_task = NodePropertyPrediction(
    bert_wrap_model, 
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc", "mcc")
)
optimizer = torch.optim.Adam(bert_task.parameters(), lr=1e-3)
solver = core.Engine(bert_task, train_set, valid_set, test_set, optimizer, batch_size=1, log_interval=1000, gpus=[2])

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


10:00:17   Preprocess training set
10:00:22   {'batch_size': 1,
 'class': 'core.Engine',
 'gpus': [2],
 'gradient_interval': 1,
 'log_interval': 1000,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.001,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'NodePropertyPrediction',
          'criterion': 'bce',
          'graph_construction_model': None,
          'metric': ('micro_auroc', 'mcc'),
          'model': {'class': 'BertWrapModel'},
          'normalization': False,
          'num_class': None,
          'num_mlp_layer': 2,
          'verbose': 0},
 'test_set': {'class': 'dataset.Subset',
              'dataset': {'atom_feature': None,
                          'bond_feature': None,
                          'class': 'ATPBind',
                          'path': None,
                          'residue_fe

In [9]:
solver.train(num_epoch=1)
solver.evaluate("valid")
solver.evaluate("test")


10:04:06   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
10:04:06   Epoch 1 begin
10:04:28   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
10:04:28   Epoch 1 end
10:04:28   duration: 3.56 mins
10:04:28   speed: 1.62 batch / sec
10:04:28   ETA: 0.00 secs
10:04:28   max GPU memory: 1911.9 MiB
10:04:28   ------------------------------
10:04:28   average binary cross entropy: 0.122717
10:04:28   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
10:04:28   Evaluate on valid
10:04:31   ------------------------------
10:04:31   matthews correlation coefficient: 0.46477
10:04:31   micro_auroc: 0.909262
10:04:31   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
10:04:31   Evaluate on test
10:04:33   ------------------------------
10:04:33   matthews correlation coefficient: 0.431995
10:04:33   micro_auroc: 0.909152


{'micro_auroc': tensor(0.9092, device='cuda:2'),
 'matthews correlation coefficient': 0.4319947093100603}