In [1]:
import requests

query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Cft_strand%2Cft_helix&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"

uniprot_request = requests.get(query_url)

In [2]:
from io import BytesIO
import pandas

bio = BytesIO(uniprot_request.content)

df = pandas.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Sequence,Beta strand,Helix
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,,
1,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,"STRAND 9..14; /evidence=""ECO:0007829|PDB:4UDT""...","HELIX 2..4; /evidence=""ECO:0007829|PDB:4UDT""; ..."
2,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,,
3,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,"STRAND 79..81; /evidence=""ECO:0007829|PDB:7EMF""","HELIX 83..86; /evidence=""ECO:0007829|PDB:7EMF""..."
4,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,,
...,...,...,...,...
11975,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,,
11976,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,,
11977,Q9Y6C7,MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW...,,
11978,X6R8D5,MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP...,,


In [3]:
no_structure_rows = df["Beta strand"].isna() & df["Helix"].isna()
df = df[~no_structure_rows]
df

Unnamed: 0,Entry,Sequence,Beta strand,Helix
1,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,"STRAND 9..14; /evidence=""ECO:0007829|PDB:4UDT""...","HELIX 2..4; /evidence=""ECO:0007829|PDB:4UDT""; ..."
3,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,"STRAND 79..81; /evidence=""ECO:0007829|PDB:7EMF""","HELIX 83..86; /evidence=""ECO:0007829|PDB:7EMF""..."
14,A1L3X0,MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY...,"STRAND 97..99; /evidence=""ECO:0007829|PDB:6Y7F""","HELIX 17..20; /evidence=""ECO:0007829|PDB:6Y7F""..."
16,A1Z1Q3,MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE...,"STRAND 71..77; /evidence=""ECO:0007829|PDB:4IQY...","HELIX 11..19; /evidence=""ECO:0007829|PDB:4IQY""..."
20,A2RUC4,MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV...,"STRAND 10..13; /evidence=""ECO:0007829|PDB:3AL5...","HELIX 16..22; /evidence=""ECO:0007829|PDB:3AL5""..."
...,...,...,...,...
11551,Q96I45,MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF...,"STRAND 3..5; /evidence=""ECO:0007829|PDB:2LOR"";...","HELIX 6..16; /evidence=""ECO:0007829|PDB:2LOR"";..."
11615,Q9H0W7,MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP...,"STRAND 7..9; /evidence=""ECO:0007829|PDB:2D8R"";...","HELIX 29..38; /evidence=""ECO:0007829|PDB:2D8R"""
11660,Q9P1F3,MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL...,"STRAND 24..29; /evidence=""ECO:0007829|PDB:2L2O...","HELIX 3..17; /evidence=""ECO:0007829|PDB:2L2O"";..."
11662,Q9P298,MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI...,"STRAND 11..14; /evidence=""ECO:0007829|PDB:2LON...","HELIX 18..24; /evidence=""ECO:0007829|PDB:2LON""..."


In [4]:
df.iloc[0]["Helix"]

'HELIX 2..4; /evidence="ECO:0007829|PDB:4UDT"; HELIX 17..23; /evidence="ECO:0007829|PDB:4UDT"; HELIX 83..86; /evidence="ECO:0007829|PDB:4UDT"'

In [5]:
import re

strand_re = r"STRAND\s(\d+)\.\.(\d+)\;"
helix_re = r"HELIX\s(\d+)\.\.(\d+)\;"

re.findall(helix_re, df.iloc[0]["Helix"])

[('2', '4'), ('17', '23'), ('83', '86')]

In [6]:
import numpy as np

def build_labels(sequence, strands, helices):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)
    
    if isinstance(helices, float): # Indicates missing (NaN)
        found_helices = []
    else:
        found_helices = re.findall(helix_re, helices)
    for helix_start, helix_end in found_helices:
        helix_start = int(helix_start) - 1
        helix_end = int(helix_end)
        assert helix_end <= len(sequence)
        labels[helix_start: helix_end] = 1  # Helix category
    
    if isinstance(strands, float): # Indicates missing (NaN)
        found_strands = []
    else:
        found_strands = re.findall(strand_re, strands)
    for strand_start, strand_end in found_strands:
        strand_start = int(strand_start) - 1
        strand_end = int(strand_end)
        assert strand_end <= len(sequence)
        labels[strand_start: strand_end] = 2  # Strand category
    return labels


sequences = []
labels = []

for row_idx, row in df.iterrows():
    row_labels = build_labels(row["Sequence"], row["Beta strand"], row["Helix"])
    sequences.append(row["Sequence"])
    labels.append(row_labels)

In [7]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

In [8]:
from transformers import AutoTokenizer

model_checkpoint = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

2023-02-07 23:21:07.382044: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

In [10]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

num_labels = 3
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing EsmForTokenClassification: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing EsmForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream t

In [11]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

In [12]:
from evaluate import load

model_name = model_checkpoint.split("/")[-1]
batch_size = 8

args = TrainingArguments(
    f"{model_name}-finetuned-secondary-structure",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2.5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=6,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    return metric.compute(predictions=predictions, references=labels)


trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

In [13]:
trainer.train()

***** Running training *****
  Num examples = 2934
  Num Epochs = 6
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2202
  Number of trainable parameters = 33763444


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.498732,0.796631
2,0.564200,0.470429,0.807404
3,0.430900,0.465695,0.81079
4,0.430900,0.4784,0.809629
5,0.385200,0.485697,0.80959
6,0.348800,0.49204,0.809191


***** Running Evaluation *****
  Num examples = 978
  Batch size = 8
Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367
Configuration saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/config.json
Model weights saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/pytorch_model.bin
tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/tokenizer_config.json
Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 978
  Batch size = 8
Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734
Configuration saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/config.json
Model weights saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/pytorch_model.bin
tokenizer config file saved in esm2_t1

TrainOutput(global_step=2202, training_loss=0.4230736949896401, metrics={'train_runtime': 635.332, 'train_samples_per_second': 27.708, 'train_steps_per_second': 3.466, 'total_flos': 1600290440180880.0, 'train_loss': 0.4230736949896401, 'epoch': 6.0})