# Employ transfer learning with new LMs for IDR prediction
Dataset from [Disprot](https://www.disprot.org/download) (actually [older version with annotation](https://idpcentral.org/caid/data/1/reference/disprot-disorder.txt)). Methods used from [ProtTrans](https://github.com/agemagician/ProtTrans).

Based on [PytorchLightning implementation](https://github.com/agemagician/ProtTrans/blob/master/Fine-Tuning/ProtBert-BFD-FineTuning-PyTorchLightning-MS.ipynb).

In [1]:
!nvidia-smi

Mon May 16 19:07:00 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:0A:00.0 Off |                  N/A |
| 33%   49C    P8    11W / 250W |      6MiB / 11178MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:0B:00.0 Off |                  N/A |
| 27%   49C    P8    12W / 250W |      6MiB / 11178MiB |      0%      Defaul

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, RandomSampler
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer, seed_everything
from torchmetrics import Accuracy

from transformers import T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer
from transformers import XLNetModel, XLNetTokenizer
from transformers import AlbertModel, AlbertTokenizer

from torchnlp.encoders import LabelEncoder
from torchnlp.datasets.dataset import Dataset
from torchnlp.utils import collate_tensors

from test_tube import HyperOptArgumentParser
import os
import re
import gc
from datetime import datetime
import logging as log
import glob

In [3]:
torch.cuda.is_available()

True

In [4]:
torch.cuda.device_count()

4

In [5]:
# Select the model
model_name = "Rostlab/prot_t5_xl_uniref50"

In [6]:
class DisorderDataset:
    """
    Loads the Dataset from the txt files passed to the parser.
    """

    def collate_lists(self, seqs: list, labels: list) -> list[dict]:
        """ Converts each line into a dictionary. """
        collated_dataset = []
        for i in range(len(seqs)):
            collated_dataset.append({"seq": seqs[i], "label": labels[i]})
        return collated_dataset

    def load_dataset(self, path):
        seqs = []
        labels = []
        with open(path) as file_handler:
            i = -1
            for line in file_handler:
                i += 1
                if i < 10 or len(line) > 1536:
                    continue
                i_offset = i - 10
                if i_offset % 7 == 1:
                    # Map rare amino acids
                    seqs.append(" ".join(list(re.sub(r"[UZOB]", "X", line.strip()))))
                elif i_offset % 7 == 2:
                    labels.append(line.strip())

        assert len(seqs) == len(labels)
        return Dataset(self.collate_lists(seqs, labels))

In [7]:
class ProtTransDisorderPredictor(pl.LightningModule):
    """
    ProtTrans model to predict intrinsical disorder in sequences.

    :param hp: ArgumentParser containing the hyperparameters.
    """

    def __init__(self, hp) -> None:
        super(ProtTransDisorderPredictor, self).__init__()
        self.hp = hp
        self.batch_size = self.hp.batch_size

        self.model_name = model_name

        self.dataset = DisorderDataset()

        self.metric_acc = Accuracy(ignore_index=-100)

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if self.hp.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = self.hp.nr_frozen_epochs

    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        if "t5" in self.model_name:
            self.tokenizer = T5Tokenizer.from_pretrained(self.model_name, do_lower_case=False)
            self.LM = T5EncoderModel.from_pretrained(self.model_name)
        elif "albert" in self.model_name:
            self.tokenizer = AlbertTokenizer.from_pretrained(self.model_name, do_lower_case=False)
            self.LM = AlbertModel.from_pretrained(self.model_name)
        elif "bert" in self.model_name:
            self.tokenizer = BertTokenizer.from_pretrained(self.model_name, do_lower_case=False)
            self.LM = BertModel.from_pretrained(self.model_name)
        elif "xlnet" in self.model_name:
            self.tokenizer = XLNetTokenizer.from_pretrained(self.model_name, do_lower_case=False)
            self.LM = XLNetModel.from_pretrained(self.model_name)
        else:
            print("Unkown model name")

        if self.hp.gradient_checkpointing:
            self.LM.gradient_checkpointing_enable()

        # Label Encoder
        self.label_encoder = LabelEncoder(self.hp.label_set.split(","), reserved_labels=[], unknown_index=None)
        self.hidden_features = 1024

        # https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html
        # https://www.dotlayer.org/en/training-rnn-using-pytorch/
        self.lstm = nn.LSTM(
            input_size=self.LM.config.hidden_size,
            hidden_size=self.hidden_features,
            num_layers=1,
            bidirectional=self.hp.bidirectional_lstm,
            batch_first=True,
        )

        self.hidden2label = nn.Linear(2 * self.hidden_features if self.hp.bidirectional_lstm else self.hidden_features,
                                      self.label_encoder.vocab_size)

    def __build_loss(self):
        """ Initializes the loss function/s. """
        self._loss = nn.CrossEntropyLoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.LM.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.LM.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.
        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample], prepare_target=False)
            model_out = self.forward(**model_input)
            predicted_labels = [
                self.label_encoder.index_to_token[prediction]
                for prediction in model_out
            ]
            sample["predicted_label"] = predicted_labels[0]

        return sample

    def forward(self, input_ids, attention_mask, length):
        """ Usual pytorch forward function.
        input ids is already padded
        Returns:
            model outputs
        """
        input_ids = torch.tensor(input_ids, device=self.device)
        attention_mask = torch.tensor(attention_mask, device=self.device)

        padded_word_embeddings = self.LM(input_ids, attention_mask)[0]

        # We pack the padded sequence to improve the computational speed during training
        pack_padded_sequences_vectors = pack_padded_sequence(padded_word_embeddings, length, batch_first=True)

        lstm_out, _ = self.lstm(pack_padded_sequences_vectors)
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=self.hp.max_length)

        tag_space = self.hidden2label(lstm_out)

        tag_scores = F.softmax(tag_space, dim=1)

        return tag_scores.transpose(-1, 1)

    def loss(self, predictions: torch.tensor, targets: torch.tensor) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: a tensor [batch_size x 1] with model predictions
        :param targets: Label values [batch_size]
        Returns:
            torch.tensor with loss value.
        """
        return self._loss(predictions, targets)

    def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param prepare_target: also load label
        :param sample: list of dictionaries.

        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)

        inputs = self.tokenizer.batch_encode_plus(sample["seq"],
                                                  add_special_tokens=True,
                                                  padding='max_length',
                                                  #is_split_into_words=True,
                                                  return_length=True,
                                                  truncation=True,
                                                  max_length=self.hp.max_length)

        if not prepare_target:
            return inputs, {}

        # Prepare target:
        try:
            labels = [self.label_encoder.batch_encode(l) for l in sample["label"]]
            labels.append(torch.empty(self.hp.max_length))
            padded_sequences_labels = pad_sequence(labels, batch_first=True, padding_value=-100)
            return inputs, padded_sequences_labels[:-1]
        except RuntimeError:
            print(sample["label"])
            raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """
        Runs one training step. This usually consists in the forward function followed
            by the loss function.

        :param batch: The output of your dataloader.
        :param batch_nb: Integer displaying which batch this is
        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss = self.loss(model_out, targets)

        return {'loss': loss}

    def validation_step(self, batch: tuple, batch_nb: int, *args, **kwargs):
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, y = batch
        model_out = self.forward(**inputs)
        y_hat = torch.argmax(model_out, dim=1)

        self.log('val_loss', self.loss(model_out, y))
        self.log('val_acc', self.metric_acc(y_hat, y))

    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs):
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, y = batch
        model_out = self.forward(**inputs)
        y_hat = torch.argmax(model_out, dim=1)

        self.log('test_loss', self.loss(model_out, y))
        self.log('test_acc', self.metric_acc(y_hat, y))

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [
            {"params": self.hidden2label.parameters()},
            {"params": self.lstm.parameters()},
            {
                "params": self.LM.parameters(),
                "lr": self.hp.encoder_learning_rate,
            },
        ]
        optimizer = optim.Adam(parameters, lr=self.hp.learning_rate)
        return optimizer

    def on_train_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()

    def __retrieve_dataset(self, train=False, val=False, test=False):
        """ Retrieves task specific dataset """
        if train:
            return self.dataset.load_dataset(self.hp.train_file)
        elif val:
            return self.dataset.load_dataset(self.hp.val_file)
        elif test:
            return self.dataset.load_dataset(self.hp.test_file)
        else:
            print('Incorrect dataset split')

    def train_dataloader(self) -> DataLoader:
        """ Function that loads the train set. """
        _train_dataset = self.__retrieve_dataset(train=True)
        return DataLoader(
            dataset=_train_dataset,
            sampler=RandomSampler(_train_dataset),
            batch_size=self.hp.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hp.loader_workers,
        )

    def val_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        _dev_dataset = self.__retrieve_dataset(val=True)
        return DataLoader(
            dataset=_dev_dataset,
            batch_size=self.hp.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hp.loader_workers,
        )

    def test_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        _test_dataset = self.__retrieve_dataset(test=True)
        return DataLoader(
            dataset=_test_dataset,
            batch_size=self.hp.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hp.loader_workers,
        )

    @classmethod
    def add_model_specific_args(
        cls, parser: HyperOptArgumentParser
    ) -> HyperOptArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters.
        :param parser: HyperOptArgumentParser obj
        Returns:
            - updated parser
        """
        parser.opt_list(
            "--max_length",
            default=1536,
            type=int,
            help="Maximum sequence length.",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=5e-06,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=3e-05,
            type=float,
            help="Classification head learning rate.",
        )
        parser.opt_list(
            "--nr_frozen_epochs",
            default=1,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
            tunable=True,
            options=[0, 1, 2, 3, 4, 5],
        )
        # Data Args:
        parser.add_argument(
            "--label_set",
            default="0,1",
            type=str,
            help="Classification labels set.",
        )
        parser.add_argument(
            "--train_file",
            default="../data/disprot/flDPnn_Training_Annotation.txt",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--val_file",
            default="../data/disprot/flDPnn_Validation_Annotation.txt",
            type=str,
            help="Path to the file containing the validation data.",
        )
        parser.add_argument(
            "--test_file",
            default="../data/disprot/flDPnn_Test_Annotation.txt",
            type=str,
            help="Path to the file containing the test data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=8,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        parser.add_argument(
            "--gradient_checkpointing",
            default=True,
            type=bool,
            help="Enable or disable gradient checkpointing which use the cpu memory \
                with the gpu memory to store the model.",
        )
        parser.add_argument(
            "--bidirectional_lstm",
            default=True,
            type=bool,
            help="Enable bidirectional LSTM in the decoder.",
        )
        return parser

In [14]:
parser = HyperOptArgumentParser(
    strategy="random_search",
    description="ProtTrans IDR Predictor",
    add_help=True,
)
parser.add_argument("--seed", type=int, default=3, help="Training seed.")
parser.add_argument(
    "--save_top_k",
    default=1,
    type=int,
    help="The best k models according to the quantity monitored will be saved.",
)
# Early Stopping
parser.add_argument(
    "--monitor", default="val_acc", type=str, help="Quantity to monitor."
)
parser.add_argument(
    "--metric_mode",
    default="max",
    type=str,
    help="If we want to min/max the monitored quantity.",
    choices=["min", "max"],
)
parser.add_argument(
    "--patience",
    default=5,
    type=int,
    help=(
        "Number of epochs with no improvement "
        "after which training will be stopped."
    ),
)
parser.add_argument(
    "--min_epochs",
    default=1,
    type=int,
    help="Limits training to a minimum number of epochs",
)
parser.add_argument(
    "--max_epochs",
    default=100,
    type=int,
    help="Limits training to a max number number of epochs",
)

# Batching
parser.add_argument(
    "--batch_size", default=1, type=int, help="Batch size to be used."
)
parser.add_argument(
    "--accumulate_grad_batches",
    default=64,
    type=int,
    help=(
        "Accumulated gradients runs K small batches of size N before "
        "doing a backwards pass."
    ),
)

# gpu/tpu args
parser.add_argument("--accelerator", type=str, default="auto", help="Which hardware accelerator to use", choices=["cpu", "gpu", "tpu"])
parser.add_argument("--devices", type=str, default="auto", help="How many devices to use")
parser.add_argument("--strategy", type=str, default="ddp", help="Which parallelization strategy to use", choices=["dp", "ddp", "ddp_spawn", "ddp2"])
parser.add_argument(
    "--limit_val_batches",
    default=1.0,
    type=float,
    help=(
        "If you don't want to use the entire validation set (for debugging or "
        "if it's huge), set how much of the validation set you want to use with this flag."
    ),
)

# mixed precision
parser.add_argument("--precision", type=int, default="32", help="full precision or mixed precision mode")

# each LightningModule defines arguments relevant to it
parser = ProtTransDisorderPredictor.add_model_specific_args(parser)
hparams = parser.parse_known_args()[0]

## Main Training

In [15]:
def setup_logger() -> TensorBoardLogger:
    """ Function that sets the TestTubeLogger to be used. """
    now = datetime.now()
    dt_string = now.strftime("%d-%m-%Y--%H-%M-%S")

    return TensorBoardLogger(
        save_dir="logs/",
        version=dt_string,
        name="lightning_logs",
    )

logger = setup_logger()

In [16]:
"""
Main training routine specific for this project
:param hparams:
"""
seed_everything(hparams.seed)

model = ProtTransDisorderPredictor(hparams)

# Init model checkpoint path and saver
ckpt_path = os.path.join(
    logger.save_dir,
    logger.name,
    f"version_{logger.version}",
    "checkpoints",
)
checkpoint_callback = ModelCheckpoint(
    dirpath=ckpt_path,
    filename="{epoch}-{val_loss:.2f}-{val_acc:.2f}",
    save_top_k=hparams.save_top_k,
    monitor=hparams.monitor,
    every_n_epochs=1,
    mode=hparams.metric_mode,
)

early_stop_callback = EarlyStopping(
    monitor=hparams.monitor,
    min_delta=0.0,
    patience=hparams.patience,
    verbose=True,
    mode=hparams.metric_mode,
)

trainer = Trainer(
    # fast_dev_run=True,
    accelerator=hparams.accelerator,
    devices=hparams.devices,
    strategy='dp',
    logger=logger,
    max_epochs=hparams.max_epochs,
    min_epochs=hparams.min_epochs,
    accumulate_grad_batches=hparams.accumulate_grad_batches,
    limit_val_batches=hparams.limit_val_batches,
    callbacks=[checkpoint_callback, early_stop_callback],
    precision=hparams.precision,
    deterministic=False,
)

Global seed set to 3
Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.19.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.9.layer.2.layer_norm.weight', 'decoder.block.12.layer.0.SelfAttention.q.weight', 'decoder.block.22.layer.0.SelfAttention.q.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.2.DenseReluDense.wi.weight', 'decoder.block.16.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'decoder.block.20.layer.0.SelfAttention.k.weight', 'decoder.block.12.layer.2.layer_norm.weight', 'decoder.block.13.layer.2.DenseReluDense.wo.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.22.layer.2.layer_norm.weight', 'decoder.block.7.layer.0.SelfAttentio

In [None]:
gc.collect()

In [17]:
trainer.fit(model)

4305

In [18]:
#%reload_ext tensorboard
#%tensorboard --logdir f"{logger.save_dir}{logger.name}/" --port=8009 --bind_all

In [19]:
best_checkpoint_path = glob.glob(ckpt_path + "/*")[0]
print(best_checkpoint_path)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name         | Type             | Params
--------------------------------------------------
0 | metric_acc   | Accuracy         | 0     
1 | LM           | T5EncoderModel   | 1.2 B 
2 | lstm         | LSTM             | 8.4 M 
3 | hidden2label | Linear           | 2.0 K 
4 | _loss        | CrossEntropyLoss | 0     
--------------------------------------------------
8.4 M     Trainable params
1.2 B     Non-trainable params
1.2 B     Total params
2,433.081 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Metric val_acc improved. New best score: 0.640


RuntimeError: CUDA out of memory. Tried to allocate 124.00 MiB (GPU 0; 10.92 GiB total capacity; 9.87 GiB already allocated; 90.62 MiB free; 10.08 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
trainer.resume_from_checkpoint = best_checkpoint_path

In [236]:
best_checkpoint_path = glob.glob(f"{ckpt_path}/*")[0]
print(best_checkpoint_path)

IndexError: list index out of range

## Predict new sequence

In [None]:
seq = "MSDNDDIEVESDEEQPRFQSAADKRAHHNALERKRRDHIKDSFHSLRDSVPSLQGEKASRAQILDKATEYIQYMRRKNHTHQQDIDDLKRQNALLEQQVRALEKARSSAQLQTNYPSSDNSLYTNAKGSTISAFDGGSDSSSESEPEEPQSRKKLRMEAS"
label = "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000011111111111111111111111111111111111111111111111111111111110000000000"

preds = model.predict({"seq": seq})

print("Sequence label is: {} - prediction is: {}".format(label, preds['predicted_label']))

In [None]:
seq = "MSDNDDIEVESDEEQPRFQSAADKRAHHNALERKRRDHIKDSFHSLRDSVPSLQGEKASRAQILDKATEYIQYMRRKNHTHQQDIDDLKRQNALLEQQVRALEKARSSAQLQTNYPSSDNSLYTNAKGSTISAFDGGSDSSSESEPEEPQSRKKLRMEAS"
label = "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000011111111111111111111111111111111111111111111111111111111110000000000"

predictions = model.predict({"seq": seq})

print("Sequence label is: {} - prediction is: {}".format(label, predictions['predicted_label']))

### Hyperparameter Tuning
Using RayTune, I tested different models (ProtT5-XL-encoder, ProtBert, ESM-1b) with different parameters in two rounds (results in folder 'raytune_results' on raven, first 'tune_esm_prottrans_' then 'tune_rates').

The best model I found was: `{'model_name': 'Rostlab/prot_t5_xl_half_uniref50-enc', 'rnn': 'gru', 'crf_after_rnn': False, 'learning_rate': 2e-04, 'encoder_learning_rate': 1e-05, 'hidden_features': 1024}`

We still need to do timing tests, and I could imagine an ESM-1b model without dropout linear net to also work well while being faster. Test this!