## Imports

In [1]:
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import EsmForProteinFolding

from protxlstm.utils import load_sequences_from_msa_file, tokenizer, AA_TO_ID, reorder_masked_sequence, load_model
from protxlstm.generation import generate_sequence
from protxlstm.models.xlstm import xLSTMLMHeadModel
from protxlstm.dataloaders import ProteinMemmapDataset

from protxlstm.applications.generation_utils.score_hamming import align_sequences
from protxlstm.applications.generation_utils.score_hmmer import make_hmm_from_a3m_msa, align_and_score_sequences_in_a3m_with_hmm
from protxlstm.applications.generation_utils.score_structure import compute_structure

  from .autonotebook import tqdm as notebook_tqdm


## Settings

Please provide the path to your Prot-xLSTM model:

In [2]:
checkpoint = "../checkpoints/protxlstm_102M_60B"

Define your context sequences either by providing a path to an MSA file or by entering a list of protein sequences:

In [3]:
msa_path = "./example_msas/A0A1C5UJ41.a3m"

# protein_list = ["MRIDIDKFAGPCSCGREHEIDVKEIIIESGALKKLPEILSKYGLREYKNPAVICDTNTYAAAGELVEELLPRCEVIILDPEGLHADEHAVEKVMKQLDEDIDLLIAVGSGTIHDITRYCAYERGIPFISVPTAASVDGFVSTVAAMTWNGFKKTFPAVAPILVVADTDIFSKAPLRLTASGVGDILGKYIALADWKIAHLLTGEYICPEICDMEEKALDTVCSCLDGIAAGDEDAYEQLMYALILSGLAMQMVGNSRPASGAEHHMSHLWEMEVINGHIDALHGEKVGVGTVLVSDEYHKLAEAIRDGRCKVKPYMPLEEELLEETFGKKGLYEGILKENTPNPLEDVDPEMLEEKWPEIRDIIDELPSAEELRALLKKAGCKTTLEDIGLPESLKEETLRLSPYVRNRLTLMRLLKMLDFY",
#                 "MTEIMENLSVDGISGAEIKCRCGKMHKNQIKEIIIERGALAKIPDIIKKHGGSNVYVIADRNTYAAAGETVCKNIERYNLPYSLYVFDSERIEPDELAVGKAIMHYDGKCDFIVGIGSGTINDIGKMVACITGKPYMIVATAPSMDGYASATSSMIRDGIKVSLGTVCPCVIVADTEVLCNAPKILLQAGIGDMLAKYISICEWRLSHLITGEYYCEEIASMVRNALKNCM-QIESLEFTEPDDIKPVIEGLIISGIAMSFAGLSRPASGMEHYFSHLWDMRAIEFNTPSALHGIQCGVATVLCLRVYEFIARLVPDRKKACDFVNSFSLKEWNRFLAGFLGRSAEGLIELERKERKYNPESHAKRLDIIVNNWDEIVKIISEELPPAEQVEKYMKKLGMPTMPKELGFSDGEVQGAFLATKDIRDKYIGSRLLWDLGLLDEAKHVCRSVW",
#                 "MESKFSTTRVLPINQIFHLKQGVISAMMIDSKKYSGACACGHDHSMDTNLAVIQAGCLNQLDDYLQQFGLQGPRAAIYDENTYHAQGLVRPRAEQEIILAPENLHANEIAVEKVLSQLRGDIAILIAVGSGTIHDITRYCAHDRGILFISCPTAATVDGFCSTVSAMTWYGFKKTLPGVAPALVLADLNVICKAPAYLALSGVGDILGKYTALADWKISSAVSGEFFCPQIESMTRKAVQAVYQSARRLADRNEEAYEELTYGLLLSGLAMQLMGNSRPASGAEHHISHLIEMEPDGLGVHSNALHGEKVGAATLLVAREYHHLAETEDIAPHVHTYRFPDRYYLFPIFGERLTDAVSEENRDSCMKPVTPTALIEHWAEIRSIIAEIPAADELQSLYRDVGMKSTLADLGVPQSALPKLMEYSPCVRNRMTLMRIRRMIDLPYCE",
#                 "MFEEILDVSGCACGKNHTLQTREYIVEKDAMKKLPALLARLFPSAKPLAVFDRNTHRAAYPKFGAALPEVPACILADDEIHADERQIDLVTQALRDGGHDLLLAVGSGVICDVVRYVAFKQELPFIVVPTAASVDGFVSNSAAMTLNGAKITLPAKAPNAVVADLEVVAAAPKKMTASGVGDMLSKYISIADWKIGHLITGEYFCPFVADLTIEAVDMIVQNIEKINSGDIDSFGILMKGLLLSGVAMQMVGITRPASSFEHHFSHYLEIVPVEGVNRAALHGEKVGIATIQAAKYYPIFARRLSRIYKENIPNQFDIERVKGYYAQYPAGIVAAIEKENTPTITAKLDRRLLEQNYDEVLRIAGEVPSAEALTETLRAIGGYTSYHDINMTDEQFKETMKVCCYIRNRFTLLRLVCDFALFDFDAELKV",
#                 "MDVDLGHLSKPRVCGREHPDGIREIRIEPGATARLDDILLEYQYQNPVFICDSSTRAAAEPYLEEEFKDYLVIELDPTGLQADEASKQKILSQVEDCDLGLSSVPVDILVAIGAGTIHDLTRYAAEEFEIPFISVPTAASTDGFSCSMILRDPDGIRKEVPSVAPSWILADTNLFVHAPKRLTLAGVSDVISRLTALADWKVSHLVSDAWFDEEIYQEMRSRISRVIDQLEDICAGDVFATEALMDTLIYFGIMTGVPGENQAVCGAEHHVAHLWKMAVINPAPDALYGESVLTAMFLVLDQYKKMVPAIRQGKLRVDTEESKGIEYMLLERVFRDPEVLEQIIAENTPNPLEDIDLDAFEDSLEAIADVIDSLPRPDGLQRHLRAAGCRTALTQLGLPENIAALSLDAAPYLRGTITLLRLRKLLE",
#                 "MRVDADDFARPCSCGREHQIAVKEILIEAGAVEKLEEEMSEGMLREYISPLVICDTNTYAATEprotein_listELMEDIYDRCQVLVLDAEGLQADRHAIKIVENNMEEDIDLILAVGAGTIHDISRYIAHNYKVPFISVPTAASGDGFVTTVAAITLDGVKKTVPSVAPICVYADTDIFSKAPQRLTAAGISDLMAKYICLADWKIANLVTGEYFCRETVKLEEKALKTVKSSIQDITEGEEDECEQLMYALILSGLAMQMIGNSRPASCAEHQVTHLWDMEVINGPLDALHGEKVSVAALLVLEEYKRIAAAITQGRCHAKPYENEDEELLKETFGKKGLLEEIRKENEPELLETISPQHLEKCLNGIEEIIDELPSEQTMFRLLEKAGCAKTVYDIGLDESAVLPSLRLAPYTRRRLSLLRISKMLDIRGE",
#                 "MKIDANHLSGPCSCGGEHLLATQICVIQEGALFHLEEILSSIPVVGKRCAVYDENTYRAIPNSIHPRAEQEIILSPSGLHADENSTASVLARLEPDIQVMLAIGGGTVHDITRYCSTERGIPFISIPTAASCDGFCSNVAAMTWHGYKKTIPCQAPLLVVADLDVISAAPWRLTASGIGDMLGKFIALTDWRISHLLTGEKLCPVIYQIMEDAVDSIWTRCRDLRSGGSAAYEAVVYGLLMSGLAMQMIGTSRPASGAEHHVSHFIEVEPAALRTHSSALHGEKVGVGTLLIAQEYQRLSQIENIASLALPYAPVSDERLMEVFGPRLFSACREENLHDCLAQVTPERLIQQWPQIRQIIAKIPPAAQIHQFLTDLKASASLSDLGVPEAALELILEASPLIRNRLTFMRVRRIIRH",
#                 "MIMDCAKYAGLCECGRDHELETKMVVVEYGAINNFEKYMADVGLAGKKRAVVYDSVIYKLTEGKHVAADQEIVLEAQGLRAEDTLIEDMMKKLDDPEVIVAYGAGTIMDFGRYPAYKLGIPFVAIPTLASSDGFTANICSAIINGQKKSTPMCAPTLVVTDLDIIKGAPMRLVSSGINDILSKYVSVFDWKVSHMVADEYFCPKVCELAEHALKIMRDAADKLAKTGEVDHEAMTMAQMESGLTMQLLNHSRAASGAEHLAAHLVEMHPPRFEEAEGIHGECVGVGTYLCIKEYHRLASLPTPKAKKFEPLSEEWIREKFGDRLAPGIIKENANDVLGTFDPQNIVDHWDEIRDMINKLPSAEEMEALYKACGCKYLPEHIGIKPELADEMLAVSSAIRNRNTLIRMRRVLDFGE",
#                 "MQIDINSFRRPCNCGRTHEIFVKDILIEENALKRLPEKVRSIFDGRNTEIAVICDTNTYQAAGKTVEKLLPGCELIILPANDLRADNCGITLARKGLLSSGRIKLIIAAGAGTIHDISRYLAMEFRIPFVSVPTAASTDSYASVISILTMNGSKKNIPGDSPVLIIADTLILAKAPYRLTASGITKILRKYTALTDWEISHMVTGEYICQRICEMEMSALKEVCLYSNDLKGNTRDKNTLRAYEKLIYALLLSGIAMQMVGSISSASGGDDAAHLWEKEAVNELFETYHGEKISIGLMVAVHTCHKLKNTVKNGINKVMPNREIESMGKGRTYEEVVKENALDSLPAISGILGKLPTESDLRKLLTAAGYKREIRDIKLEERLVPLTKRLDFDTRNRLIFLKFTKFFKLKNEA",
#                 "MNKPSTEKIVINGGCAAECRSYAREHFGDAYAVVCDGNTEPIARRAFPGDELIVFPAGSHATEQAADDCISRIKSDELCGLIACGSGSVHDIARYSAHDRKIPFVSFPTAASVDGFASGVAAMTWHGRKVTFPSAPPIALFADDDVYSSAPRELLASGVGDIVGKYVSIFDWIFTSLLTSETVEDDIYKLENESLETVMHCDISSPDYPHGVMDCLVKSGIAIQLKDSSRPASGAEHHLSHLWEMGCIGTPKHAYHGEQVGVSTLFVLDRYKRNPRPQLRPKPLDRELLRPTFGTLTDGIIEENTPDSLAEITQSALDANADRIAELIKALPDPEEIREYLLSVGAKTTLTELGLPDSTEFIQRSLDWAPYVRRRLTYLKVI"]

Define your sampling parameters:

In [4]:
# number of sequences to sample as context; if set to -1 all sequences provided above are used
num_context_sequences = 100
# number of sequences to generate
num_sequences = 10  

# controls the randomness of the model’s output; the higher the more diverse
temperature = 0.9  
# limits the model's choices to the top k most likely next tokens
top_k = 10
# limits the model's choices smallest set of next tokens whose cumulative probability exceeds p
top_p = 0.9  

Set your device:

In [5]:
device = "cuda:0"

## Load Model and Data Class

In [6]:
# read msa file
if msa_path != None:
    msa_sequences = load_sequences_from_msa_file(msa_path)
    protein_list = [msa.upper() for msa in msa_sequences]

# tokenize context sequences
tokens = tokenizer(protein_list, concatenate=True)

# load data class
data_class = ProteinMemmapDataset(
        sample=False,
        max_msa_len=-1,
        reverse=False,
        seed=0,
        troubleshoot=False,
        fim_strategy="multiple_span",
        always_mask=False,
        max_position_embeddings=2048,
        max_seq_position_embeddings=512,
        add_position_ids="1d",
        mask_fraction=0.2,
        max_patches=5
    )

# get number of context sequences
if num_context_sequences == -1:
    num_context_sequences = len(protein_list)
else:
    num_context_sequences = min(num_context_sequences, len(protein_list))

In [7]:
# load the model

config_update_kwargs = {
                "mlstm_backend": "chunkwise_variable",
                "mlstm_chunksize": 1024,
                "mlstm_return_last_state": True}

model = load_model(checkpoint,
                    model_class=xLSTMLMHeadModel,
                    device=device,
                    dtype=torch.bfloat16,
                    **config_update_kwargs,
                    )
model = model.eval()

detected slstm_block
In newest xlstm


## Generate Sequences

In [8]:
# create a dataframe for the results
generation_df = pd.DataFrame(columns = ['Generated Sequence', 'Perplexity'])

for i in tqdm(range(num_sequences)):

    # sample context sequences and corresponding positional embeddings
    input_ids, pos_ids = data_class.sample_sequences(tokens.numpy()[0], num_sequences=num_context_sequences)
    input_ids.append(AA_TO_ID["<cls>"])
    input_ids = torch.asarray(input_ids, dtype=torch.int64)[None,:].to(device)
    pos_ids.append(0)
    pos_ids = torch.asarray(pos_ids, dtype=torch.int64)[None,:].to(device)

    # generate sequences
    output = generate_sequence(model,
                                input_ids,
                                position_ids=pos_ids,
                                is_fim={},
                                max_length=(input_ids.shape[1]+1000),
                                temperature=temperature,
                                top_k=top_k,
                                top_p=top_p,
                                return_dict_in_generate=True,
                                output_scores=True,
                                eos_token_id=torch.tensor([AA_TO_ID["<cls>"]]).to(device),
                                chunk_chunk_size=2**15,
                                device=device)
    
    # calculate perplexity
    perplexity = float(torch.exp(torch.nn.functional.cross_entropy(torch.from_numpy(output["scores"]).permute(0, 2, 1), torch.from_numpy(output["generated_tokens"][0][None,:]))))
    
    # append sequence and perplexity to data frame
    generation_df = pd.concat([generation_df, pd.DataFrame({'Generated Sequence': [reorder_masked_sequence(output["generated"][0])], 'Perplexity': [perplexity]})], ignore_index=True)
    


  generation_df = pd.concat([generation_df, pd.DataFrame({'Generated Sequence': [reorder_masked_sequence(output["generated"][0])], 'Perplexity': [perplexity]})], ignore_index=True)
100%|██████████| 10/10 [01:19<00:00,  7.95s/it]


In [9]:
display(generation_df)

Unnamed: 0,Generated Sequence,Perplexity
0,VLLVSDTGILNSGVLERIREKLKGLGIKVELFPLPESEPTFQQVEK...,4.358632
1,PVTVLSGPDAIARVGDELVEAGAKKALVVTGARAVDHCGVLDALAA...,3.33546
2,VPATTTTATRRLALGEGALGRVPAVLDALGGRPLLVLADAGVAAAA...,4.730563
3,IVSGPGARAAVGDLVAEHGGSRVLVITDPGVAGAGLAPALTGVLEG...,3.036892
4,KPTTVIYDQKALEELEELVEKNGFERPLLVTGRGSFKKSGVYENVM...,3.524508
5,MVTDDTTYAAAAAVVEGLGITAEAIDVAGEGDRKDLTTVDRVWRAA...,3.225808
6,PRIIFGEGAADRAAGYLKSFGKKVFIVTGKGSIKNSGAYDLVSKTL...,3.018368
7,TVSAVESGALAELRGELRDLGAGRVVLVTDENTARSYGERVRETLG...,3.241559
8,SASIEALDAALAERGGGLLLVDSGVLSRLPEELARASRVRGLELAP...,3.711276
9,MSTVHVATGEIAEVLRDLEDAGRERLVVVTDAGLRDAGVAGRVRAV...,3.514427


## Score Sequences

Calculate Hamming distances to context sequences:

In [10]:
# create new column in dataframe
generation_df["Minimum Hamming Distance"] = pd.Series()

for i in tqdm(range(len(generation_df))):

    # calculate pairwise Hamming distances to all context sequences
    all_hamming = []
    for ctx_seq in protein_list:
        hamming, _, _ = align_sequences(ctx_seq, generation_df["Generated Sequence"].iloc[i], print_alignments=False)
        all_hamming.append(hamming)

    # add the Hamming distance to the closest context sequence to the data frame
    min_hamming = np.mean(all_hamming)
    generation_df.loc[i, "Minimum Hamming Distance"] = min_hamming

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:24<00:00,  2.44s/it]


Calculate HMMER scores (only if an MSA is available):

In [11]:
if not msa_path == None:

    # train HMM
    hmm = make_hmm_from_a3m_msa(msa_path)
    
    # score all sequences
    scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=list(generation_df["Generated Sequence"]))

    # add HMMER scores to the data frame
    for seq in list(generation_df["Generated Sequence"]):
        generation_df.loc[generation_df["Generated Sequence"] == seq, "HMMER Score"] = scores[seq]["score"] if seq in scores.keys() else 0

Calculate folding scores (pTM and PLDDT) using ESMFold:

In [12]:
# import the folding model
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", cache_dir="/system/user/publicdata/pxlstm_temp/esm-fold", low_cpu_mem_usage=True)
model = model.cuda(device)
model.esm = model.esm.half()
torch.backends.cuda.matmul.allow_tf32 = True

for seq in tqdm(list(generation_df["Generated Sequence"])):

    # compute structural scores
    ptm, pae, mean_plddt, pos_plddt = compute_structure(seq, model)
    
    # add scores to the data frame
    generation_df.loc[generation_df["Generated Sequence"] == seq, "pTM"] = ptm
    generation_df.loc[generation_df["Generated Sequence"] == seq, "Mean pLDDT"] = mean_plddt

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 10/10 [01:35<00:00,  9.55s/it]


In [13]:
display(generation_df)

Unnamed: 0,Generated Sequence,Perplexity,Minimum Hamming Distance,HMMER Score,pTM,Mean pLDDT
0,VLLVSDTGILNSGVLERIREKLKGLGIKVELFPLPESEPTFQQVEK...,4.358632,0.614762,217.472214,0.961031,0.911549
1,PVTVLSGPDAIARVGDELVEAGAKKALVVTGARAVDHCGVLDALAA...,3.33546,0.620925,232.49736,0.964514,0.930103
2,VPATTTTATRRLALGEGALGRVPAVLDALGGRPLLVLADAGVAAAA...,4.730563,0.722965,173.413864,0.679946,0.589188
3,IVSGPGARAAVGDLVAEHGGSRVLVITDPGVAGAGLAPALTGVLEG...,3.036892,0.551833,150.928665,0.959316,0.943711
4,KPTTVIYDQKALEELEELVEKNGFERPLLVTGRGSFKKSGVYENVM...,3.524508,0.612628,281.122864,0.95696,0.916718
5,MVTDDTTYAAAAAVVEGLGITAEAIDVAGEGDRKDLTTVDRVWRAA...,3.225808,0.555919,139.610992,0.939767,0.91296
6,PRIIFGEGAADRAAGYLKSFGKKVFIVTGKGSIKNSGAYDLVSKTL...,3.018368,0.660326,296.930481,0.937441,0.879868
7,TVSAVESGALAELRGELRDLGAGRVVLVTDENTARSYGERVRETLG...,3.241559,0.620738,274.203796,0.971223,0.942209
8,SASIEALDAALAERGGGLLLVDSGVLSRLPEELARASRVRGLELAP...,3.711276,0.593117,123.15612,0.898233,0.863185
9,MSTVHVATGEIAEVLRDLEDAGRERLVVVTDAGLRDAGVAGRVRAV...,3.514427,0.619098,223.633408,0.962831,0.923391
