In [1]:
from torchdrug import transforms
from torchdrug import data, core, layers, tasks, metrics, utils, models
from torchdrug.layers import functional
from torchdrug.core import Registry as R

import torch
from torch.utils import data as torch_data
from torch.nn import functional as F
from lib.tasks import NodePropertyPrediction

In [2]:
from lib.datasets import ATPBind

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind(atom_feature=None, bond_feature=None,
                  residue_feature="default", transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [346, 42, 41]
train samples: 346, valid samples: 42, test samples: 41


In [3]:
# freeze_bert in https://github.com/aws-samples/lm-gvp/blob/0b7a6d96486e2ee222929917570432296554cfe7/lmgvp/modules.py#L47

from transformers import BertModel, BertTokenizer

def _freeze_bert(
    bert_model: BertModel, freeze_bert=True, freeze_layer_count=-1
):
    """Freeze parameters in BertModel (in place)

    Args:
        bert_model: HuggingFace bert model
        freeze_bert: Bool whether or not to freeze the bert model
        freeze_layer_count: If freeze_bert, up to what layer to freeze.

    Returns:
        bert_model
    """
    if freeze_bert:
        # freeze the entire bert model
        for param in bert_model.parameters():
            param.requires_grad = False
    else:
        # freeze the embeddings
        for param in bert_model.embeddings.parameters():
            param.requires_grad = False
        if freeze_layer_count != -1:
            # freeze layers in bert_model.encoder
            for layer in bert_model.encoder.layer[:freeze_layer_count]:
                for param in layer.parameters():
                    param.requires_grad = False
    return None


In [4]:
# Cusom model Wrapping BERT: check https://torchdrug.ai/docs/notes/model.html
class BertWrapModel(torch.nn.Module, core.Configurable):
    def __init__(self):
        super().__init__()
        self.bert_tokenizer = BertTokenizer.from_pretrained(
            "Rostlab/prot_bert", do_lower_case=False)
        self.bert_model = BertModel.from_pretrained(
            "Rostlab/prot_bert").to('cuda')
        _freeze_bert(self.bert_model, freeze_bert=False, freeze_layer_count=29)
        self.input_dim = 21
        self.output_dim = self.bert_model.config.hidden_size

    def forward(self, graph, _, all_loss=None, metric=None):
        # print("graph: ", graph)
        # print("sequence: ", graph.to_sequence())
        input = [seq.replace('.', ' ') for seq in graph.to_sequence()]

        encoded_input = self.bert_tokenizer(
            input, return_tensors='pt').to('cuda')
        # print("Input size: ", encoded_input["input_ids"].size())
        x = self.bert_model(**encoded_input)
        # print("Output size just after model: ", x.last_hidden_state.size())
        
        # skip residue feature for [CLS] and [SEP], since they are not in the original sequence
        return {"residue_feature": torch.squeeze(x.last_hidden_state)[1:-1]}


In [5]:
bert_wrap_model = BertWrapModel()
bert_task = NodePropertyPrediction(
    bert_wrap_model, 
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc", "micro_auprc", "macro_auprc", "macro_auroc")
)
optimizer = torch.optim.Adam(bert_task.parameters(), lr=1e-3)
solver = core.Engine(bert_task, train_set, valid_set, test_set, optimizer, batch_size=1, log_interval=1000, gpus=[2])

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


11:32:26   Preprocess training set
11:32:32   {'batch_size': 1,
 'class': 'core.Engine',
 'gpus': [2],
 'gradient_interval': 1,
 'log_interval': 1000,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.001,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'NodePropertyPrediction',
          'criterion': 'bce',
          'metric': ('micro_auroc',
                     'micro_auprc',
                     'macro_auprc',
                     'macro_auroc'),
          'model': {'class': 'BertWrapModel'},
          'normalization': False,
          'num_class': None,
          'num_mlp_layer': 2,
          'verbose': 0},
 'test_set': {'class': 'dataset.Subset',
              'dataset': {'atom_feature': None,
                          'bond_feature': None,
                          'class': 'ATPBind',
                    

In [6]:
solver.train(num_epoch=1)
solver.evaluate("valid")
solver.evaluate("test")


11:32:32   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
11:32:32   Epoch 0 begin
11:32:32   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
11:32:32   binary cross entropy: 0.680496
11:32:54   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
11:32:54   Epoch 0 end
11:32:54   duration: 22.20 secs
11:32:54   speed: 15.59 batch / sec
11:32:54   ETA: 0.00 secs
11:32:54   max GPU memory: 1911.9 MiB
11:32:54   ------------------------------
11:32:54   average binary cross entropy: 0.160703
11:32:54   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
11:32:54   Evaluate on valid
11:32:56   ------------------------------
11:32:56   macro_auprc: 0.450954
11:32:56   macro_auroc: 0.869341
11:32:56   micro_auprc: 0.481848
11:32:56   micro_auroc: 0.899687
11:32:56   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
11:32:56   Evaluate on test
11:32:58   ------------------------------
11:32:58   macro_auprc: 0.505588
11:32:58   macro_auroc: 0.880655
11:32:58   micro_auprc: 0.49297
11:32:58   micro_auroc: 0.894856


{'micro_auroc': tensor(0.8949, device='cuda:2'),
 'micro_auprc': tensor(0.4930, device='cuda:2'),
 'macro_auprc': tensor(0.5056, device='cuda:2'),
 'macro_auroc': tensor(0.8807, device='cuda:2')}