## Protein classification

get data from UniProt

In [1]:
import requests

query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"

uniprot_request = requests.get(query_url)

In [2]:
from io import BytesIO
import pandas

bio = BytesIO(uniprot_request.content)

df = pandas.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
2,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,SUBCELLULAR LOCATION: Endoplasmic reticulum me...
3,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.
4,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
...,...,...,...
11975,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,
11976,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,
11977,Q9Y6C7,MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW...,
11978,X6R8D5,MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP...,


In [3]:
df = df.dropna()
df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
2,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,SUBCELLULAR LOCATION: Endoplasmic reticulum me...
3,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.
4,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
...,...,...,...
11916,Q8N8V8,MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11934,Q8WWF1,MDVLFVAIFAVPLILGQEYEDEERLGEDEYYQVVYYYTVTPSYDDF...,SUBCELLULAR LOCATION: Secreted {ECO:0000305}.
11956,Q96N68,MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11963,Q9H0A3,MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT...,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...


In [4]:
cytosolic = df['Subcellular location [CC]'].str.contains("Cytosol") | df['Subcellular location [CC]'].str.contains("Cytoplasm")
membrane = df['Subcellular location [CC]'].str.contains("Membrane") | df['Subcellular location [CC]'].str.contains("Cell membrane")
cytosolic_df = df[cytosolic & ~membrane]
cytosolic_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
10,A1E959,MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL...,SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un...
15,A1XBS5,MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
19,A2RU49,MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
21,A2RUH7,MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP...,"SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa..."
22,A4D126,MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA...,"SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:..."
...,...,...,...
11555,Q96L03,MATLARLQARSSTVGNQYYFRNSVVDPFRKKENDAAVKIQSWFRGC...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000250}.
11598,Q9BYD9,MNHCQLPVVIDNGSGMIKAGVAGCREPQFIYPNIIGRAKGQSRAAQ...,"SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ..."
11640,Q9NPB0,MEQRLAEFRAARKRAGLAAQPPAASQGAQTPGEKAEAAATLKAAPG...,SUBCELLULAR LOCATION: Cytoplasmic vesicle memb...
11653,Q9NUJ7,MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...


In [5]:
membrane_df = df[membrane & ~cytosolic]
membrane_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
4,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
18,A2RU14,MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
34,A5X5Y0,MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN...,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...
...,...,...,...
11841,Q6UWF5,MQIQNNLFFCCYTVMSAIFKWLLLYSLPALCFLLGTQESESFHSKA...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11916,Q8N8V8,MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11956,Q96N68,MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11963,Q9H0A3,MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT...,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...


In [6]:
cytosolic_sequences = cytosolic_df["Sequence"].tolist()
cytosolic_labels = [0 for protein in cytosolic_sequences]

membrane_sequences = membrane_df["Sequence"].tolist()
membrane_labels = [1 for protein in membrane_sequences]

In [7]:
sequences = cytosolic_sequences + membrane_sequences
labels = cytosolic_labels + membrane_labels

len(sequences) == len(labels)

True

In [8]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

In [9]:
from transformers import AutoTokenizer

model_checkpoint = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

2023-02-07 23:05:16.917367: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [10]:
tokenizer(train_sequences[0])

{'input_ids': [0, 20, 6, 5, 6, 5, 11, 6, 10, 5, 20, 13, 6, 14, 10, 4, 4, 4, 4, 4, 4, 4, 6, 7, 8, 4, 6, 6, 5, 15, 9, 5, 23, 14, 11, 6, 4, 19, 11, 21, 8, 6, 9, 23, 23, 15, 5, 23, 17, 4, 6, 9, 6, 7, 5, 16, 14, 23, 6, 5, 17, 16, 11, 7, 23, 9, 14, 23, 4, 13, 8, 7, 11, 18, 8, 13, 7, 7, 8, 5, 11, 9, 14, 23, 15, 14, 23, 11, 9, 23, 7, 6, 4, 16, 8, 20, 8, 5, 14, 23, 7, 9, 5, 13, 13, 5, 7, 23, 10, 23, 5, 19, 6, 19, 19, 16, 13, 9, 11, 11, 6, 10, 23, 9, 5, 23, 10, 7, 23, 9, 5, 6, 8, 6, 4, 7, 18, 8, 23, 16, 13, 15, 16, 17, 11, 7, 23, 9, 9, 23, 14, 13, 6, 11, 19, 8, 13, 9, 5, 17, 21, 7, 13, 14, 23, 4, 14, 23, 11, 7, 23, 9, 13, 11, 9, 10, 16, 4, 10, 9, 23, 11, 10, 22, 5, 13, 5, 9, 23, 9, 9, 12, 14, 6, 10, 22, 12, 11, 10, 8, 11, 14, 14, 9, 6, 8, 13, 8, 11, 5, 14, 8, 11, 16, 9, 14, 9, 5, 14, 14, 9, 16, 13, 4, 12, 5, 8, 11, 7, 5, 6, 7, 7, 11, 11, 7, 20, 6, 8, 8, 16, 14, 7, 7, 11, 10, 6, 11, 11, 13, 17, 4, 12, 14, 7, 19, 23, 8, 12, 4, 5, 5, 7, 7, 7, 6, 4, 7, 5, 19, 12, 5, 18, 15, 10, 22, 17, 8, 23, 15, 16

In [11]:
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

In [12]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 3808
})

In [13]:
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 3808
})

In [14]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

num_labels = max(train_labels + test_labels) + 1
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing EsmForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', '

In [15]:
from evaluate import load
import numpy as np


model_name = model_checkpoint.split("/")[-1]
batch_size = 8

args = TrainingArguments(
    f"{model_name}-finetuned-localization",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)


metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)


trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [16]:
trainer.train()

***** Running training *****
  Num examples = 3808
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 1428
  Number of trainable parameters = 33993843


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.177176,0.949606
2,0.235900,0.177769,0.951969
3,0.149400,0.178976,0.949606


***** Running Evaluation *****
  Num examples = 1270
  Batch size = 8
Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476
Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/config.json
Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/pytorch_model.bin
tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/tokenizer_config.json
Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1270
  Batch size = 8
Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952
Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/config.json
Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/pytorch_model.bin
tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/toke

TrainOutput(global_step=1428, training_loss=0.16585329200039392, metrics={'train_runtime': 395.6749, 'train_samples_per_second': 28.872, 'train_steps_per_second': 3.609, 'total_flos': 1027026171523296.0, 'train_loss': 0.16585329200039392, 'epoch': 3.0})