
# Token classification tutorial with ESM2 pre-trained large protein language model by Meta

***

Another common task we can perform with pre-traind large language models is **token classification**. Instead of predict and classifying the entire protein sequence into a single class, we categorize each token (every aminoacid letter) into one or more categories or classes. Model like this are used for:

- Secondary structure prediction,
- Prediciton of exposed residues or buried ones,
- Prediction of residues likely to receive post-traslational modifications,
- Prediction of residues involved in binding pockets or active sites,
- Prediction of transmembrane protein topology,
- More and more applications.

***
# Libraries installation and model loading

In [None]:
! pip install transformers evaluate datasets requests pandas sklearn

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
! apt install git-lfs

Reading package lists... Done
Building dependency tree       
Reading state information... Done
git-lfs is already the newest version (2.9.2-1).
0 upgraded, 0 newly installed, 0 to remove and 24 not upgraded.


In [None]:
model_checkpoint = "facebook/esm2_t12_35M_UR50D"

***
# Data preparation/mining

First thing to do is (as always), to gather training data from literature or from a specialized database like **Uniprot**.

As in the sequence classification tutorial, we aim to create two lists: `sequences` and `labels`. In this case, however, the labels are more than just single integers. the label for each sample will be one integer per token in the input. This makes sense because when we do token classification, different tokens in the input may have different labels.

To demonstrate token classification, we are going to use UniProt and get some training data about protein secondary structures. 

In [None]:
import requests

query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Cft_strand%2Cft_helix&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"

This time, our UniProt search was `(organism_id:9606) AND (reviewed:true) AND (length:[100 TO 1000])` as it was in the first example, but instead of `Subcellular location [CC]` we take the `Helix` and `Beta strand` columns, as they contain the secondary structure information we want.

In [None]:
uniprot_request = requests.get(query_url)

To get this data into Pandas, we use a BytesIO object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to read_csv.

In [None]:
from io import BytesIO
import pandas

bio = BytesIO(uniprot_request.content)

df = pandas.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Sequence,Beta strand,Helix
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,,
1,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,"STRAND 9..14; /evidence=""ECO:0007829|PDB:4UDT""...","HELIX 2..4; /evidence=""ECO:0007829|PDB:4UDT""; ..."
2,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,,
3,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,"STRAND 79..81; /evidence=""ECO:0007829|PDB:7EMF""","HELIX 83..86; /evidence=""ECO:0007829|PDB:7EMF""..."
4,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,,
...,...,...,...,...
11985,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,,
11986,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,,
11987,Q9Y6C7,MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW...,,
11988,X6R8D5,MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP...,,


Since not all proteins have this structural information, we discard proteins that have no annotated beta strands or alpha helices.

In [None]:
no_structure_rows = df["Beta strand"].isna() & df["Helix"].isna()
df = df[~no_structure_rows]
df

Unnamed: 0,Entry,Sequence,Beta strand,Helix
1,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,"STRAND 9..14; /evidence=""ECO:0007829|PDB:4UDT""...","HELIX 2..4; /evidence=""ECO:0007829|PDB:4UDT""; ..."
3,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,"STRAND 79..81; /evidence=""ECO:0007829|PDB:7EMF""","HELIX 83..86; /evidence=""ECO:0007829|PDB:7EMF""..."
14,A1L3X0,MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY...,"STRAND 97..99; /evidence=""ECO:0007829|PDB:6Y7F""","HELIX 17..20; /evidence=""ECO:0007829|PDB:6Y7F""..."
16,A1Z1Q3,MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE...,"STRAND 71..77; /evidence=""ECO:0007829|PDB:4IQY...","HELIX 11..19; /evidence=""ECO:0007829|PDB:4IQY""..."
20,A2RUC4,MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV...,"STRAND 10..13; /evidence=""ECO:0007829|PDB:3AL5...","HELIX 16..22; /evidence=""ECO:0007829|PDB:3AL5""..."
...,...,...,...,...
11564,Q96I45,MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF...,"STRAND 3..5; /evidence=""ECO:0007829|PDB:2LOR"";...","HELIX 6..16; /evidence=""ECO:0007829|PDB:2LOR"";..."
11630,Q9H0W7,MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP...,"STRAND 7..9; /evidence=""ECO:0007829|PDB:2D8R"";...","HELIX 29..38; /evidence=""ECO:0007829|PDB:2D8R"""
11674,Q9P1F3,MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL...,"STRAND 24..29; /evidence=""ECO:0007829|PDB:2L2O...","HELIX 3..17; /evidence=""ECO:0007829|PDB:2L2O"";..."
11676,Q9P298,MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI...,"STRAND 11..14; /evidence=""ECO:0007829|PDB:2LON...","HELIX 18..24; /evidence=""ECO:0007829|PDB:2LON""..."


Well, this works, but that data still isn't in a clean format that we can use to build our labels. Let's take a look at one sample to see what exactly we're dealing with:

In [None]:
df.iloc[0]["Helix"]

'HELIX 2..4; /evidence="ECO:0007829|PDB:4UDT"; HELIX 17..23; /evidence="ECO:0007829|PDB:4UDT"; HELIX 83..86; /evidence="ECO:0007829|PDB:4UDT"'

We'll need to use a [regex](https://docs.python.org/3/howto/regex.html) to pull out each segment that's marked as being a STRAND or HELIX. What we're asking for is a list of everywhere we see the word STRAND or HELIX followed by two numbers separated by two dots. In each case where this pattern is found, we tell the regex to extract the two numbers as a tuple for us.

In [None]:
import re

strand_re = r"STRAND\s(\d+)\.\.(\d+)\;"
helix_re = r"HELIX\s(\d+)\.\.(\d+)\;"

re.findall(helix_re, df.iloc[0]["Helix"])

[('2', '4'), ('17', '23'), ('83', '86')]

Looks good! We can use this to build our training data. Recall that the **labels** need to be a list or array of integers that's the same length as the input sequence. We're going to use 0 to indicate residues without any annotated structure, 1 for residues in an alpha helix, and 2 for residues in a beta strand. To build that, we'll start with an array of all 0s, and then fill in values based on the positions that our regex pulls out of the UniProt results.

We'll use NumPy arrays rather than lists here, since these allow [slice assignment](https://numpy.org/doc/stable/user/basics.indexing.html#assigning-values-to-indexed-arrays), which will be a lot simpler than editing a list of integers. Note also that UniProt annotates residues starting from 1 (unlike Python, which starts from 0), and region annotations are inclusive (so 1..3 means residues 1, 2 and 3). To turn these into Python slices, we subtract 1 from the start of each annotation, but not the end.

In [None]:
import numpy as np

def build_labels(sequence, strands, helices):
    """function to assign labels to residues"""
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)
    
    if isinstance(helices, float): # Indicates missing (NaN)
        found_helices = []
    else:
        found_helices = re.findall(helix_re, helices)
    for helix_start, helix_end in found_helices:
        helix_start = int(helix_start) - 1
        helix_end = int(helix_end)
        assert helix_end <= len(sequence)
        labels[helix_start: helix_end] = 1  # Helix category
    
    if isinstance(strands, float): # Indicates missing (NaN)
        found_strands = []
    else:
        found_strands = re.findall(strand_re, strands)
    for strand_start, strand_end in found_strands:
        strand_start = int(strand_start) - 1
        strand_end = int(strand_end)
        assert strand_end <= len(sequence)
        labels[strand_start: strand_end] = 2  # Strand category
    return labels

Now we've defined a helper function, let's build our lists of sequences and labels:

In [None]:
sequences = []
labels = []

for row_idx, row in df.iterrows():
    row_labels = build_labels(row["Sequence"], row["Beta strand"], row["Helix"])
    sequences.append(row["Sequence"])
    labels.append(row_labels)

***
# Creating the dataset

Nice! Now we'll split and tokenize the data, and then create datasets - I'll go through this quite quickly here, since it's identical to how we did it in the sequence classification tutorial.

In [None]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

Downloading (…)okenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [None]:
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

***
# Model loading

The key difference here with the classification tutorial is that we must use `TFAutoModelForTokenClassification` instead of `TFAutoModelForSequenceClassification`. We will also need a `data_collator` this time, as we're in the slightly more complex case where both inputs and labels must be padded in each batch.

In [None]:
from transformers import TFAutoModelForTokenClassification

num_labels = 3
model = TFAutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Downloading (…)lve/main/config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

Downloading tf_model.h5:   0%|          | 0.00/134M [00:00<?, ?B/s]

Some layers from the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing TFEsmForTokenClassification: ['lm_head']
- This IS expected if you are initializing TFEsmForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFEsmForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFEsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['dropout_36', 'classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer, return_tensors="np")

Now we create our `tf.data.Dataset` objects as before. Remember to pass the data collator, though! Note that when you pass a data collator, there's no need to pass your tokenizer, as the data collator is handling padding for us.

In [None]:
tf_train_set = model.prepare_tf_dataset(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=data_collator
)

tf_test_set = model.prepare_tf_dataset(
    test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=data_collator
)

Our metrics are bit more complex than in the sequence classification task, as we need to ignore padding tokens (those where the label is `-100`). This means we'll need our own metric function where we only compute accuracy on non-padding tokens.

In [None]:
from transformers import AdamWeightDecay
import tensorflow as tf

def masked_accuracy(y_true, y_pred):
    predictions = tf.math.argmax(y_pred, axis=-1)  # Highest logit corresponds to predicted category
    numerator = tf.math.count_nonzero((predictions == tf.cast(y_true, predictions.dtype)) & (y_true != -100), dtype=tf.float32)
    denominator = tf.math.count_nonzero(y_true != -100, dtype=tf.float32)
    return numerator / denominator

model.compile(optimizer=AdamWeightDecay(2e-5), metrics=[masked_accuracy])

No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


And now we're ready to train our model! 

In [None]:
model.fit(tf_train_set, validation_data=tf_test_set, epochs=1)

  tensor = as_tensor(value)


  1/370 [..............................] - ETA: 2:06:44 - loss: 0.5335 - masked_accuracy: 0.7683

KeyboardInterrupt: ignored

This definitely seems harder than the first task (classification task based on the full protein), but we still attain a very respectable accuracy. Remember that to keep this demo lightweight, we used one of the smallest ESM models, focused on human proteins only and didn't put a lot of work into making sure we only included completely-annotated proteins in our training set. With a bigger model and a cleaner, broader training set, accuracy on this task could definitely go a lot higher!

Now, let's push this model to the hub as we did before, while also setting the category labels appropriately.

In [None]:
model.label2id = {"unstructured": 0, "helix": 1, "strand": 2}
model.id2label = {val: key for key, val in model.label2id.items()}

model_name = model_checkpoint.split('/')[-1]
finetuned_model_name = f"{model_name}-finetuned-secondary-structure-classification"

model.push_to_hub(finetuned_model_name)
tokenizer.push_to_hub(finetuned_model_name)

If you used the code above, you can now share this model with all your friends: they can all load it with the identifier `"your-username/the-name-you-picked"` so for instance:

In [None]:
from transformers import TFAutoModelForTokenClassification

model = TFAutoModelForTokenClassification.from_pretrained("your-username/my-awesome-model")