In [46]:
import peft
import requests
from io import BytesIO
import pandas as pd
import torch

import transformers
import evaluate 
import datasets
import requests
import pandas
import sklearn
from datasets import Dataset
from transformers import TrainingArguments, Trainer


In [57]:
transformers.__version__

'4.38.2'

# Play with a small ESM2 checkpoint

First, let's play around with a simple ESM2 checkpoint on a sequence classification problem

In [4]:
model_checkpoint = 'facebook/esm2_t6_8M_UR50D' # This is the smallest of the ESM2 models: 6 layers, 8M params. 
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Let's download some data for a protein binary classification problem. In this case, we will attempt to predict whether a protein lives iinside a cell or on its membrane. 

In [5]:
query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"
uniprot_request = requests.get(query_url)
bio = BytesIO(uniprot_request.content)
df = pandas.read_csv(bio, compression='gzip', sep='\t')
df['seq_len'] = list(map(len, df.Sequence))
df = df.dropna()
df.sort_values('seq_len', ascending = False)
df['ind'] = list(df.index)
cytosolic = df['Subcellular location [CC]'].str.contains("Cytosol") | df['Subcellular location [CC]'].str.contains("Cytoplasm")
membrane = df['Subcellular location [CC]'].str.contains("Membrane") | df['Subcellular location [CC]'].str.contains("Cell membrane")
cytosolic_df = df[cytosolic & ~membrane]
cytosolic_df['label'] = 0
membrane_df = df[membrane & ~cytosolic]
membrane_df['label'] = 1
df = pd.concat([cytosolic_df, membrane_df]).sort_values('ind').sample(frac = 1)
df.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  cytosolic_df['label'] = 0
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  membrane_df['label'] = 1


Unnamed: 0,Entry,Sequence,Subcellular location [CC],seq_len,ind,label
10108,Q8NGJ6,MSIINTSYVEITTFFLVGMPGLEYAHIWISIPICSMYLIAILGNGT...,SUBCELLULAR LOCATION: Cell membrane; Multi-pas...,313,10108,1
2827,P56597,MEISMPPPQIYVEKTLAIIKPDIVDKEEEIQDIILRSGFTIVQRRK...,"SUBCELLULAR LOCATION: Cell projection, cilium ...",212,2827,0
5248,Q8WWB3,MESIYLQKHLGACLTQGLAEVARVRPVDPIEYLALWIYKYKENVTM...,"SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton,...",177,5248,0
6922,Q9NS68,MALKVLLEQEKTFFTLLVLLGYLSCKVTCESGDCRQQEFRDRSGNC...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...,423,6922,1
5992,Q99932,METNESTEGSRSRSRSLDIQPSSEGLGPTSEPFPSSDDSPRSALAA...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000250|U...,485,5992,0


In [6]:
len(df)

5149

Let's try passing a sequence through the pre-trained model

In [7]:
df.Sequence[3286]

'MCLAGCTPRKAAAPGRGALPRARLPRTAPAAATMFQPAAKRGFTIESLVAKDGGTGGGTGGGGAGSHLLAAAASEEPLRPTALNYPHPSAAEAAFVSGFPAAAAAGAGRSLYGGPELVFPEAMNHPALTVHPAHQLGASPLQPPHSFFGAQHRDPLHFYPWVLRNRFFGHRFQASDVPQDGLLLHGPFARKPKRIRTAFSPSQLLRLERAFEKNHYVVGAERKQLAGSLSLSETQVKVWFQNRRTKYKRQKLEEEGPESEQKKKGSHHINRWRIATKQANGEDIDVTSND'

In [8]:
len(df.Sequence[3286])

290

In [9]:
k = 3286
idx = tokenizer(df.Sequence[k], return_tensors = 'pt')
output = model(**idx)
output.last_hidden_state

tensor([[[ 0.1205,  0.5255,  0.1919,  ...,  1.1485,  0.0487, -0.3363],
         [ 0.2743,  0.4252, -0.3717,  ...,  0.8866,  0.0379,  0.0960],
         [-0.3534, -0.2014,  0.1076,  ...,  0.1796,  0.2354,  0.4582],
         ...,
         [-0.0576, -0.3005,  0.1158,  ..., -0.1239, -0.3607, -0.2614],
         [-0.0735, -0.2186,  0.0888,  ...,  0.1062, -0.1766, -0.1515],
         [ 0.0370, -0.1893,  0.1269,  ...,  0.4626, -0.8497, -0.3220]]],
       grad_fn=<NativeLayerNormBackward0>)

In [10]:
output.last_hidden_state.size()

torch.Size([1, 292, 320])

It seems like the number of tokens (292) in the last layer is 2 more than the protein sequence length (290). This might be because the model has added a couple of special tokens like [CLS] and [END]. Let's confirm this by using ESM's own generate embeddings tool and then reading back the embedding below. 

In [11]:
e = torch.load("/home/suhas/research/drug_design/esm/esm/examples/data/some_proteins_emb_esm2/UniRef50_A0SUHASP16.pt")

In [12]:
type(e), e.keys()

(dict, dict_keys(['label', 'representations', 'mean_representations']))

In [13]:
e['representations']

{6: tensor([[ 0.2743,  0.4252, -0.3717,  ...,  0.8866,  0.0379,  0.0960],
         [-0.3534, -0.2014,  0.1076,  ...,  0.1796,  0.2354,  0.4582],
         [ 0.0503, -0.5071,  0.0169,  ...,  0.5133, -0.1251,  0.2827],
         ...,
         [-0.3096, -0.5735, -0.3464,  ...,  0.2152,  0.1514,  0.0554],
         [-0.0576, -0.3005,  0.1158,  ..., -0.1239, -0.3607, -0.2614],
         [-0.0735, -0.2186,  0.0888,  ...,  0.1062, -0.1766, -0.1515]])}

In [14]:
e['representations'][6].size()

torch.Size([290, 320])

Ok so this matches what I get from the last layer above, except for the number of tokens -- so the first and the last token are probably the CLS and SEP tokens added on by ESM. 

# PEFT

4 steps in training a peft/lora model:
1. Instantiate a base model.
2. Create a configuration (LoraConfig) where you define LoRA-specific parameters.
3. Wrap the base model with get_peft_model() to get a trainable PeftModel.
4. Train the PeftModel as you normally would train the base model.

We will use the dataset we downloaded above to fine-tune the pre-trained ESM2 model using LoRA. 

In [15]:
df = df[['Sequence','label']]

In [17]:
df.groupby('label').size()

label
0    2599
1    2550
dtype: int64

In [18]:
from transformers import AutoModelForSequenceClassification
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
model_checkpoint = 'facebook/esm2_t6_8M_UR50D'

target_modules = []
for layer in range(6):
    for elem in ['query','key','value']:
        target_modules.append("esm.encoder.layer."+str(layer)+".attention.self." + elem)
        
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, 
    inference_mode=False, 
    r=4, 
    lora_alpha=32, 
    lora_dropout=0.1,
    target_modules=target_modules
)

model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels = len(set(df.label)))

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
model = get_peft_model(model, peft_config)

In [20]:
model.print_trainable_parameters()

trainable params: 149,442 || all params: 7,990,205 || trainable%: 1.870314966887583


Let us create a training and test dataset from df, and also let us tokenize it

In [21]:
len(df), df.columns

(5149, Index(['Sequence', 'label'], dtype='object'))

In [22]:
sequences = list(df.Sequence)
labels = list(df.label)

# Quick check to make sure we got it right
len(sequences) == len(labels)

True

In [36]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

In [37]:
type(train_sequences), type(test_sequences), type(train_labels), type(test_labels)

(list, list, list, list)

In [38]:
len(train_sequences), len(test_sequences), len(train_labels), len(test_labels)

(3861, 1288, 3861, 1288)

In [39]:
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

In [40]:
z = tokenizer(train_sequences[0])
type(z), len(z)

(transformers.tokenization_utils_base.BatchEncoding, 2)

In [41]:
len(z['input_ids'])

201

In [42]:
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 3861
})

In [44]:
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
train_dataset, test_dataset

(Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 3861
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 1288
 }))

In [52]:
model_name = model_checkpoint.split("/")[1]
model_name

'esm2_t6_8M_UR50D'

In [53]:
batch_size = 8

args = TrainingArguments(
    f"{model_name}-lora-finetuned-localization",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

In [54]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [55]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [56]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.243445,0.926242
2,0.459200,0.21787,0.932453




KeyboardInterrupt: 