In [11]:
from lib.datasets import ATPBind3D

from torchdrug import transforms
from torchdrug import data, core, layers, tasks, metrics, utils, models
from torchdrug.layers import functional
from torchdrug.core import Registry as R

import torch
from torch.utils import data as torch_data
from torch.nn import functional as F
from lib.tasks import NodePropertyPrediction


In [12]:
truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [13]:
from transformers import BertModel, BertTokenizer

bert_model_global = BertModel.from_pretrained("Rostlab/prot_bert").to('cuda:2')

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
from transformers import BertModel, BertTokenizer


def _freeze_bert(
    bert_model: BertModel, freeze_bert=True, freeze_layer_count=-1
):
    """Freeze parameters in BertModel (in place)

    Args:
        bert_model: HuggingFace bert model
        freeze_bert: Bool whether or not to freeze the bert model
        freeze_layer_count: If freeze_bert, up to what layer to freeze.

    Returns:
        bert_model
    """
    if freeze_bert:
        # freeze the entire bert model
        for param in bert_model.parameters():
            param.requires_grad = False
    else:
        # freeze the embeddings
        for param in bert_model.embeddings.parameters():
            param.requires_grad = False
        if freeze_layer_count != -1:
            # freeze layers in bert_model.encoder
            for layer in bert_model.encoder.layer[:freeze_layer_count]:
                for param in layer.parameters():
                    param.requires_grad = False
    return None


def separate_alphabets(text):
    separated_text = ""
    for char in text:
        if char.isalpha():
            separated_text += char + " "
    return separated_text.strip()

class LMGearNetModel(torch.nn.Module, core.Configurable):
    def __init__(self):
        super().__init__()
        self.bert_tokenizer = BertTokenizer.from_pretrained(
            "Rostlab/prot_bert", do_lower_case=False)
        self.bert_model = bert_model_global
        _freeze_bert(self.bert_model, freeze_bert=False, freeze_layer_count=29)
        self.gearnet = models.GearNet(
            input_dim=1024, #self.bert_model.config.hidden_size,
            hidden_dims=[512, 512, 512, 512, 512, 512],
            num_relation=7,
            edge_input_dim=59,
            num_angle_bin=8,
            batch_norm=True,
            concat_hidden=True,
            short_cut=True,
            readout="sum"
        ).to('cuda:2')
        self.input_dim = 21
        self.output_dim = self.gearnet.output_dim

    def forward(self, graph, _, all_loss=None, metric=None):
        # print("at forward, graph: ", graph)
        # print("sequence: ", graph.to_sequence())
        input = [separate_alphabets(seq) for seq in graph.to_sequence()]

        encoded_input = self.bert_tokenizer(
            input, return_tensors='pt').to('cuda:2')
        # print("Input size: ", encoded_input["input_ids"].size())
        x = self.bert_model(**encoded_input)
        # print("Output size just after bert model: ", x.last_hidden_state.size())
        
        # skip residue feature for [CLS] and [SEP], since they are not in the original sequence
        lm_output = x.last_hidden_state.squeeze()[1:-1]
        
        # print(f'lm_output shape: {lm_output.shape}')
        gearnet_output = self.gearnet(graph, lm_output)
        return gearnet_output
    

lm_gearnet = LMGearNetModel()


In [15]:
from torchdrug import core, layers
from torchdrug.layers import geometry
import torch
from lib.disable_logger import DisableLogger

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

task = NodePropertyPrediction(
    lm_gearnet, 
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc", "mcc"),
    graph_construction_model=graph_construction_model,
)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
with DisableLogger():
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer, batch_size=1, log_interval=1000, gpus=[2])


In [20]:
solver.train(num_epoch=1)


21:21:08   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
21:21:08   Epoch 2 begin
21:21:42   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
21:21:42   binary cross entropy: 0.240249
21:22:42   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
21:22:42   Epoch 2 end
21:22:42   duration: 2.57 mins
21:22:42   speed: 2.19 batch / sec
21:22:42   ETA: 0.00 secs
21:22:42   max GPU memory: 5132.6 MiB
21:22:42   ------------------------------
21:22:42   average binary cross entropy: 0.0889349


In [23]:
solver.evaluate("test")

21:25:17   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
21:25:17   Evaluate on test
21:25:25   ------------------------------
21:25:25   mcc: 0.446728
21:25:25   micro_auroc: 0.883612


{'micro_auroc': tensor(0.8836, device='cuda:2'), 'mcc': 0.44672766097239364}