In [1]:
import torch
from transformers import AutoTokenizer, EsmForMaskedLM
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

from Bio import SeqIO
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


### CUDA/Torch GPU Setup

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

Using GPU: NVIDIA GeForce RTX 4060 Ti


## Convert FASTA file to pd.DataFrame for easier processing

In [3]:
def fasta_to_dataframe(fasta_file):
    records = []
    for seq_record in SeqIO.parse(fasta_file, 'fasta'):
        # Extract the ID after '>tr'
        id_info = seq_record.id.split('|')[1] # Adjust the index based on your FASTA file format
        sequence = str(seq_record.seq)
        
        # Generate a random index within the range of the sequence length
        random_index = np.random.randint(0, len(sequence))
        # Replace the character at the random index with '<mask>'
        masked_sequence = sequence[:random_index] + '<mask>' + sequence[random_index+1:]
        
        records.append([id_info, sequence, masked_sequence])
    
    # Create a DataFrame from the records
    df = pd.DataFrame(records, columns=['ID', 'Sequence', 'Masked_Sequence'])
    return df

## Create Dataset

In [4]:
fasta_df = fasta_to_dataframe("human_protein_seq/uniprotkb_proteome_UP000005640.fasta")

In [5]:
fasta_df.head()

Unnamed: 0,ID,Sequence,Masked_Sequence
0,A0A075B6G3,MLWWEEVEDCYEREDVQKKTFTKWVNAQFSKFGKQHIENLFSDLQD...,MLWWEEVEDCYEREDVQKKTFTKWVNAQFSKFGKQHIENLFSDLQD...
1,A0A087WV00,MDAAGRGCHLLPLPAARGPARAPAAAAAAAASPPGPCSGAACAPSA...,MDAAGRGCHLLPLPAARGPARAPAAAAAAAASPPGPCSGAACAPSA...
2,A0A087WZT3,MELSAEYLREKLQRDLEAEHVLPSPGGVGQVRGETAASETQLGS,MELSAEYLREKLQRDLEAEHVLPSP<mask>GVGQVRGETAASETQLGS
3,A0A087X1C5,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...
4,A0A087X296,MSRSLLLWFLLFLLLLPPLPVLLADPGAPTPVNPCCYYPCQHQGIC...,MSRSLLLWFLLFLLLLPPLPVLLADPGAPTPVNPCCYYPCQHQGIC...


## Preparing your model and tokenizer

Now we load our model and tokenizer. If using GPU, use `model.cuda()` to transfer the model to GPU.

In [6]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

In [7]:
model = model.cuda()
accuracy = []

for id, fasta_row in fasta_df.iterrows():
    inputs = tokenizer(fasta_row["Masked_Sequence"], max_length=4096, truncation=True, padding='max_length', return_tensors='pt')
    print(inputs)
    
    with torch.no_grad():
        logits = model(**inputs).logits
    
    # Retrieve index of <mask>
    mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
    predicted_token_id = logits[0, mask_token_index].argmax(axis=-1).item()
    
    predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
    
    unmasked_text = fasta_row['Sequence']
    mask_position = unmasked_text.find('<mask>')
    actual_token = unmasked_text[mask_position]

    if predicted_token == actual_token:
        # print(f"Prediction matches the unmasked token: {predicted_token}")
        accuracy.append(1)
    else:
        # print(f"Prediction does not match the unmasked token. Predicted: {predicted_token}, Actual: {actual_token}")
        accuracy.append(0)

{'input_ids': tensor([[ 0, 20,  4,  ...,  1,  1,  1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0]])}


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [None]:
print("Accuracy: " +  str(sum(accuracy)/len(accuracy)))

Get latency, memory, power consumption, energy consumption
- latency is per sequence
- latency graph of sequence length vs latency
- accuracy graph of sequence length vs accuracy