In [1]:
!pip install transformers
!pip install transformers[torch]
!pip install accelerate --upgrade



In [1]:
import os
import pandas as pd
import numpy as np
import torch
import re

from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import Dataset, DataLoader

torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

from transformers import Trainer, TrainingArguments, AdamW, BertTokenizer, BertForMaskedLM, AutoTokenizer, AutoModelForMaskedLM

cuda


In [2]:
class amp_data():
    def __init__(self, df, tokenizer_name='Rostlab/prot_bert_bfd', max_len=200):
        # Initialize the tokenizer and maximum sequence length
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, do_lower_case=False)
        self.max_len = max_len
        # Extract the sequences from the dataframe
        self.seqs = list(df['aa_seq'])

    def __len__(self):
        # Return the number of sequences in the dataset
        return len(self.seqs)

    def __getitem__(self, idx):
        # Get the sequence at the given index
        seq = " ".join("".join(self.seqs[idx].split()))
        # Tokenize the sequence and pad/truncate it to the maximum length
        seq_ids = self.tokenizer(seq, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')

        # Extract the input_ids and attention_mask from the tokenized sequence
        input_ids = seq_ids.input_ids[0]
        attention_mask = seq_ids.attention_mask[0]
        # Create a copy of the input_ids to use as labels for masked language modeling
        masked_lm_labels = input_ids.clone()

        # Create a random mask with the same shape as the input_ids
        probability_mask = torch.rand(input_ids.shape)
        # Select 15% of the tokens at random
        mask_indices = (probability_mask < 0.15) & attention_mask.bool()
        # Set the labels for all unmasked tokens to -100 so that they are ignored during training
        masked_lm_labels[~mask_indices] = -100
        # Find the indices of the masked tokens
        masked_indices = torch.nonzero(mask_indices, as_tuple=False)
        # Create a random mask for the masked tokens
        probability_mask[mask_indices] = torch.rand(sum(mask_indices))
        # Replace the masked tokens with either the [MASK] token, a random token from the vocabulary, or leave them unchanged with probabilities 0.8, 0.1, and 0.1 respectively
        input_ids[mask_indices] = torch.where(probability_mask[mask_indices] < 0.8, self.tokenizer.mask_token_id, torch.where(probability_mask[mask_indices] < 0.9, torch.randint(len(self.tokenizer), (sum(mask_indices),)), input_ids[mask_indices]))

        # Return a dictionary containing the input_ids, attention_mask, and labels for this sequence
        sample = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': masked_lm_labels,
        }

        return sample

In [3]:
# read in the train dataset
# create an amp_data class of the dataset

data_url = 'https://raw.githubusercontent.com/GIST-CSBL/AMP-BERT/main/all_veltri.csv'
df = pd.read_csv(data_url, index_col = 0)
df = df[df["AMP"] == True]
df = df.sample(frac=1, random_state = 0)
print(df.head(7))

train_dataset = amp_data(df)

                                                    aa_seq  aa_len   AMP
AP01472                                      QLYENKPRRPYIL      13  True
AP00605                                     ILPILSLIGGLLGK      14  True
AP01892                           IPWKLPATFRPVERPFSKPFCRKD      24  True
AP02733  LFGSVKAWFKGAKKGFQDYRYQKDMAKMNKRYGPNWQQRGGQEPPA...      55  True
AP00045           QGVRSYLSCWGNRGICLLNRCPGRMRQIGTCLAPRVKCCR      40  True
AP00108                                      FLPFLATLLSKVL      13  True
AP01206                               CTFTLPGGGGVCTLTSECIC      20  True


In [4]:
# Replace the model and tokenizer with the BERT model and tokenizer for MLM

def model_init():
    return BertForMaskedLM.from_pretrained('Rostlab/prot_bert_bfd')

tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert_bfd')


In [None]:
# Modify the TrainingArguments for MLM
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=7,
    learning_rate=5e-5,
    per_device_train_batch_size=1,
    warmup_steps=0,
    weight_decay=0.1,
    logging_dir='./logs',
    logging_steps=10,
    do_train=True,
    do_eval=True,
    evaluation_strategy="no",
    save_strategy='no',
    gradient_accumulation_steps=64,
    fp16=True,
    fp16_opt_level="O2",
    run_name="BERT-MLM",
    seed=0,
    load_best_model_at_end=True,
)

# Use Trainer with model for masked language modeling
from transformers import DataCollatorForLanguageModeling

trainer = Trainer(
    model_init=model_init,
    args=training_args,
    data_collator=DataCollatorForLanguageModeling(tokenizer),
    train_dataset=train_dataset,
)

trainer.train()

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a

Step,Training Loss
10,1653.8639
20,116.5264
30,149.54
40,115.9875
50,75.9254
60,62.3669
70,85.5157
80,66.407
90,150.6399
100,18.3741


In [None]:
# save the model, if desired

#from google.colab import drive
#drive.mount('/content/drive')

trainer.save_model('/content/results/Colab Notebooks/AMP-BERT/Fine-tuned_model/')

In [None]:
# predict AMP/non-AMP for a single example

# IMPORTANT:
# one must mount their Google Drive and load their own fine-tuned model before running the below cell for individual predictions
#from google.colab import drive
#drive.mount('/content/drive')

from transformers import pipeline

# load appropriate tokenizer and fine-tuned model
tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert_bfd', do_lower_case=False)
model = BertForMaskedLM.from_pretrained("/content/results/Colab Notebooks/AMP-BERT/Fine-tuned_model/")
unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer)

In [8]:
from random import randint
import json

#@markdown **Input peptide sequence (upper case only)**
input_seq = 'SLGPAIKATRQVCPKATRFVTVSCKKSDCQ' #@param {type:"string"}

input_seq_spaced = ' '.join([ input_seq[i:i+1] for i in range(0, len(input_seq), 1) ])

tokens = input_seq_spaced.split()
random_index = randint(0, len(tokens) - 1)
tokens[random_index] = '[MASK]'

# Join the words back into a string with spaces
masked_seq_spaced = " ".join(tokens)
print('Original input: ' + input_seq_spaced)
print('Masked input:   ' + masked_seq_spaced)
print(json.dumps(unmasker(masked_seq_spaced), indent=2))

Original input: S L G P A I K A T R Q V C P K A T R F V T V S C K K S D C Q
Masked input:   S L G P A I K A T R Q V C P K A T R F V T V S C K K S [MASK] C Q
[
  {
    "score": 0.14264661073684692,
    "token": 12,
    "token_str": "K",
    "sequence": "S L G P A I K A T R Q V C P K A T R F V T V S C K K S K C Q"
  },
  {
    "score": 0.0845257043838501,
    "token": 13,
    "token_str": "R",
    "sequence": "S L G P A I K A T R Q V C P K A T R F V T V S C K K S R C Q"
  },
  {
    "score": 0.08332563936710358,
    "token": 10,
    "token_str": "S",
    "sequence": "S L G P A I K A T R Q V C P K A T R F V T V S C K K S S C Q"
  },
  {
    "score": 0.08325883746147156,
    "token": 7,
    "token_str": "G",
    "sequence": "S L G P A I K A T R Q V C P K A T R F V T V S C K K S G C Q"
  },
  {
    "score": 0.06272854655981064,
    "token": 8,
    "token_str": "V",
    "sequence": "S L G P A I K A T R Q V C P K A T R F V T V S C K K S V C Q"
  }
]
