<a href="https://colab.research.google.com/github/ayyucedemirbas/Protein_Language_Models/blob/main/amino_acid_token_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets evaluate

In [2]:
import requests
from io import BytesIO
import pandas
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
from datasets import Dataset
from evaluate import load
import numpy as np
import re

In [3]:
query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Cft_strand%2Cft_helix&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"

In [4]:
uniprot_request = requests.get(query_url)

In [5]:
bio = BytesIO(uniprot_request.content)

df = pandas.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Sequence,Beta strand,Helix
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,,
1,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,,
2,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,"STRAND 79..81; /evidence=""ECO:0007829|PDB:7EMF""","HELIX 83..86; /evidence=""ECO:0007829|PDB:7EMF""..."
3,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,,
4,A0PJY2,MDSSCHNATTKMLATAPARGNMMSTSKPLAFSIERIMARTPEPKAL...,,
...,...,...,...,...
11976,Q9H8V8,MKPDWPRRGAAGTRVRSRGEGDGTYFARRGAGRRRREIKAPIRAAW...,,
11977,Q9HAA7,MLFGIRILVNTPSPLVTGLHHYNPSIHRDQGECANQWRKGPGSAHL...,,
11978,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,,
11979,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,,


In [6]:
no_structure_rows = df["Beta strand"].isna() & df["Helix"].isna()
df = df[~no_structure_rows]
df

Unnamed: 0,Entry,Sequence,Beta strand,Helix
2,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,"STRAND 79..81; /evidence=""ECO:0007829|PDB:7EMF""","HELIX 83..86; /evidence=""ECO:0007829|PDB:7EMF""..."
13,A1L3X0,MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY...,"STRAND 97..99; /evidence=""ECO:0007829|PDB:6Y7F""","HELIX 17..20; /evidence=""ECO:0007829|PDB:6Y7F""..."
14,A1XBS5,MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...,,"HELIX 2..6; /evidence=""ECO:0007829|PDB:8CEG""; ..."
15,A1Z1Q3,MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE...,"STRAND 71..77; /evidence=""ECO:0007829|PDB:4IQY...","HELIX 11..19; /evidence=""ECO:0007829|PDB:4IQY""..."
19,A2RUC4,MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV...,"STRAND 10..13; /evidence=""ECO:0007829|PDB:3AL5...","HELIX 16..22; /evidence=""ECO:0007829|PDB:3AL5""..."
...,...,...,...,...
11592,Q96I45,MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF...,"STRAND 3..5; /evidence=""ECO:0007829|PDB:2LOR"";...","HELIX 6..16; /evidence=""ECO:0007829|PDB:2LOR"";..."
11650,Q9H0W7,MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP...,"STRAND 7..9; /evidence=""ECO:0007829|PDB:2D8R"";...","HELIX 29..38; /evidence=""ECO:0007829|PDB:2D8R"""
11689,Q9P1F3,MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL...,"STRAND 24..29; /evidence=""ECO:0007829|PDB:2L2O...","HELIX 3..17; /evidence=""ECO:0007829|PDB:2L2O"";..."
11691,Q9P298,MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI...,"STRAND 11..14; /evidence=""ECO:0007829|PDB:2LON...","HELIX 18..24; /evidence=""ECO:0007829|PDB:2LON""..."


In [7]:
df.iloc[0]["Helix"]

'HELIX 83..86; /evidence="ECO:0007829|PDB:7EMF"; HELIX 90..96; /evidence="ECO:0007829|PDB:7EMF"; HELIX 112..116; /evidence="ECO:0007829|PDB:7EMF"; HELIX 128..138; /evidence="ECO:0007829|PDB:7EMF"; HELIX 147..152; /evidence="ECO:0007829|PDB:7EMF"'

In [8]:
strand_re = r"STRAND\s(\d+)\.\.(\d+)\;"
helix_re = r"HELIX\s(\d+)\.\.(\d+)\;"

re.findall(helix_re, df.iloc[0]["Helix"])

[('83', '86'), ('90', '96'), ('112', '116'), ('128', '138'), ('147', '152')]

In [9]:
def build_labels(sequence, strands, helices):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)

    if isinstance(helices, float): # Indicates missing (NaN)
        found_helices = []
    else:
        found_helices = re.findall(helix_re, helices)
    for helix_start, helix_end in found_helices:
        helix_start = int(helix_start) - 1
        helix_end = int(helix_end)
        assert helix_end <= len(sequence)
        labels[helix_start: helix_end] = 1  # Helix category

    if isinstance(strands, float): # Indicates missing (NaN)
        found_strands = []
    else:
        found_strands = re.findall(strand_re, strands)
    for strand_start, strand_end in found_strands:
        strand_start = int(strand_start) - 1
        strand_end = int(strand_end)
        assert strand_end <= len(sequence)
        labels[strand_start: strand_end] = 2  # Strand category
    return labels

In [10]:
sequences = []
labels = []

for row_idx, row in df.iterrows():
    row_labels = build_labels(row["Sequence"], row["Beta strand"], row["Helix"])
    sequences.append(row["Sequence"])
    labels.append(row_labels)

In [11]:
train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

In [13]:
model_checkpoint = "facebook/esm2_t12_35M_UR50D"

In [14]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [15]:
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [17]:
from transformers import AutoModelForTokenClassification

In [18]:
num_labels = 3
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8

In [20]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [21]:
args = TrainingArguments(
    f"{model_name}-finetuned-secondary-structure",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
    report_to="none",
)

In [22]:
metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [24]:
data_collator = DataCollatorForTokenClassification(tokenizer)

In [25]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

  trainer = Trainer(


In [26]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.464615,0.81045
2,0.503000,0.458124,0.812158
3,0.378500,0.462791,0.814856


TrainOutput(global_step=1215, training_loss=0.42199833873858666, metrics={'train_runtime': 616.8958, 'train_samples_per_second': 15.756, 'train_steps_per_second': 1.97, 'total_flos': 882160536634944.0, 'train_loss': 0.42199833873858666, 'epoch': 3.0})