In [1]:
# This function is based on ProteinMPNN/Pifold/esm, under the MIT License.
# Source: https://github.com/dauparas/ProteinMPNN, https://github.com/A4Bio/PiFold,https://github.com/facebookresearch/esm

"""
This is a real-world design case for SAR425899, the GLP-1R/GCGR dual agonist SAR425899, with PDB (8JIU, 8JIR). 
It utilizes the PDB processing script (from ProteinMPNN), where the design chain length must be equal. 
Currently, only dual agonist designs are supported. The design process involves setting conserved sites (e.g., 1-13) 
to complete the sequence design. 

Note: Since non-natural amino acids are not supported, non-natural amino acids in the design need to be specially handled, 
such as being set as conserved sites and processed as natural amino acids.
"""

In [2]:
import torch
import numpy as np 
import pandas as pd
import argparse
import os.path
import json, time, os, sys, glob
import shutil
import warnings
from torch import optim
from torch.utils.data import DataLoader
import queue
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
import subprocess
from concurrent.futures import ProcessPoolExecutor    
from utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, StructureDataset, StructureLoader
from model_utils import *

  from .autonotebook import tqdm as notebook_tqdm


# parse PDB files by proteinmpnn helper_scripts

In [None]:
!python ./helper_scripts/parse_multiple_chains.py --input_path=./dual_pdb --output_path=./dual_SAR.jsonl

In [3]:

import esm 
from  esm.model.esm2 import ESM2, ESM2_decoder
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
def esm_model():
    regression_data = torch.load('./esm2_t33_650M_UR50D-contact-regression.pt')
    model_data = torch.load('./esm2_t33_650M_UR50D.pt')
    model_data["model"].update(regression_data["model"])
    alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
    model = ESM2(
        num_layers=33,
        embed_dim=1280,
        attention_heads=20,
        alphabet=alphabet,
        token_dropout=True,
    )
    import re
    def upgrade_state_dict(state_dict):
        """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
        prefixes = ["encoder.sentence_encoder.", "encoder."]
        pattern = re.compile("^" + "|".join(prefixes))
        state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
        return state_dict

    model_data = upgrade_state_dict(model_data["model"])
    model.load_state_dict(model_data, strict=True)
    
    
    decoder = ESM2_decoder(
        num_layers=33,
        embed_dim=1280,
        attention_heads=20,
        alphabet=alphabet,
        token_dropout=True,
    )
    decoder_keys = ['embed_tokens.weight','lm_head.weight', 'lm_head.bias', 'lm_head.dense.weight','lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']
    decoder_data = {}
    for i in decoder_keys:
        decoder_data[i] = model_data[i]
    
    decoder.load_state_dict(decoder_data, strict=True)
    return model, decoder

esm_encoder, esm_decoder = esm_model()
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
esm_encoder.to(device)
print("OK")

OK


In [4]:

def randomize_list(lst, ratio):
    
    new_lst = lst.copy()
    for i in range(len(new_lst)):
        if random.random() < ratio:
            new_lst[i] = random.randint(0, 20)
    return new_lst

def set_nan(arr):
    
    N = arr.shape[0]
    
    
    num_nan = int(0.1 * N)
    
    
    indices = np.random.choice(N, num_nan, replace=False)
    
    arr[indices, :, :] = np.nan
    
    return arr

def featurize(batch,lst_chain ,device = "cpu",is_train = True):
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    Clabel = [0,1,3,3,0,1,2,0,2,0,0,1,0,1,2,1,1,0,0,1,0]
    B = len(batch)
     
    lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
    L_max = max([len(b['seq']) for b in batch])
    X = np.zeros([B, L_max, 4, 3])
    residue_idx = -100*np.ones([B, L_max], dtype=np.int32) #residue idx with jumps across chains
    chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted, 0.0 for the bits that are given
    mask_self = np.ones([B, L_max, L_max], dtype=np.int32) #for interface loss calculation - 0.0 for self interaction, 1.0 for other
    chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #integer encoding for chains 0, 0, 0,...0, 1, 1,..., 1, 2, 2, 2...
    S = np.zeros([B, L_max], dtype=np.int32) #sequence AAs integers
    S_noise = np.zeros([B, L_max], dtype=np.int32) #sequence AAs integers
    S_s = np.zeros([B, L_max], dtype=np.int32)
    mask_rp = np.zeros([B, L_max], dtype=np.int32)
    init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
    extra_alphabet = [str(item) for item in list(np.arange(300))]
    chain_letters = init_alphabet + extra_alphabet
    ids = 0
    lst_seq_str = []
    for i, b in enumerate(batch):
        #print(b)
        lst_seq_str_1 = []
        masked_chains = b['masked_list']
        visible_chains = b['visible_list']
        all_chains = masked_chains + visible_chains
        visible_temp_dict = {}
        masked_temp_dict = {}
        for step, letter in enumerate(all_chains):
            chain_seq = b[f'seq_chain_{letter}']
            if letter in visible_chains:
                visible_temp_dict[letter] = chain_seq
            elif letter in masked_chains:
                masked_temp_dict[letter] = chain_seq
        for km, vm in masked_temp_dict.items():
            for kv, vv in visible_temp_dict.items():
                if vm == vv:
                    if kv not in masked_chains:
                        masked_chains.append(kv)
                    if kv in visible_chains:
                        visible_chains.remove(kv)
        all_chains = masked_chains + visible_chains
        index_of_a = all_chains.index(lst_chain[ids])
        all_chains.insert(0, all_chains.pop(index_of_a))
        
        
        #random.shuffle(all_chains) #randomly shuffle chain order
        num_chains = b['num_of_chains']
        mask_dict = {}
        x_chain_list = []
        chain_mask_list = []
        chain_seq_list = []
        chain_encoding_list = []
        c = 1
        l0 = 0
        l1 = 0
        for step, letter in enumerate(all_chains):
            if letter != lst_chain[ids]:
                chain_seq = b[f'seq_chain_{letter}']
                
                lst_seq_str_1.append(chain_seq)
                chain_length = len(chain_seq)
                chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
                chain_mask = np.zeros(chain_length) #0.0 for visible chains
                
                
                x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
                #if is_train:
                    #x_chain = set_nan(x_chain)
                
                x_chain_list.append(x_chain)
                chain_mask_list.append(chain_mask)
                chain_seq_list.append(chain_seq)
                chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
                l1 += chain_length
                mask_self[i, l0:l1, l0:l1] = np.zeros([chain_length, chain_length])
                residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
                l0 += chain_length
                c+=1
            else: 
                
                chain_seq = b[f'seq_chain_{letter}']
                
                lst_seq_str_1.append(chain_seq)
                
                chain_length = len(chain_seq)
                chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
                chain_mask = np.ones(chain_length) #0.0 for visible chains
                x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
                #if is_train:
                    #x_chain = set_nan(x_chain)
                
                x_chain_list.append(x_chain)
                chain_mask_list.append(chain_mask)
                chain_seq_list.append(chain_seq)
                chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
                l1 += chain_length
                mask_self[i, l0:l1, l0:l1] = np.zeros([chain_length, chain_length])
                residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
                l0 += chain_length
                c+=1
        x = np.concatenate(x_chain_list,0) #[L, 4, 3]
        all_sequence = "".join(chain_seq_list)
        m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
        chain_encoding = np.concatenate(chain_encoding_list,0)

        l = len(all_sequence)
        
        
        all_sequence = list(all_sequence)
        for aas in range(len(all_sequence)):
            if all_sequence[aas] not in alphabet:
                all_sequence[aas] = "X"
        all_sequence = "".join(all_sequence)
        
        x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
        X[i,:,:,:] = x_pad

        m_pad = np.pad(m, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
        chain_M[i,:] = m_pad

        chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
        chain_encoding_all[i,:] = chain_encoding_pad

        # Convert to labels
        indices = np.asarray([alphabet.index(a) for a in all_sequence], dtype=np.int32)
        
        S_s_s = []
        for ids_aa in indices:
            S_s_s.append(Clabel[ids_aa])
        S[i, :l] = indices
        
        S_noise[i, :l] = randomize_list(indices, 0.1)
        S_s[i, :l] = S_s_s
        
        mask_rp[i,:l] = np.ones([l], dtype=np.int32)
        
        lst_seq_str.append(lst_seq_str_1)
        
        ids+=1

    isnan = np.isnan(X)
    mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
    X[isnan] = 0.

        

    # Conversion
    mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)
    mask_rp = torch.from_numpy(mask_rp).to(dtype=torch.float32, device=device)
    
    residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long,device=device)
    S = torch.from_numpy(S).to(dtype=torch.long,device=device)
    S_noise = torch.from_numpy(S_noise).to(dtype=torch.long,device=device)
    S_s = torch.from_numpy(S_s).to(dtype=torch.long,device=device)
    X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
    
    mask_self = torch.from_numpy(mask_self).to(dtype=torch.float32, device=device)
    chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32, device=device)
    chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long, device=device)
    return X, S, mask,mask_rp, lengths, chain_M, residue_idx, mask_self, chain_encoding_all,lst_seq_str

In [5]:
from produalnet_main import ProDualNet

In [6]:
    import argparse
    import os.path

    import json, time, os, sys, glob
    import shutil
    import warnings
    import numpy as np
    import torch
    from torch import optim
    from torch.utils.data import DataLoader
    import queue
    import copy
    import torch.nn as nn
    import torch.nn.functional as F
    import random
    import os.path
    import subprocess
    from concurrent.futures import ProcessPoolExecutor    

     
    
    
    
    PATH = "./esm_test128/model_weights/best_esm.pt"#"./produalnet_esm.pt"
    
    model = ProDualNet(node_features=128, 
                        edge_features=128, 
                        hidden_dim=128, 
                        num_encoder_layers=4, 
                        num_decoder_layers=4, 
                        k_neighbors=32, 
                        dropout=0.1, 
                        augment_eps=0.2)
    model.to(device)


    if PATH:
        
        checkpoint = torch.load(PATH, map_location=device)
        total_step = checkpoint['step'] #write total_step from the checkpoint
        epoch = checkpoint['epoch'] #write epoch from the checkpoint
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        print("load ok----------")
    else:
        total_step = 0
        epoch = 0

    #optimizer = torch.optim.Adam(model.parameters(), lr=0.00005, betas=(0.9, 0.98), eps=1e-9)#




load ok----------


In [7]:

def seq_esm_embed_func(lst_seq,L,esm_encoder,alphabet):
    lst_embed = []
    esm_encoder.eval()
    for i in range(len(lst_seq)):
        L1 = 0
        lst_embed_1 = []
        seq_lst = lst_seq[i]
        L1 = L1+len(seq_lst[0])
        lst_embed_1.append(torch.zeros(len(seq_lst[0]),1280))# designing sequences embedding is zeros
        seq_l = ""
        for j in range(len(seq_lst)-1):
            seq1 = seq_lst[j+1]
            L1 = L1+len(seq_lst[j+1])
            seq_l = seq_l+seq1
        S1 = alphabet.get_batch_converter()([[1,seq_l]])[-1].to(device)
        with torch.no_grad():
            S_embedding = esm_encoder(S1,[33])
            S_embedding = S_embedding["representations"][33][0,1:-1]
            lst_embed_1.append(S_embedding.cpu())
        lst_embed_1.append(torch.zeros(int(L-L1),1280))
        #print(torch.cat(lst_embed_1,0).shape)
        lst_embed.append(torch.cat(lst_embed_1,0))
    
    return torch.stack(lst_embed).to(device)

def seq_esm_embed_func_re(lst_seq,L,esm_encoder,alphabet):
    lst_embed = []
    esm_encoder.eval()
    #for i in range(len(lst_seq)):
        #lst_seq[i][0] = lst_seq[i][lst_pre_re]
    for i in range(len(lst_seq)):
        L1 = 0
        lst_embed_1 = []
        seq_lst = lst_seq[i]
        #L1 = L1+len(seq_lst[0])
        #lst_embed_1.append(torch.zeros(len(seq_lst[0]),1280))
        seq_l = ""
        for j in range(len(seq_lst)):
            seq1 = seq_lst[j]
            L1 = L1+len(seq_lst[j])
            seq_l = seq_l+seq1
        S1 = alphabet.get_batch_converter()([[1,seq_l]])[-1].to(device)
        with torch.no_grad():
            S_embedding = esm_encoder(S1,[33])
            S_embedding = S_embedding["representations"][33][0,1:-1]
            lst_embed_1.append(S_embedding.cpu())
        lst_embed_1.append(torch.zeros(int(L-L1),1280))
        #print(torch.cat(lst_embed_1,0).shape)
        lst_embed.append(torch.cat(lst_embed_1,0))
    
    return torch.stack(lst_embed).to(device)

def find_first_last_one(lst):
    first_one = -1
    last_one = -1
    
    for i, value in enumerate(lst):
        if value == 1:
            if first_one == -1:
                first_one = i
            last_one = i
            
    return first_one, last_one
def indices_to_chars(indices):
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    return ''.join([alphabet[int(i)] for i in indices])
def seqs_get_signle(lst_x,lst_seq_pre):
    length_lst = len(lst_x)
    lst_seq_all = []

    for i in range(length_lst):
        lst_seq_1 = []
        masked_list = lst_x[i]['masked_list']
        masked_seq_name = "seq_chain_"+masked_list[0]
        masked_seq = lst_x[i][masked_seq_name]
        length_seq = len(masked_seq)
        #a,b = find_first_last_one(lst_mask_chain[i])
        seq_pre_AA = lst_seq_pre[i][:length_seq]
        ######
        #seq_pre_str = indices_to_chars(seq_pre_AA[a:b+1])
        seq_pre_str = indices_to_chars(seq_pre_AA)
        ######
        lst_seq_1.append(seq_pre_str)


        lst_seq_keys_1 = []
        for keys_1 in lst_x[i]:
            if "seq_chain_" in keys_1 and keys_1 != masked_seq_name:
                lst_seq_keys_1.append(keys_1)

        for keys_seq in lst_seq_keys_1:
            lst_seq_1.append(lst_x[i][keys_seq])

        lst_seq_all.append(lst_seq_1)

    return lst_seq_all

def seq_single_complex_cross(lst_seq_all):
    lengths_all = len(lst_seq_all)
    lengths_all = int(lengths_all//2)
    ns = 0
    lst_cross = []
    for i in range(lengths_all):
        lst1 = lst_seq_all[i*2]
        lst2 = lst_seq_all[i * 2 + 1]

        lst_3 = []
        lst_4 = []
        lst_3.append(lst2[0])
        lst_4.append(lst1[0])

        for l in lst1[1:]:
            if "X" in l:
                print("a")
            lst_3.append(l)
        for l in lst2[1:]:
            if "X" in l:
                print("a")
            lst_4.append(l)
        lst_cross.append(lst_3)
        lst_cross.append(lst_4)

    return lst_cross


In [8]:
import json
data_dual = []
with open('./dual_SAR.jsonl', 'r', encoding='utf-8') as file:
    for line in file:
        # 解析每一行的JSON数据
        data1 = json.loads(line)
        data1["masked_list"] = ["P"]
        data1["visible_list"] = ["R"]
        data1["name"] = "dual_design"
        data_dual.append( data1)

dual target protein design

In [None]:
def recycle_esm(seq1, mask, chain_M, S, lst_x, esm_encoder, alphabet):
    """ Function to refine sequences using the ESM encoder """
    lst_seq_pre = []
    for i in range(len(mask)):
        ks = i if i % 2 == 0 else i - 1
        
        # Append sequence to the list (only the first item in seq1 is used for now)
        lst_seq_pre.append(seq1.cpu()[0])
                    
    # Generate MPNN predictions using the previously obtained sequences
    lst_pre_mpnn = seqs_get_signle(lst_x, lst_seq_pre)
    
    # Compute the ESM embeddings for sequences
    S_embed = seq_esm_embed_func_re(lst_pre_mpnn, S.shape[-1], esm_encoder, alphabet)
    
    return S_embed


def _scores(S, log_probs, mask):
    """ Calculate the negative log probabilities (NLLLoss) """
    criterion = torch.nn.NLLLoss(reduction='none')
    
    # Flatten log_probs and S for loss calculation
    loss = criterion(log_probs.contiguous().view(-1, log_probs.size(-1)),
                     S.contiguous().view(-1)).view(S.size())
    
    # Calculate scores by summing the loss weighted by the mask
    scores = torch.sum(loss * mask, dim=-1) / torch.sum(mask, dim=-1)
    
    return scores


def run_design_sequence(model, data_dual, esm_encoder, alphabet, device, temp_a=0.1, design_num=10, recycle_num=1, conserved_sites=[]):
    """Main validation loop for structure prediction and optimization"""
    model.eval()

    lst_x = []  # List to store input sequences
    lst_seq_pre = []  # List to store predicted sequences
    lst_acc = []  # List to store accuracy
    lst_score = []  # List to store scores
    lst_name = []  # List to store names of sequences
    ns = 0  # Batch counter

    with torch.no_grad():
        # Feature extraction for the entire batch
        X, S, mask, mask_train, lengths, chain_M, residue_idx, mask_self, chain_encoding_all, S_lst = \
            featurize(data_dual, ["P", "P"], device, is_train=False)

        # Generate initial ESM embeddings
        S_embed = seq_esm_embed_func(S_lst, S.shape[-1], esm_encoder, alphabet)

        # Modified chain mask
        chain_M_pos = mask.clone()
        
        # Apply the conservation sites (set them to 0)
        for site in conserved_sites:
            chain_M_pos[0][site] = 0
            chain_M_pos[1][site] = 0

        B = 2  # Batch size
        for sdp in range(design_num):
            lst_x_r = []
            randn = torch.randn(chain_M.shape).to(device)  # Random noise tensor for sampling

            # Generate initial sequences using model's sampling method
            S_pre, log_probs = model.sample(
                X, randn, S, S_embed, chain_M, chain_encoding_all, residue_idx, mask=mask,
                temperature=temp_a, chain_M_pos=chain_M_pos, mask_train=mask_train
            )

            # Store input data and names
            for i in data_dual:
                lst_name.append(f"{i['name']}_{ns}_{sdp}")
                lst_x.append(i)
                lst_x_r.append(i)

            # Loop for recycling and refinement
            for _ in range(recycle_num):
                S_embed_r = recycle_esm(S_pre, mask, chain_M, S, lst_x_r, esm_encoder, alphabet)
                S_pre, log_probs = model.sample(
                    X, randn, S, S_embed_r, chain_M, chain_encoding_all, residue_idx, mask=mask,
                    temperature=temp_a, chain_M_pos=chain_M_pos, mask_train=mask_train
                )

            # Store predicted sequences
            for ss in range(len(mask)):
                lst_seq_pre.append(S_pre[ss].cpu())

            # Calculate the scores and accuracies
            mask_for_loss = mask * chain_M
            S1 = S.reshape(B // 2, 2, -1)[:, 1]
            mask_for_loss1 = mask_for_loss.reshape(B // 2, 2, -1)[:, 1]

            # Design score
            score_design = _scores(S1, log_probs[0][None, :, :], mask_for_loss1)
            lst_score.append(score_design[0])

            # Compute accuracy
            acc_design = torch.sum(
                torch.sum(torch.nn.functional.one_hot(S[0], 21) * torch.nn.functional.one_hot(S_pre[0], 21), axis=-1) * mask_for_loss[0]
            ) / torch.sum(mask_for_loss[0])
            lst_acc.append(acc_design.cpu())

        ns += 1  # Increment batch counter

    # Return the final results: lst_x, lst_seq_pre, lst_acc, lst_score, lst_name
    return lst_x, lst_seq_pre, lst_acc, lst_score, lst_name


# Assuming these are predefined functions and variables from your environment:
# - model: your trained model
# - StructureLoader_test: your test data loader
# - esm_encoder: the ESM encoder model
# - alphabet: the sequence alphabet (e.g., amino acid alphabet)
# - device: the computing device (CPU or GPU)
# - temp_a: temperature for sampling, default is 0.1
# - design_num: number of designs to generate per case
# - recycle_num: number of recycling iterations
# - conserved_sites: conserved_sites in design
# Example usage:
conserved_sites = [0,1,2,3,4,5,6,7,8]
lst_x,lst_seq_pre, lst_acc, lst_score,lst_name = run_design_sequence(model, data_dual, esm_encoder, alphabet,\
                                                                     device, temp_a=0.1, design_num=1000, recycle_num=1,conserved_sites = conserved_sites)




Save design sequence

In [12]:
path_fasta = "design_sequence_SAR_1000.fasta"

In [None]:
def find_first_last_one(lst):
    first_one = -1
    last_one = -1
    
    for i, value in enumerate(lst):
        if value == 1:
            if first_one == -1:
                first_one = i
            last_one = i
            
    return first_one, last_one
def indices_to_chars(indices):
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    return ''.join([alphabet[int(i)] for i in indices])
def seqs_get_signle(lst_x,lst_seq_pre):
    length_lst = len(lst_x)
    lst_seq_all = []

    for i in range(length_lst):
        lst_seq_1 = []
        masked_list = lst_x[i]['masked_list']
        masked_seq_name = "seq_chain_"+masked_list[0]
        masked_seq = lst_x[i][masked_seq_name]
        length_seq = len(masked_seq)
        #a,b = find_first_last_one(lst_mask_chain[i])
        seq_pre_AA = lst_seq_pre[i][:length_seq]
        ######
        #seq_pre_str = indices_to_chars(seq_pre_AA[a:b+1])
        seq_pre_str = indices_to_chars(seq_pre_AA)
        ######
        lst_seq_1.append(seq_pre_str)


        lst_seq_keys_1 = []
        for keys_1 in lst_x[i]:
            if "seq_chain_" in keys_1 and keys_1 != masked_seq_name:
                lst_seq_keys_1.append(keys_1)

        for keys_seq in lst_seq_keys_1:
            lst_seq_1.append(lst_x[i][keys_seq])

        lst_seq_all.append(lst_seq_1)

    return lst_seq_all

def seq_single_complex_cross(lst_seq_all):
    lengths_all = len(lst_seq_all)
    lengths_all = int(lengths_all//2)
    ns = 0
    lst_cross = []
    for i in range(lengths_all):
        lst1 = lst_seq_all[i*2]
        lst2 = lst_seq_all[i * 2 + 1]

        lst_3 = []
        lst_4 = []
        lst_3.append(lst2[0])
        lst_4.append(lst1[0])

        for l in lst1[1:]:
            if "X" in l:
                print("a")
            lst_3.append(l)
        for l in lst2[1:]:
            if "X" in l:
                print("a")
            lst_4.append(l)
        lst_cross.append(lst_3)
        lst_cross.append(lst_4)

    return lst_cross

#def get_fasta_lst(lst):
lst_pre_mpnn =seqs_get_signle(lst_x,lst_seq_pre)
with open(path_fasta, "w") as file:
     
    for i in range(len(lst_name)):
        
        file.write(f">{lst_name[i]} \n")
        seqs = ""
        for s in lst_pre_mpnn[i]:
            
            all_sequence = list(s)
            for aas in range(len(all_sequence)):
                if all_sequence[aas] == "X":
                    all_sequence[aas] = "A"
            all_sequence = "".join(all_sequence)
            
            seqs = seqs+all_sequence+":"
        file.write(seqs[:-1] + "\n")