# ProtTrans (Protein Transformer)

From https://github.com/agemagician/ProtTrans : ProtTrans is providing state of the art pre-trained models for proteins. ProtTrans was trained on thousands of GPUs from Summit and hundreds of Google TPUs using various Transformer models.

Have a look at our paper [ProtTrans: cracking the language of life’s code through self-supervised deep learning and high performance computing](https://doi.org/10.1109/TPAMI.2021.3095381) for more information about this work.

## Lightning's ProtBERT multi-task classifier
**This notebook llustrates how to make a "multi-task" classifier using ProtBERT with Pytorch-lightning**. This notebook is an updated version of offical notebook [here](https://github.com/agemagician/ProtTrans/tree/master/Fine-Tuning). Compared to the original notebook, this notebook

* uses up-to-date API of Huggingface's transformer (4.18.0) and Pytorch lightning (1.6.0)
* uses the new bug-fixing dataset (details here: https://github.com/agemagician/ProtTrans/issues/74 )
* supports multi-task automatically (by detecting the columns in the input csv)
* supports wandb logger with Kaggle's secret adds-on so that you can compare various experimental results
* uses Kaggle's P100 free GPU so no need to subscribe Colab-pro :)

## RAM warning
Kaggle notebook with GPU has just 13GB RAM. With max protein length of 512, the notebook will use most of the RAM, but sometimes it will call out for more RAM and make the notebook crash. We could reduce the protein length if you want just to test run the program.

![](https://github.com/agemagician/ProtTrans/raw/master/images/transformers_attention.png)

In [None]:
# Model's hyperparameters
BATCH_SIZE = 1
ACCUM_BATCH = 32
MAX_PROTEIN_STR_LEN = 512
ENCODER_LR = 5e-6
GENERAL_LR = 3e-05
N_FROZEN_EPOCH = 1
GRADIENT_CHECKPOINT = False # True if RAM < 16GB
SEED = 43

# File's hyperparameters
FILE_DIR = "/kaggle/working/data/" 
COL_NAMES = ['input','loc','membrane'] # optional, but no need

# Trainer's hyperparameters
CPU_WORKERS = 2
N_CHECKPOINTS = 1 # number of top-k models you want to save
MONITOR_METRIC = "val_loss"
MONITOR_MODE = "min"
MIN_EPOCHS = 1
MAX_EPOCHS = 7 # Increase to 10-20 epochs to maximize accuracy (needs more GPU hours on Kaggle)
PATIENCE = MAX_EPOCHS
NUM_GPU = 1
PRECISION = 32
AMP_BACKEND = "native"

wandb_flag=True # Need a free wandb.ai account (retrieve a personal key there)
wandb_name="Run 4. %d-maxlen" % MAX_PROTEIN_STR_LEN
wandb_project="ProtBERT multi-task DeepLoc"

print(wandb_name)

In [None]:
# Optional: for APEX mixed precision, but now Pytorch has a native mixed precision support already
if PRECISION==16:
    print('installing Apex...')
    !git clone https://github.com/NVIDIA/apex
    !cd apex && pip install -v --disable-pip-version-check --no-cache-dir ./
    AMP_BACKEND = 'apex'
    
!pip uninstall -y torchtext
!pip install -q transformers==4.18.0 pytorch-lightning==1.6.0
!pip install -q pytorch-nlp torchmetrics 

print(AMP_BACKEND)

In [None]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, RandomSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer, seed_everything
from torchmetrics import Accuracy

from torchnlp.encoders import LabelEncoder
from torchnlp.datasets.dataset import Dataset
from torchnlp.utils import collate_tensors

import pandas as pd
from argparse import ArgumentParser
import os
import re
import requests
from tqdm.auto import tqdm
from datetime import datetime
from collections import OrderedDict
import logging as log
import numpy as np
import glob

In [None]:
import transformers
from IPython.display import display
print(torch.__version__)
print(pl.__version__)
print(transformers.__version__)

HyperOptArgumentParser = ArgumentParser

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")

if wandb_flag:
    os.environ["WANDB_API_KEY"] = secret_value_0
    import wandb
    from pytorch_lightning.loggers import WandbLogger
    wandb_logger = WandbLogger(name=wandb_name,project=wandb_project)

    print('wandb is running!')
else:
    wandb_logger = None

In [None]:
class Loc_dataset():
    """
    Loads the Dataset from the csv files passed to the parser.
    :param hparam: HyperOptArgumentParser obj containg the path to the data files.
    :param train: flag to return the train set.
    :param val: flag to return the validation set.
    :param test: flag to return the test set.
    Returns:
        - Training Dataset, Development Dataset, Testing Dataset
    """
    def  __init__(self, col_names = None) -> None:
        self.downloadDeeplocDataset()
        self.col_names = col_names
        
    def downloadDeeplocDataset(self):
        # The old dataset is obsoleted as discussed in https://github.com/agemagician/ProtTrans/issues/74
#         deeplocDatasetTrainUrl = 'https://www.dropbox.com/s/vgdqcl4vzqm9as0/deeploc_per_protein_train.csv?dl=1'
#         deeplocDatasetValidUrl = 'https://www.dropbox.com/s/jfzuokrym7nflkp/deeploc_per_protein_test.csv?dl=1'
        deeplocDatasetTrainUrl = 'https://rostlab.org/~deepppi/deeploc_data/deeploc_our_train_set.csv'
        deeplocDatasetValidUrl = 'https://rostlab.org/~deepppi/deeploc_data/deeploc_our_val_set.csv'
        deeplocDatasetTestUrl = 'https://rostlab.org/~deepppi/deeploc_data/deeploc_test_set.csv'
        deeplocDatasetHardUrl = 'https://rostlab.org/~deepppi/deeploc_data/setHARD.csv'
    
        datasetFolderPath = 'data/'
        self.trainFilePath = os.path.join(datasetFolderPath, 'deeploc_per_protein_train.csv')
        self.validFilePath = os.path.join(datasetFolderPath, 'deeploc_per_protein_valid.csv')
        self.testFilePath = os.path.join(datasetFolderPath, 'deeploc_per_protein_test.csv')
        self.hardFilePath = os.path.join(datasetFolderPath, 'deeploc_per_protein_hard.csv')


        if not os.path.exists(datasetFolderPath):
            os.makedirs(datasetFolderPath)

        def download_file(url, filename):
            response = requests.get(url, stream=True)
            with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                              total=int(response.headers.get('content-length', 0)),
                              desc=filename) as fout:
                for chunk in response.iter_content(chunk_size=4096):
                    fout.write(chunk)

        if not os.path.exists(self.trainFilePath):
            download_file(deeplocDatasetTrainUrl, self.trainFilePath)

        if not os.path.exists(self.testFilePath):
            download_file(deeplocDatasetTestUrl, self.testFilePath)
        
        if not os.path.exists(self.validFilePath):
            download_file(deeplocDatasetValidUrl, self.validFilePath)

        if not os.path.exists(self.hardFilePath):
            download_file(deeplocDatasetHardUrl, self.hardFilePath)
            
    def collate_lists(self, seq: list, label: list) -> dict:
        """ Converts each line into a dictionary. """
        collated_dataset = []
        for i in range(len(seq)):
            collated_dataset.append({"seq": str(seq[i]), "label": str(label[i]).split()})
        return collated_dataset
    
    def protein_df_stat(self, df):
        print(df.shape)
        all_strings = []
        for i in tqdm(range(len(df))):
            all_strings += df[df.columns[0]].iloc[i].split()
        print('total alphabets', len(all_strings))
        print()
    
        all_alphabets = np.unique(all_strings)
        print(all_alphabets, len(all_alphabets))
        print()
    
        all_strings_np = np.array(all_strings)
        num_alphabets_dict = OrderedDict() 
        for a in all_alphabets:
            num_alphabets_dict[a] = (all_strings_np==a).sum()
            print(a, num_alphabets_dict[a])
            
        tasks_list = df.columns[1:].values
        tasks_info_dict = OrderedDict()
        for t in tasks_list:
            tasks_info_dict[t] = {"n_labels": len(df[t].unique()),
                                  "label_names": df[t].unique()}
        print(tasks_info_dict)
    
    def load_df(self, df_name='train', max_len=MAX_PROTEIN_STR_LEN, print_stat=False):
        if df_name=='train':
            path = self.trainFilePath
        elif df_name=='valid' or df_name=='val':
            path = self.validFilePath
        elif df_name=='hard':
            path = self.hardFilePath
        else:
            path = self.testFilePath
        
        # assume no columns' names are given in csv file
        # otherwise use pd.read_csv(path,names=col_names,skiprows=0) instead
        df = pd.read_csv(path)

        if self.col_names is not None: 
            assert len(self.col_names) == len(df.columns)
        else:
            self.col_names = ['input']
            self.col_names += ['task_%d' % i for i in range(len(df.columns[1:]))]
        df.columns = self.col_names
            
        df2 = df.copy()
        df2['len'] = df[df.columns[0]].apply(lambda x: len(x.split()))
        
        if print_stat:
            print('original df')
            self.protein_df_stat(df)
            
        df = df[df2.len <= max_len]
        
        if print_stat:
            print('truncated df')
            self.protein_df_stat(df)
        
        return df
    
    def load_dataset(self,df_name='train', max_len=MAX_PROTEIN_STR_LEN, print_stat=False):
        
        df = self.load_df(df_name=df_name, max_len=max_len, print_stat=print_stat)
        
        seq = list(df[df.columns[0]])
        
        tasks_list = df.columns[1:].values
        label = []
        for i in range(len(df)):
            label_i_str = " ".join([df[t].values[i] for t in tasks_list])
            label.append(label_i_str)
    
        # Make sure there is a space between every token, and map rarely amino acids
        seq = [" ".join("".join(sample.split())) for sample in seq]
        seq = [re.sub(r"[UZOB]", "X", sample) for sample in seq]
        
#         print(len(seq), len(label), label[:10])
        assert len(seq) == len(label)
        return Dataset(self.collate_lists(seq, label))

In [None]:
dataset = Loc_dataset()
train_dataset = dataset.load_dataset('train')
test_dataset = dataset.load_dataset('test')
valid_dataset = dataset.load_dataset('val')
hard_dataset = dataset.load_dataset('hard')

In [None]:
# use print_stat = True to see dataset's basic statistics (see example below)
train_df = dataset.load_df('train', print_stat=False)
train_df.head()

In [None]:
test_df = dataset.load_df('test', print_stat=True)

# 3. Create the ProtBert pytorch lighting class

In [None]:
from transformers import BertTokenizer, BertModel

class ProtBertBFDClassifier(pl.LightningModule):
    """
    # https://github.com/minimalist-nlp/lightning-text-classification.git
    
    Sample model to show how to use BERT to classify sentences.
    
    :param hparam: ArgumentParser containing the hyperparameters.
    """

    def __init__(self, hparam) -> None:
        super(ProtBertBFDClassifier, self).__init__()
        self.hparam = hparam
        self.batch_size = self.hparam.batch_size

        self.model_name = "Rostlab/prot_bert_bfd"
        
        # TOFIX:
        self.dataset = Loc_dataset(col_names=COL_NAMES) #self.hparam.col_names)
        self.metric_acc = Accuracy()

        # build multi-tasks model -> need (multi) tasks_info
        self.tasks_info_dict = self.extract_tasks_info()
        
        self.label_encoders_dict = self.create_label_encoders()
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if self.hparam.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = self.hparam.nr_frozen_epochs
    
    def create_label_encoders(self) -> OrderedDict:
        label_encoders_dict = OrderedDict()
        
        for task in self.tasks_info_dict.keys():
            label_encoders_dict[task] = LabelEncoder(self.tasks_info_dict[task]["label_names"], 
                                                  reserved_labels=[], 
                                                  unknown_index=None)
        return label_encoders_dict
        
    def extract_tasks_info(self) -> OrderedDict:
#         tasks_df = pd.read_csv(self.hparam.train_csv)
        
        tasks_df = self.dataset.load_df('train')
        tasks_list = tasks_df.columns[1:].values # ALWAYS assuming the first column is about input, not task
        tasks_list = [t.strip() for t in tasks_list]
        tasks_info_dict = OrderedDict()
        for t in tasks_list:
            tasks_info_dict[t] = {"n_labels": len(tasks_df[t].unique()),
                                  "label_names": tasks_df[t].unique()}
        return tasks_info_dict
        
    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.ProtBertBFD = BertModel.from_pretrained(self.model_name,
                                                    #  gradient_checkpointing=self.hparam.gradient_checkpointing
                                                     )
        if self.hparam.gradient_checkpointing:
            self.ProtBertBFD.gradient_checkpointing_enable() # HF >= 4.17
        self.encoder_features = 1024
        self.pooled_encoder_features = self.encoder_features*4 # 4 pooling strategies by default

        # Tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name, do_lower_case=False)

        # Classification head, one for each task
        self.classification_heads_dict = nn.ModuleDict()
        for i,task in enumerate(self.tasks_info_dict.keys()):
            self.classification_heads_dict[task] = nn.Sequential(nn.Linear(self.pooled_encoder_features, 
                                                                        self.label_encoders_dict[task].vocab_size),
                                                              nn.Tanh())

            # NOTE: as lighting does not auto transfer nn.ModuleDict()/ModuleList() to cuda()
            self.classification_heads_dict[task] = self.classification_heads_dict[task].to(self.ProtBertBFD.device)

    def __build_loss(self) -> None:
        """ Initializes the loss function/s. """
        self.task_loss_functions_dict = OrderedDict()
        for t in self.tasks_info_dict.keys():
            # TODO: add customization
            self.task_loss_functions_dict[t] = nn.CrossEntropyLoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.ProtBertBFD.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.ProtBertBFD.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.
        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample], prepare_target=False)
            task_predictions_dict = self.forward(**model_input)
            predicted_labels_dict = OrderedDict()
            for t in self.tasks_info_dict.keys():
                logits = task_predictions_dict[t]["logits"].numpy()
                predicted_labels_dict[t] = [self.label_encoders_dict[t].index_to_token[prediction]
                                            for prediction in np.argmax(logits, axis=1)
                                           ][0]
            sample["predicted_label"] = predicted_labels_dict

        return sample
    
    # https://github.com/UKPLab/sentence-transformers/blob/eb39d0199508149b9d32c1677ee9953a84757ae4/sentence_transformers/models/Pooling.py
    def pool_strategy(self, features,
                      pool_cls=True, pool_max=True, pool_mean=True,
                      pool_mean_sqrt=True):
        token_embeddings = features['token_embeddings']
        cls_token = features['cls_token_embeddings']
        attention_mask = features['attention_mask']

        ## Pooling strategy
        output_vectors = []
        if pool_cls:
            output_vectors.append(cls_token)
        if pool_max:
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
            max_over_time = torch.max(token_embeddings, 1)[0]
            output_vectors.append(max_over_time)
        if pool_mean or pool_mean_sqrt:
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

            #If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
            if 'token_weights_sum' in features:
                sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size())
            else:
                sum_mask = input_mask_expanded.sum(1)

            sum_mask = torch.clamp(sum_mask, min=1e-9)

            if pool_mean:
                output_vectors.append(sum_embeddings / sum_mask)
            if pool_mean_sqrt:
                output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))

        output_vector = torch.cat(output_vectors, 1)
        return output_vector
    
    def forward(self, input_ids, token_type_ids, attention_mask):
        """ Usual pytorch forward function.
        :param tokens: text sequences [batch_size x src_seq_len]
        :param lengths: source lengths [batch_size]
        Returns:
            Dictionary with model outputs (e.g: logits)
        """
        input_ids = torch.tensor(input_ids, device=self.device)
        attention_mask = torch.tensor(attention_mask,device=self.device)

        word_embeddings = self.ProtBertBFD(input_ids,
                                           attention_mask)[0]

        pooling = self.pool_strategy({"token_embeddings": word_embeddings,
                                      "cls_token_embeddings": word_embeddings[:, 0],
                                      "attention_mask": attention_mask,
                                      })
        
        
        
        task_predictions_dict = OrderedDict()
        for i,t in enumerate(self.tasks_info_dict.keys()):
            task_predictions_dict[t] = {"logits": self.classification_heads_dict[t](pooling)}
        return task_predictions_dict

    def compute_loss(self, predictions: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: model specific output. Must contain a key 'logits' with
            a tensor [batch_size x 1] with model predictions
        :param labels: Label values [batch_size]
        Returns:
            torch.tensor with loss value.
        """
        
        self.task_losses_dict = OrderedDict()
        loss = torch.tensor(0.0, dtype=torch.float32, requires_grad=True).to(self.ProtBertBFD.device)
        for i,t in enumerate(self.tasks_info_dict.keys()):
            self.task_losses_dict[t] = self.task_loss_functions_dict[t](predictions[t]["logits"], targets[t]["labels"])
            loss += self.task_losses_dict[t]
        
        return loss 

    def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.
        
        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        # https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.utils.html # collate_tensors
        sample = collate_tensors(sample)

        inputs = self.tokenizer.batch_encode_plus(sample["seq"],
                                                  add_special_tokens=True,
                                                  padding=True,
                                                  truncation=True,
                                                  max_length=self.hparam.max_length)

        if not prepare_target:
            return inputs, {}

        # Prepare target:
        try:
            targets = OrderedDict()
            for i,t in enumerate(self.tasks_info_dict.keys()):
                # i and t will corresponded by construction of ordered_dict
                targets[t] = {"labels": self.label_encoders_dict[t].batch_encode(sample["label"][i])}
            return inputs, targets
        except RuntimeError:
            print(sample["label"])
            raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is
        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_train = self.compute_loss(model_out, targets)

        output = OrderedDict(
            {"loss": loss_train})
        
        for t in self.tasks_info_dict.keys():
            y = targets[t]["labels"]
            y_hat = model_out[t]["logits"]
        
            labels_hat = torch.argmax(y_hat, dim=1)
            train_acc = self.metric_acc(labels_hat, y)
            output["train_acc_"+ t] = train_acc
            output["train_loss_"+ t] = self.task_losses_dict[t]
            
        self.log_dict(output, prog_bar=True)
        return output

    def validation_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch

        model_out = self.forward(**inputs)
        loss_val = self.compute_loss(model_out, targets)
        
        output = OrderedDict({"val_loss": loss_val})
        for t in self.tasks_info_dict.keys():
            y = targets[t]["labels"]
            y_hat = model_out[t]["logits"]
        
            labels_hat = torch.argmax(y_hat, dim=1)
            val_acc = self.metric_acc(labels_hat, y)
            output["val_acc_"+ t] = val_acc
        
        self.log_dict(output)
        
        return output
    

    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_test = self.compute_loss(model_out, targets)

        output = OrderedDict({"test_loss": loss_test})
        for t in self.tasks_info_dict.keys():
            y = targets[t]["labels"]
            y_hat = model_out[t]["logits"]
        
            labels_hat = torch.argmax(y_hat, dim=1)
            test_acc = self.metric_acc(labels_hat, y)
            output["test_acc_"+ t] = test_acc
            
        self.log_dict(output)
        return output

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [{"params": self.classification_heads_dict[t].parameters()}
                                 for i,t in enumerate(self.tasks_info_dict.keys())]
        
        parameters += [{"params": self.ProtBertBFD.parameters(),
                        "lr": self.hparam.encoder_learning_rate,}]
        
        optimizer = optim.Adam(parameters, lr=self.hparam.learning_rate)
        return [optimizer], []

    def on_train_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()
            if self.current_epoch + 1 == self.nr_frozen_epochs:
                print('\n Unfreeze the encoder!! \n')

    def __retrieve_dataset(self, train=True, val=True, test=True):
        """ Retrieves task specific dataset """
        
        if train:
            return self.dataset.load_dataset('train')#[:64] # in case you want a quick test
        elif val:
            return self.dataset.load_dataset('val')#[:64]
        elif test:
            return self.dataset.load_dataset('test')
        else:
            print('hard dataset')
            return self.dataset.load_dataset('hard')

    def train_dataloader(self) -> DataLoader:
        """ Function that loads the train set. """
        self._train_dataset = self.__retrieve_dataset(val=False, test=False)
        return DataLoader(
            dataset=self._train_dataset,
            sampler=RandomSampler(self._train_dataset),
            batch_size=self.hparam.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparam.loader_workers,
        )

    def val_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        self._dev_dataset = self.__retrieve_dataset(train=False, test=False)
        return DataLoader(
            dataset=self._dev_dataset,
            batch_size=self.hparam.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparam.loader_workers,
        )

    def test_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        self._test_dataset = self.__retrieve_dataset(train=False, val=False)
        return DataLoader(
            dataset=self._test_dataset,
            batch_size=self.hparam.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparam.loader_workers,
        )

    @classmethod
    def add_model_specific_args(
        cls, parser: HyperOptArgumentParser
    ) -> HyperOptArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters. 
        :param parser: HyperOptArgumentParser obj
        Returns:
            - updated parser
        """
        parser.add_argument(
            "--max_length",
            default=MAX_PROTEIN_STR_LEN,
            type=int,
            help="Maximum sequence length.",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=ENCODER_LR,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=GENERAL_LR,
            type=float,
            help="Classification head learning rate.",
        )
        parser.add_argument(
            "--nr_frozen_epochs",
            default=N_FROZEN_EPOCH,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
#             tunable=True,
            choices=[0, 1, 2, 3, 4, 5],
        )
        
        # DATA Argument
        parser.add_argument(
            "--train_csv",
            default=FILE_DIR + "deeploc_per_protein_train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_csv",
            default=FILE_DIR + "deeploc_per_protein_valid.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_csv",
            default=FILE_DIR + "deeploc_per_protein_test.csv",
            type=str,
            help="Path to the file containing the test data.",
        )
        
        parser.add_argument(
            "--loader_workers",
            default=CPU_WORKERS,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        parser.add_argument(
            "--gradient_checkpointing",
            default=GRADIENT_CHECKPOINT,
            type=bool,
            help="Enable or disable gradient checkpointing which use the cpu memory \
                with the gpu memory to store the model.",
        )
        return parser

In [None]:
# these are project-wide arguments
parser = HyperOptArgumentParser(
#     strategy="random_search",
    description="Minimalist ProtBERT Classifier",
    add_help=True,
)
parser.add_argument("--seed", type=int, default=SEED, help="Training seed.")
parser.add_argument(
    "--save_top_k",
    default=N_CHECKPOINTS,
    type=int,
    help="The best k models according to the quantity monitored will be saved.",
)
# Early Stopping
parser.add_argument(
    "--monitor", default=MONITOR_METRIC, type=str, help="Quantity to monitor."
)
parser.add_argument(
    "--metric_mode",
    default=MONITOR_MODE,
    type=str,
    help="If we want to min/max the monitored quantity.",
    choices=["auto", "min", "max"],
)
parser.add_argument(
    "--patience",
    default=PATIENCE,
    type=int,
    help=(
        "Number of epochs with no improvement "
        "after which training will be stopped."
    ),
)
parser.add_argument(
    "--min_epochs",
    default=MIN_EPOCHS,
    type=int,
    help="Limits training to a minimum number of epochs",
)
parser.add_argument(
    "--max_epochs",
    default=MAX_EPOCHS,
    type=int,
    help="Limits training to a max number number of epochs",
)

# Batching
parser.add_argument(
    "--batch_size", default=BATCH_SIZE, type=int, help="Batch size to be used."
)
parser.add_argument(
    "--accumulate_grad_batches",
    default=ACCUM_BATCH,
    type=int,
    help=(
        "Accumulated gradients runs K small batches of size N before "
        "doing a backwards pass."
    ),
)

# gpu/tpu args
parser.add_argument("--gpus", type=int, default=NUM_GPU, help="How many gpus")

# mixed precision
parser.add_argument("--precision", type=int, default=PRECISION, help="full precision or mixed precision mode")

# each LightningModule defines arguments relevant to it
parser = ProtBertBFDClassifier.add_model_specific_args(parser)
hparam = parser.parse_known_args()[0]

In [None]:
import gc
gc.collect()
hparam

In [None]:
"""
Main training routine specific for this project
:param hparam:
"""
seed_everything(hparam.seed)

# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = ProtBertBFDClassifier(hparam)

# ------------------------
# 2 INIT EARLY STOPPING
# ------------------------
early_stop_callback = EarlyStopping(
    monitor=hparam.monitor,
    min_delta=0.0,
    patience=hparam.patience,
    verbose=True,
    mode=hparam.metric_mode,
)

# --------------------------------
# 3 INIT MODEL CHECKPOINT CALLBACK
# -------------------------------
# initialize Model Checkpoint Saver
checkpoint_callback = ModelCheckpoint(
    dirpath='./',
    filename="{epoch}-{val_loss:.2f}-{val_acc:.2f}",
    save_top_k=hparam.save_top_k,
    verbose=True,
    monitor=hparam.monitor,
    # period=1,
    mode=hparam.metric_mode,
)


In [None]:
# ------------------------
# 4 INIT TRAINER
# ------------------------

parallel_strategy = "ddp" if hparam.gpus > 1 else "dp" # ddp-spwan- tpu

trainer = Trainer(
    gpus=hparam.gpus,
    logger=wandb_logger,
    strategy=parallel_strategy,
    max_epochs=hparam.max_epochs,
    min_epochs=hparam.min_epochs,
    accumulate_grad_batches=hparam.accumulate_grad_batches,
#     callbacks = [checkpoint_callback, early_stop_callback],
    precision=hparam.precision,
    amp_backend=AMP_BACKEND, 
#     amp_level='O2',#hparam.amp_level, # optional: for Apex backend
    deterministic=True,
    num_sanity_val_steps=32 # check 32 valid-batch before start
)

In [None]:
# ------------------------
# 6 START TRAINING
# ------------------------
# with batch=2, its 13 min/epoch on P100
# Remember to set MAX_EPOCHS to 10-20 epochs to maximize accuracy (needs more GPU hours on Kaggle), in this notebook we run just 5 epochs

import gc
gc.collect()

trainer.fit(model)

# Test

In [None]:
gc.collect()
trainer.test(model)

In [None]:
# model = model.load_from_checkpoint(best_checkpoint_path, hparam=hparam)
gc.collect()
model.eval()
model.freeze()

In [None]:
sample = {
  "seq": "M S T D T G V S L P S Y E E D Q G S K L I R K A K E A P F V P V G I A G F A A I V A Y G L Y K L K S R G N T K M S I H L I H M R V A A Q G F V V G A M T V G M G Y S M Y R E F W A K P K P",
}
predictions = model.cpu().predict(sample)

print("Sequence Localization Ground Truth is: {} - prediction is: {}".format('Mitochondrion',predictions['predicted_label']))

In [None]:
sample = {
  "seq": "M R C L P V F I I L L L L I P S A P S V D A Q P T T K D D V P L A S L H D N A K R A L Q M F W N K R D C C P A K L L C C N P",
}

predictions = model.cpu().predict(sample)

print("Sequence Localization Ground Truth is: {} - prediction is: {}".format('Extracellular',predictions['predicted_label']))

In [None]:
torch.save({"ProtBertLoc": model.state_dict(),
#             "optim": model.optimizers.state_dict(),
            },
            "./ProtBertLoc.pt",
        )
checkpoint = torch.load("./ProtBertLoc.pt")
model.load_state_dict(checkpoint['ProtBertLoc'])

In [None]:
# may also use this save/load command
# ckpt_path='./'
# best_checkpoint_path = glob.glob(ckpt_path + "/*")[0]
# print(best_checkpoint_path)

# trainer.resume_from_checkpoint = best_checkpoint_path