<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Fine-Tuning/ProtBert_BFD_FineTuning_MS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**2. Load necessry libraries including huggingface transformers**

In [1]:
import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments, AutoModelForSequenceClassification
from torch.utils.data import Dataset
import os
import pandas as pd
import requests
from tqdm.auto import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import re
np.random.seed(42)

**3. Select the model you want to fine-tune**

In [2]:
model_name = 'Rostlab/prot_bert_bfd'

**4. Create the Membrane dataset class**

In [3]:
class DiseaseIdentificationDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, split="train", tokenizer_name='Rostlab/prot_bert_bfd', max_length=1024):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.datasetFolderPath = '../Data/ProtTrans_Finetune/'
        self.trainFilePath = os.path.join(self.datasetFolderPath, 'healthy_vs_covid_train.csv')
        self.testFilePath = os.path.join(self.datasetFolderPath, 'healthy_vs_covid_test.csv')

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, do_lower_case=False)

        if split=="train":
          self.seqs, self.labels = self.load_dataset(self.trainFilePath)
        else:
          self.seqs, self.labels = self.load_dataset(self.testFilePath)

        self.max_length = max_length

    def load_dataset(self,path):
        df = pd.read_csv(path,names=['input','label']).dropna()
        # df = df.loc[df['labels'].isin(["0","1", 0, 1])]
        # print(df)
        # self.labels_dic = {0:'Healthy',
        #                    1:'COVID'}

        df['labels'] = np.where(df['label']==1, 1, 0)
        
        seq = list(df['input'])
        label = list(df['labels'])

        assert len(seq) == len(label)
        return seq, label

    def __len__(self):   
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        seq = " ".join("".join(self.seqs[idx].split()))
        seq = re.sub(r"[UZOB]", "X", seq)

        seq_ids = self.tokenizer(seq, truncation=True, padding='max_length', max_length=self.max_length)
        sample = {key: torch.tensor(val) for key, val in seq_ids.items()}
        sample['labels'] = torch.tensor(self.labels[idx])     

        return sample

**4. Create the train / val / test datasets**

In [4]:
train_dataset = DiseaseIdentificationDataset(split="train", tokenizer_name=model_name, max_length=256) # max_length is only capped to speed-up example.
val_dataset = DiseaseIdentificationDataset(split="valid", tokenizer_name=model_name, max_length=256)
test_dataset = DiseaseIdentificationDataset(split="test", tokenizer_name=model_name, max_length=256)

**5. Define the evaluation metrics**

In [5]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

**6. Create the model**

In [6]:
def model_init():
  return AutoModelForSequenceClassification.from_pretrained(model_name)

**7. Define the training args and start the trainer**

In [7]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=1,              # total number of training epochs
    per_device_train_batch_size=1,   # batch size per device during training
    per_device_eval_batch_size=10,   # batch size for evaluation
    warmup_steps=1000,               # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,               # How often to print logs
    do_train=True,                   # Perform training
    do_eval=True,                    # Perform evaluation
    evaluation_strategy="epoch",     # evalute after each epoch
    gradient_accumulation_steps=64,  # total number of steps before back propagation
    # fp16=True,                       # Use mixed precision
    # fp16_opt_level="02",             # mixed precision mode
    run_name="ProBert-BFD-MS",       # experiment name
    seed=3                           # Seed for experiment reproducibility 3x3
)

trainer = Trainer(
    model_init=model_init,                # the instantiated 🤗 Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset=train_dataset,          # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
    compute_metrics = compute_metrics,    # evaluation metrics
)

trainer.train()

loading configuration file config.json from cache at /Users/joseph/.cache/huggingface/hub/models--Rostlab--prot_bert_bfd/snapshots/6c5c8a55a52ff08a664dfd584aa1773f125a0487/config.json
Model config BertConfig {
  "_name_or_path": "Rostlab/prot_bert_bfd",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 40000,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 30,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.24.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30
}

loading weights file pytorch_model.bin from cache at /Users/joseph/.cache/huggingface/hub/models--Rostlab--prot_bert_bfd/snapshots/6c5c8a55a52ff08a664dfd584aa1773f125a0487/pytorch_model.

  0%|          | 0/303 [00:00<?, ?it/s]

**8. Save the model**

In [10]:
trainer.save_model('prot_bert_bfd_huggingface_finetune/')

Saving model checkpoint to models/
Configuration saved in models/config.json
Model weights saved in models/pytorch_model.bin


**9. Check Tensorboard**

In [11]:
%load_ext tensorboard
%tensorboard --logdir logs

<IPython.core.display.Javascript object>

**10. Make predictions**

In [12]:
predictions, label_ids, metrics = trainer.predict(test_dataset)

***** Running Prediction *****
  Num examples = 1020
  Batch size = 10


In [13]:
idx = 0
sample_ground_truth = test_dataset.labels_dic[int(test_dataset[idx]['labels'])]
sample_predictions =  test_dataset.labels_dic[np.argmax(predictions[idx], axis=0)]
sample_sequence = test_dataset.tokenizer.decode(test_dataset[idx]['input_ids'], skip_special_tokens=True)

In [14]:
print("Sequence: {} \nGround Truth is: {}\nprediction is: {}".format(sample_sequence,
                                                                      sample_ground_truth,
                                                                      sample_predictions))

Sequence: M W P L V V V V L L G S A Y C G S A Q L I F N I T K S V E F T V C N T T V T I P C F V N N M E A K N I S E L Y V K W K F K G K D I F I F D G A Q H I S K P S E A F P S S K I S P S E L L H G I A S L K M D K R D A V I G N Y T C E V T E L S R E G E T I I E L K R R F V S W F S P N E N I L I V I F P I L A I L L F W G Q F G I L T L K Y K S S Y T K E K T I F L L V A G L M L T I I V I V G A I L F I P G E Y S T K N A C G L G L I V I P T A I L I L L Q Y C V F M M A L G M S S F T I A I L I L Q V L G H V L S V V G L 
Ground Truth is: Membrane
prediction is: Membrane
