In [2]:

import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2Tokenizer, GPT2LMHeadModel


#use this from Char_To_INT.txt File 
CHAR_TO_INT = {'G': 0, 'I': 1, 'E': 2, 'R': 3, 'L': 4, 'Q': 5, 'S': 6, 'X': 7, 'B': 8, 'D': 9, 'F': 10, 'N': 11, 'M': 12, 'T': 13, 'Y': 14, 'P': 15, 'C': 16, 'V': 17, 'Z': 18, '-': 19, 'H': 20, 'A': 21, 'W': 22, 'K': 23}
# Define the same configuration and tokenizer used during training
model_name = "gpt2-medium"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
config = GPT2Config.from_pretrained(model_name, n_positions=512)
class GPT2ForProteinPrediction(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.lm_head = nn.Linear(config.n_embd, NUM_CLASSES, bias=False)
# Recreate the model instance with the same architecture
NUM_CLASSES = len(CHAR_TO_INT)  
model = GPT2ForProteinPrediction(config).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# Load the saved model weights
model.load_state_dict(torch.load('protein_gpt_model_mab.pth'))

# Set the model to evaluation mode
model.eval()


In [8]:

import numpy as np
import random
def introduce_gaps_and_errors(sequence, gap_probability=0.1, error_probability=0.05):
    new_sequence = []
    valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")
    for aa in sequence:
        if np.random.rand() < gap_probability:
            new_sequence.append('-')
        else:
            if np.random.rand() < error_probability and aa in valid_amino_acids:
                new_sequence.append(random.choice(list(valid_amino_acids - {aa})))
            else:
                new_sequence.append(aa)
    return ''.join(new_sequence)

def custom_tokenize_pred(sequence, char_to_int):
    token_ids = [char_to_int.get(char, char_to_int.get('-')) for char in sequence]
    max_length = len(token_ids)
    padded_token_ids = token_ids + [char_to_int.get('-')] * (max_length - len(token_ids))
    return padded_token_ids
def predict_sequence(input_sequence):
    input_tokens = torch.tensor([custom_tokenize_pred(input_sequence, CHAR_TO_INT)], dtype=torch.long).to(model.device)
    with torch.no_grad():
        outputs = model(input_tokens)
        logits = outputs[0]
    predicted_token_ids = torch.argmax(logits, dim=-1).squeeze().tolist()
    predicted_sequence = ''.join([INT_TO_CHAR.get(id, '-') for id in predicted_token_ids])
    return predicted_sequence


#use this from INT_To_CHAR.txt  File 
INT_TO_CHAR = {0: 'G', 1: 'I', 2: 'E', 3: 'R', 4: 'L', 5: 'Q', 6: 'S', 7: 'X', 8: 'B', 9: 'D', 10: 'F', 11: 'N', 12: 'M', 13: 'T', 14: 'Y', 15: 'P', 16: 'C', 17: 'V', 18: 'Z', 19: '-', 20: 'H', 21: 'A', 22: 'W', 23: 'K'}
input_seq ="DIQMTQSPSSLSASVGDRVTITCKASQNIDKYLNWYQQKPGKAPKLLIYNTNNLQTGVPS\
RFSGSGSGTDFTFTISSLQPEDIATYYCLQHISRPRTFGQGTKVEIKRTVAAPSVFIFPP\
SDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLT\
LSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC"
# scaffold_seq = introduce_gaps_and_errors(input_seq, gap_probability=(0.3) , error_probability=0.10)
scaffold_seq = "---MTQSPSSISASVGDRVTITCK---NIDKYINWYQQKPGKAPKIIIYNTNNIQTGVPSRF---G----FTFTI-----------YCIQHISRPRTFGQGTKVEIKRSIAAPSVFIFPPSDEQIKSGTASVVCIINNFYPREAQPRRKVDNAIQSGNSQESVTEQDSKDSTYSISSTITISKADYEKHKVYACEVTHQGISSPVTKSFN----"
print(scaffold_seq)
predicted_seq = predict_sequence(scaffold_seq)
print("Predicted Sequence:", predicted_seq)



---MTQSPSSISASVGDRVTITCK---NIDKYINWYQQKPGKAPKIIIYNTNNIQTGVPSRF---G----FTFTI-----------YCIQHISRPRTFGQGTKVEIKRSIAAPSVFIFPPSDEQIKSGTASVVCIINNFYPREAQPRRKVDNAIQSGNSQESVTEQDSKDSTYSISSTITISKADYEKHKVYACEVTHQGISSPVTKSFN----
Predicted Sequence: DIQMTQSPSSLSASVGDRVTITCKASQNIDKYLNWYQQKPGKAPKLLIYNTNNLQTGVPSRFSGSGSGTDFTFTISSLQPEDIATYYCLQHISRPRTFGQGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC


In [9]:
seq = ""
for i in range(0,len(input_seq)):
    if scaffold_seq[i] == "-" and input_seq[i]== predicted_seq[i]:
        seq+= "\033[92m" + predicted_seq[i]+ "\033[0m"


    elif input_seq[i]== predicted_seq[i]:   
        if   scaffold_seq[i]== input_seq[i]: 
            seq+= predicted_seq[i]
        else:
            seq+= "\033[92m" + predicted_seq[i]+ "\033[0m"
        # print(predicted_seq[i])
    else:  
        seq+= "\033[91m" + predicted_seq[i] + "\033[0m"
        # print("\033[91m" + predicted_seq[i] + "\033[0m")
    if i==110:
        seq+="\n"
print(seq)

[92mD[0m[92mI[0m[92mQ[0mMTQSPSS[92mL[0mSASVGDRVTITCK[92mA[0m[92mS[0m[92mQ[0mNIDKY[92mL[0mNWYQQKPGKAPK[92mL[0m[92mL[0mIYNTNN[92mL[0mQTGVPSRF[92mS[0m[92mG[0m[92mS[0mG[92mS[0m[92mG[0m[92mT[0m[92mD[0mFTFTI[92mS[0m[92mS[0m[92mL[0m[92mQ[0m[92mP[0m[92mE[0m[92mD[0m[92mI[0m[92mA[0m[92mT[0m[92mY[0mYC[92mL[0mQHISRPRTFGQGTKVEIKR[92mT[0m[92mV[0mA
APSVFIFPPSDEQ[92mL[0mKSGTASVVC[92mL[0m[92mL[0mNNFYPREA[92mK[0m[92mV[0m[92mQ[0m[92mW[0mKVDNA[92mL[0mQSGNSQESVTEQDSKDSTYS[92mL[0mSST[92mL[0mT[92mL[0mSKADYEKHKVYACEVTHQG[92mL[0mSSPVTKSFN[92mR[0m[92mG[0m[92mE[0m[92mC[0m


In [10]:
seq1= ""
for i in range(0,len(input_seq)):
    if scaffold_seq[i]== input_seq[i]:   
        seq1+= scaffold_seq[i]
        # print(predicted_seq[i])
    else:  
        seq1+= "\033[91m" + scaffold_seq[i] + "\033[0m"
        # print("\033[91m" + predicted_seq[i] + "\033[0m")
    if i==110:
        seq1+="\n"
print("input seq:\n"+seq1)
print("\npredicted seq:\n"+seq)

input seq:
[91m-[0m[91m-[0m[91m-[0mMTQSPSS[91mI[0mSASVGDRVTITCK[91m-[0m[91m-[0m[91m-[0mNIDKY[91mI[0mNWYQQKPGKAPK[91mI[0m[91mI[0mIYNTNN[91mI[0mQTGVPSRF[91m-[0m[91m-[0m[91m-[0mG[91m-[0m[91m-[0m[91m-[0m[91m-[0mFTFTI[91m-[0m[91m-[0m[91m-[0m[91m-[0m[91m-[0m[91m-[0m[91m-[0m[91m-[0m[91m-[0m[91m-[0m[91m-[0mYC[91mI[0mQHISRPRTFGQGTKVEIKR[91mS[0m[91mI[0mA
APSVFIFPPSDEQ[91mI[0mKSGTASVVC[91mI[0m[91mI[0mNNFYPREA[91mQ[0m[91mP[0m[91mR[0m[91mR[0mKVDNA[91mI[0mQSGNSQESVTEQDSKDSTYS[91mI[0mSST[91mI[0mT[91mI[0mSKADYEKHKVYACEVTHQG[91mI[0mSSPVTKSFN[91m-[0m[91m-[0m[91m-[0m[91m-[0m

predicted seq:
[92mD[0m[92mI[0m[92mQ[0mMTQSPSS[92mL[0mSASVGDRVTITCK[92mA[0m[92mS[0m[92mQ[0mNIDKY[92mL[0mNWYQQKPGKAPK[92mL[0m[92mL[0mIYNTNN[92mL[0mQTGVPSRF[92mS[0m[92mG[0m[92mS[0mG[92mS[0m[92mG[0m[92mT[0m[92mD[0mFTFTI[92mS[0m[92mS[0m[92mL[0m[92mQ[0m[92mP[0m[92mE[0m[92mD[0m[92mI[0m[92mA[