In [1]:
# This function is based on ProteinMPNN/Pifold/esm, under the MIT License.
# Source: https://github.com/dauparas/ProteinMPNN, https://github.com/A4Bio/PiFold,https://github.com/facebookresearch/esm

In [2]:
import torch
import numpy as np 
import pandas as pd
import argparse
import os.path
import json, time, os, sys, glob
import shutil
import warnings
from torch import optim
from torch.utils.data import DataLoader
import queue
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
import subprocess
from concurrent.futures import ProcessPoolExecutor    
from utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, StructureDataset, StructureLoader
from model_utils import featurize, loss_smoothed, loss_nll, get_std_opt
from model_utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

import esm 
from  esm.model.esm2 import ESM2, ESM2_decoder
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
def esm_model():
    regression_data = torch.load('./esm2_t33_650M_UR50D-contact-regression.pt')
    model_data = torch.load('./esm2_t33_650M_UR50D.pt')
    model_data["model"].update(regression_data["model"])
    alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
    model = ESM2(
        num_layers=33,
        embed_dim=1280,
        attention_heads=20,
        alphabet=alphabet,
        token_dropout=True,
    )
    import re
    def upgrade_state_dict(state_dict):
        """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
        prefixes = ["encoder.sentence_encoder.", "encoder."]
        pattern = re.compile("^" + "|".join(prefixes))
        state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
        return state_dict

    model_data = upgrade_state_dict(model_data["model"])
    model.load_state_dict(model_data, strict=True)
    
    
    decoder = ESM2_decoder(
        num_layers=33,
        embed_dim=1280,
        attention_heads=20,
        alphabet=alphabet,
        token_dropout=True,
    )
    decoder_keys = ['embed_tokens.weight','lm_head.weight', 'lm_head.bias', 'lm_head.dense.weight','lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']
    decoder_data = {}
    for i in decoder_keys:
        decoder_data[i] = model_data[i]
    
    decoder.load_state_dict(decoder_data, strict=True)
    return model, decoder

esm_encoder, esm_decoder = esm_model()
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
esm_encoder.to(device)
print("OK")

OK


In [4]:

def randomize_list(lst, ratio):
    
    new_lst = lst.copy()
    for i in range(len(new_lst)):
        if random.random() < ratio:
            new_lst[i] = random.randint(0, 20)
    return new_lst

def set_nan(arr):
    
    N = arr.shape[0]
    
    
    num_nan = int(0.1 * N)
    
    
    indices = np.random.choice(N, num_nan, replace=False)
    
    arr[indices, :, :] = np.nan
    
    return arr

def featurize(batch,lst_chain ,device = "cpu",is_train = True):
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    Clabel = [0,1,3,3,0,1,2,0,2,0,0,1,0,1,2,1,1,0,0,1,0]
    B = len(batch)
     
    lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
    L_max = max([len(b['seq']) for b in batch])
    X = np.zeros([B, L_max, 4, 3])
    residue_idx = -100*np.ones([B, L_max], dtype=np.int32) #residue idx with jumps across chains
    chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted, 0.0 for the bits that are given
    mask_self = np.ones([B, L_max, L_max], dtype=np.int32) #for interface loss calculation - 0.0 for self interaction, 1.0 for other
    chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #integer encoding for chains 0, 0, 0,...0, 1, 1,..., 1, 2, 2, 2...
    S = np.zeros([B, L_max], dtype=np.int32) #sequence AAs integers
    S_noise = np.zeros([B, L_max], dtype=np.int32) #sequence AAs integers
    S_s = np.zeros([B, L_max], dtype=np.int32)
    mask_rp = np.zeros([B, L_max], dtype=np.int32)
    init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
    extra_alphabet = [str(item) for item in list(np.arange(300))]
    chain_letters = init_alphabet + extra_alphabet
    ids = 0
    lst_seq_str = []
    for i, b in enumerate(batch):
        #print(b)
        lst_seq_str_1 = []
        masked_chains = b['masked_list']
        visible_chains = b['visible_list']
        all_chains = masked_chains + visible_chains
        visible_temp_dict = {}
        masked_temp_dict = {}
        for step, letter in enumerate(all_chains):
            chain_seq = b[f'seq_chain_{letter}']
            if letter in visible_chains:
                visible_temp_dict[letter] = chain_seq
            elif letter in masked_chains:
                masked_temp_dict[letter] = chain_seq
        for km, vm in masked_temp_dict.items():
            for kv, vv in visible_temp_dict.items():
                if vm == vv:
                    if kv not in masked_chains:
                        masked_chains.append(kv)
                    if kv in visible_chains:
                        visible_chains.remove(kv)
        all_chains = masked_chains + visible_chains
        index_of_a = all_chains.index(lst_chain[ids])
        all_chains.insert(0, all_chains.pop(index_of_a))
        
        
        #random.shuffle(all_chains) #randomly shuffle chain order
        num_chains = b['num_of_chains']
        mask_dict = {}
        x_chain_list = []
        chain_mask_list = []
        chain_seq_list = []
        chain_encoding_list = []
        c = 1
        l0 = 0
        l1 = 0
        for step, letter in enumerate(all_chains):
            if letter != lst_chain[ids]:
                chain_seq = b[f'seq_chain_{letter}']
                
                lst_seq_str_1.append(chain_seq)
                chain_length = len(chain_seq)
                chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
                chain_mask = np.zeros(chain_length) #0.0 for visible chains
                
                
                x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
                #if is_train:
                    #x_chain = set_nan(x_chain)
                
                x_chain_list.append(x_chain)
                chain_mask_list.append(chain_mask)
                chain_seq_list.append(chain_seq)
                chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
                l1 += chain_length
                mask_self[i, l0:l1, l0:l1] = np.zeros([chain_length, chain_length])
                residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
                l0 += chain_length
                c+=1
            else: 
                
                chain_seq = b[f'seq_chain_{letter}']
                
                lst_seq_str_1.append(chain_seq)
                
                chain_length = len(chain_seq)
                chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
                chain_mask = np.ones(chain_length) #0.0 for visible chains
                x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
                #if is_train:
                    #x_chain = set_nan(x_chain)
                
                x_chain_list.append(x_chain)
                chain_mask_list.append(chain_mask)
                chain_seq_list.append(chain_seq)
                chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
                l1 += chain_length
                mask_self[i, l0:l1, l0:l1] = np.zeros([chain_length, chain_length])
                residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
                l0 += chain_length
                c+=1
        x = np.concatenate(x_chain_list,0) #[L, 4, 3]
        all_sequence = "".join(chain_seq_list)
        m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
        chain_encoding = np.concatenate(chain_encoding_list,0)

        l = len(all_sequence)
        
        
        all_sequence = list(all_sequence)
        for aas in range(len(all_sequence)):
            if all_sequence[aas] not in alphabet:
                all_sequence[aas] = "X"
        all_sequence = "".join(all_sequence)
        
        x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
        X[i,:,:,:] = x_pad

        m_pad = np.pad(m, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
        chain_M[i,:] = m_pad

        chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
        chain_encoding_all[i,:] = chain_encoding_pad

        # Convert to labels
        indices = np.asarray([alphabet.index(a) for a in all_sequence], dtype=np.int32)
        
        S_s_s = []
        for ids_aa in indices:
            S_s_s.append(Clabel[ids_aa])
        S[i, :l] = indices
        
        S_noise[i, :l] = randomize_list(indices, 0.1)
        S_s[i, :l] = S_s_s
        
        mask_rp[i,:l] = np.ones([l], dtype=np.int32)
        
        lst_seq_str.append(lst_seq_str_1)
        
        ids+=1

    isnan = np.isnan(X)
    mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
    X[isnan] = 0.

        

    # Conversion
    mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)
    mask_rp = torch.from_numpy(mask_rp).to(dtype=torch.float32, device=device)
    
    residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long,device=device)
    S = torch.from_numpy(S).to(dtype=torch.long,device=device)
    S_noise = torch.from_numpy(S_noise).to(dtype=torch.long,device=device)
    S_s = torch.from_numpy(S_s).to(dtype=torch.long,device=device)
    X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
    
    mask_self = torch.from_numpy(mask_self).to(dtype=torch.float32, device=device)
    chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32, device=device)
    chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long, device=device)
    return X, S, mask,mask_rp, lengths, chain_M, residue_idx, mask_self, chain_encoding_all,lst_seq_str

In [5]:
from produalnet_main import ProDualNet

In [6]:
    import argparse
    import os.path

    import json, time, os, sys, glob
    import shutil
    import warnings
    import numpy as np
    import torch
    from torch import optim
    from torch.utils.data import DataLoader
    import queue
    import copy
    import torch.nn as nn
    import torch.nn.functional as F
    import random
    import os.path
    import subprocess
    from concurrent.futures import ProcessPoolExecutor    

     
    
    
    
    PATH = "./esm_test128/model_weights/best_esm.pt"#"./produalnet_esm.pt"
    
    model = ProDualNet(node_features=128, 
                        edge_features=128, 
                        hidden_dim=128, 
                        num_encoder_layers=4, 
                        num_decoder_layers=4, 
                        k_neighbors=32, 
                        dropout=0.1, 
                        augment_eps=0.2)
    model.to(device)


    if PATH:
        
        checkpoint = torch.load(PATH, map_location=device)
        total_step = checkpoint['step'] #write total_step from the checkpoint
        epoch = checkpoint['epoch'] #write epoch from the checkpoint
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        print("load ok----------")
    else:
        total_step = 0
        epoch = 0

    #optimizer = torch.optim.Adam(model.parameters(), lr=0.00005, betas=(0.9, 0.98), eps=1e-9)#




load ok----------


In [7]:
import random
class StructureLoader():
    def __init__(self, dataset, filtered_dict, batch_size=5000, shuffle=True,
        collate_fn=lambda x:x, drop_last=False):
    
        self.dataset = dataset
        self.batch_size = batch_size
        self.filtered_dict = filtered_dict
        self.size = list(filtered_dict.keys())
        #print(len(self.size))
        self.lengths = [len(dataset[i]['seq']) for i in self.size]
        #print(len(self.lengths))
        self.batch_size = batch_size
        sorted_ix = np.argsort(self.lengths)
        #print(len(sorted_ix))
         

        # Cluster into batches of similar sizes
        clusters, batch = [], []
        batch_max = 0
        for ix in sorted_ix:
 
                    batch.append(self.size[ix])
                    clusters.append(batch)
                    #batch_max = size
                    batch, batch_max = [], 0
               
                
                
        if len(batch) > 0:
            clusters.append(batch)
        self.clusters = clusters

    def __len__(self):
        return len(self.clusters)

    def __iter__(self):
        #np.random.shuffle(self.clusters)
        for b_idx in self.clusters:
            #print(b_idx)
            batch = []
            lst_chain = []
            length_batch = 0
            max_length_batch = 0
            for idx in b_idx:
                #print(self.filtered_dict.keys())
                bb_idx = self.filtered_dict[idx]#.tolist()
                bb_l = len(bb_idx)
                bb_lst = list(bb_idx.keys())
                bb_rand = random.randint(0, bb_l-1)
                bb_rand = 0
                max_length_batch1 = max([max_length_batch,len(self.dataset[idx]["seq"]),len(self.dataset[bb_lst[bb_rand]]["seq"])])
                if (length_batch+1)*max_length_batch1 < self.batch_size*15:
                    length_batch = length_batch+1
                    max_length_batch = max([max_length_batch,max_length_batch1])
                    batch.append(self.dataset[idx])
                    batch.append(self.dataset[bb_lst[bb_rand]])
                    lst_chain.append(bb_idx[bb_lst[bb_rand]][0][0])
                    lst_chain.append(bb_idx[bb_lst[bb_rand]][1][0])
                    #if self.dataset[idx]["name"] == "2xpx_A":
                        #print(self.dataset[idx]["seq_chain_"+bb_idx[bb_lst[bb_rand]][0][0]])
                        #print(self.dataset[bb_lst[bb_rand]]["seq_chain_"+bb_idx[bb_lst[bb_rand]][1][0]])
                        #print(self.dataset[bb_lst[bb_rand]]["name"])
                    #if self.dataset[idx]["seq_chain_"+bb_idx[bb_lst[bb_rand]][0][0]] \
                    #!= self.dataset[bb_lst[bb_rand]]["seq_chain_"+bb_idx[bb_lst[bb_rand]][1][0]]:
                        #print(11111111)
                    #if self.dataset[idx]["masked_list"] != bb_idx[bb_lst[bb_rand]][0] or \
                    #self.dataset[bb_lst[bb_rand]]["masked_list"]!= bb_idx[bb_lst[bb_rand]][1]:
                        #print(111111111111)
                    #print(self.dataset[idx]["masked_list"] ,bb_idx[bb_lst[bb_rand]][0],self.dataset[bb_lst[bb_rand]]["masked_list"],bb_idx[bb_lst[bb_rand]][1])
            #batch = [self.dataset[i] for i in b_idx]
            yield batch,lst_chain

In [8]:
x_test = torch.load("./x_test_multi.pt")
data_set = "test1"
if data_set == "test1":
    dict_x_test = torch.load("./dict_x_test_30_159.pt")
    StructureLoader_test = StructureLoader(x_test,dict_x_test)
elif data_set == "test2":
    dict_x_test = torch.load("./dict_x_test_sim_50_rmsd_2.pt")
    StructureLoader_test = StructureLoader(x_test,dict_x_test)
elif data_set == "test3":
    StructureLoader_test = torch.load("./lst_diff_inter_38_data.pt")

In [9]:

def seq_esm_embed_func(lst_seq,L,esm_encoder,alphabet):
    lst_embed = []
    esm_encoder.eval()
    for i in range(len(lst_seq)):
        L1 = 0
        lst_embed_1 = []
        seq_lst = lst_seq[i]
        L1 = L1+len(seq_lst[0])
        lst_embed_1.append(torch.zeros(len(seq_lst[0]),1280))# designing sequences embedding is zeros
        seq_l = ""
        for j in range(len(seq_lst)-1):
            seq1 = seq_lst[j+1]
            L1 = L1+len(seq_lst[j+1])
            seq_l = seq_l+seq1
        S1 = alphabet.get_batch_converter()([[1,seq_l]])[-1].to(device)
        with torch.no_grad():
            S_embedding = esm_encoder(S1,[33])
            S_embedding = S_embedding["representations"][33][0,1:-1]
            lst_embed_1.append(S_embedding.cpu())
        lst_embed_1.append(torch.zeros(int(L-L1),1280))
        #print(torch.cat(lst_embed_1,0).shape)
        lst_embed.append(torch.cat(lst_embed_1,0))
    
    return torch.stack(lst_embed).to(device)

def seq_esm_embed_func_re(lst_seq,L,esm_encoder,alphabet):
    lst_embed = []
    esm_encoder.eval()
    #for i in range(len(lst_seq)):
        #lst_seq[i][0] = lst_seq[i][lst_pre_re]
    for i in range(len(lst_seq)):
        L1 = 0
        lst_embed_1 = []
        seq_lst = lst_seq[i]
        #L1 = L1+len(seq_lst[0])
        #lst_embed_1.append(torch.zeros(len(seq_lst[0]),1280))
        seq_l = ""
        for j in range(len(seq_lst)):
            seq1 = seq_lst[j]
            L1 = L1+len(seq_lst[j])
            seq_l = seq_l+seq1
        S1 = alphabet.get_batch_converter()([[1,seq_l]])[-1].to(device)
        with torch.no_grad():
            S_embedding = esm_encoder(S1,[33])
            S_embedding = S_embedding["representations"][33][0,1:-1]
            lst_embed_1.append(S_embedding.cpu())
        lst_embed_1.append(torch.zeros(int(L-L1),1280))
        #print(torch.cat(lst_embed_1,0).shape)
        lst_embed.append(torch.cat(lst_embed_1,0))
    
    return torch.stack(lst_embed).to(device)

def find_first_last_one(lst):
    first_one = -1
    last_one = -1
    
    for i, value in enumerate(lst):
        if value == 1:
            if first_one == -1:
                first_one = i
            last_one = i
            
    return first_one, last_one
def indices_to_chars(indices):
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    return ''.join([alphabet[int(i)] for i in indices])
def seqs_get_signle(lst_x,lst_seq_pre):
    length_lst = len(lst_x)
    lst_seq_all = []

    for i in range(length_lst):
        lst_seq_1 = []
        masked_list = lst_x[i]['masked_list']
        masked_seq_name = "seq_chain_"+masked_list[0]
        masked_seq = lst_x[i][masked_seq_name]
        length_seq = len(masked_seq)
        #a,b = find_first_last_one(lst_mask_chain[i])
        seq_pre_AA = lst_seq_pre[i][:length_seq]
        ######
        #seq_pre_str = indices_to_chars(seq_pre_AA[a:b+1])
        seq_pre_str = indices_to_chars(seq_pre_AA)
        ######
        lst_seq_1.append(seq_pre_str)


        lst_seq_keys_1 = []
        for keys_1 in lst_x[i]:
            if "seq_chain_" in keys_1 and keys_1 != masked_seq_name:
                lst_seq_keys_1.append(keys_1)

        for keys_seq in lst_seq_keys_1:
            lst_seq_1.append(lst_x[i][keys_seq])

        lst_seq_all.append(lst_seq_1)

    return lst_seq_all

def seq_single_complex_cross(lst_seq_all):
    lengths_all = len(lst_seq_all)
    lengths_all = int(lengths_all//2)
    ns = 0
    lst_cross = []
    for i in range(lengths_all):
        lst1 = lst_seq_all[i*2]
        lst2 = lst_seq_all[i * 2 + 1]

        lst_3 = []
        lst_4 = []
        lst_3.append(lst2[0])
        lst_4.append(lst1[0])

        for l in lst1[1:]:
            if "X" in l:
                print("a")
            lst_3.append(l)
        for l in lst2[1:]:
            if "X" in l:
                print("a")
            lst_4.append(l)
        lst_cross.append(lst_3)
        lst_cross.append(lst_4)

    return lst_cross


#Unconditional sequence prediction without context, temperature

In [10]:
model.eval()
lst_acc = []
lst_name = []
recycle_num = 1
with torch.no_grad():
    validation_sum, validation_weights = 0., 0.
    validation_acc = 0.
    lst_x_all = []
    lst_seq_pre_all = []
    
    for _, batch in enumerate(StructureLoader_test):
        lst_x = []
        X, S, mask, mask_train, lengths, chain_M, residue_idx, mask_self, chain_encoding_all, S_lst = featurize(
            batch[0], batch[1], device, is_train=False)
        
        # Collect names and sequences
        for i in batch[0]:
            lst_name.append(i["name"])
            lst_x.append(i)
            lst_x_all.append(i)
        
        # Sequence embedding
        S_embed = seq_esm_embed_func(S_lst, S.shape[-1], esm_encoder, alphabet)

        B = S.shape[0]
        
        # Forward pass to get log probabilities
        log_probs = model(X, S, S_embed, mask, mask_train, chain_M, residue_idx, chain_encoding_all, is_eval=True)
        mask_for_loss = mask * chain_M

        

        # Recycled predictions for a set number of iterations
        for _ in range(recycle_num):
            lst_seq_pre = []
            for i in range(len(mask)):
                ks = i if i % 2 == 0 else i - 1
                seq1 = torch.argmax(log_probs[int(i // 2)], -1).cpu() * (mask[ks] * chain_M[ks]).cpu() + \
                       (torch.ones_like((mask[ks] * chain_M[ks]).cpu()) - (mask[ks] * chain_M[ks]).cpu()) * S[ks].cpu()
                lst_seq_pre.append(seq1)
            
            # Process predictions with MPNN (Message Passing Neural Network)
            lst_pre_mpnn = seqs_get_signle(lst_x, lst_seq_pre)
            S_embed = seq_esm_embed_func_re(lst_pre_mpnn, S.shape[-1], esm_encoder, alphabet)
            log_probs = model(X, S, S_embed, mask, mask_train, chain_M, residue_idx, chain_encoding_all, is_eval=True)

        # Collect predictions (initial pass)
        for i in range(len(mask)):
            ks = i if i % 2 == 0 else i - 1
            seq1 = torch.argmax(log_probs[int(i // 2)], -1).cpu() * (mask[ks] * chain_M[ks]).cpu() + \
                   (torch.ones_like((mask[ks] * chain_M[ks]).cpu()) - (mask[ks] * chain_M[ks]).cpu()) * S[ks].cpu()
            lst_seq_pre_all.append(seq1)
        
        # Reshape and calculate loss
        S = S.reshape(B // 2, 2, -1)[:, 1]
        mask_for_loss = mask_for_loss.reshape(B // 2, 2, -1)[:, 1]
        
        loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
        
        # Update accuracy and validation metrics
        lst_acc += list(torch.sum(true_false * mask_for_loss, -1).cpu().data.numpy() / 
                        torch.sum(mask_for_loss, -1).cpu().data.numpy())
        validation_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()
        validation_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()
        validation_weights += torch.sum(mask_for_loss).cpu().data.numpy()

# Calculate final metrics
validation_loss = validation_sum / validation_weights
validation_accuracy = validation_acc / validation_weights
validation_perplexity = np.exp(validation_loss)

# Format results for output
validation_perplexity_ = np.format_float_positional(np.float32(validation_perplexity), unique=False, precision=3)
validation_accuracy_ = np.format_float_positional(np.float32(validation_accuracy), unique=False, precision=3)




In [12]:
print(len(lst_name),validation_accuracy_,validation_perplexity_,np.mean(lst_acc))

318 0.579 3.913 0.5810868


#Evaluation interface,Unconditional sequence prediction without context, temperature

In [51]:
import numpy as np
import torch

def seqs_get_single(lst_x, lst_seq_pre):
    """
    Processes sequences from lst_x and lst_seq_pre and returns a list of all sequences.
    """
    lst_seq_all = []

    for entry in lst_x:
        seq_data = []
        masked_list = entry['masked_list']
        masked_seq_name = f"seq_chain_{masked_list[0]}"
        masked_seq = entry[masked_seq_name]
        length_seq = len(masked_seq)

        # Process predicted sequence
        seq_pre_AA = lst_seq_pre[len(seq_data)][:length_seq]
        seq_pre_str = indices_to_chars(seq_pre_AA)
        seq_data.append(seq_pre_str)

        # Append other sequences related to the chain
        lst_seq_keys = [key for key in entry if "seq_chain_" in key and key != masked_seq_name]
        seq_data.extend([entry[key] for key in lst_seq_keys])

        lst_seq_all.append(seq_data)

    return lst_seq_all


def calculate_distance(p1, p2):
    """Calculate the Euclidean distance between two points."""
    return np.sqrt(np.sum((p1 - p2) ** 2))


def find_close_points(x1, x2, threshold=10):
    """
    Given two coordinate lists x1 and x2, find the indices of x1 where points are close to points in x2 (distance < threshold).
    """
    close_indices = []
    x1, x2 = np.array(x1), np.array(x2)
    for i, point1 in enumerate(x1):
        if np.isnan(point1).any():
            continue  # Skip points with NaN values
        for point2 in x2:
            if np.isnan(point2).any():
                continue  # Skip points with NaN values
            distance = calculate_distance(point1, point2)
            if distance < threshold:
                close_indices.append(i)
                break  # Stop checking once a close point is found
    return close_indices


def interface_point(x1):
    """
    Given a structure, find the indices of close points between peptide and receptor chains.
    """
    pep_x1 = x1["masked_list"][0]
    rep_x1 = x1["visible_list"]
    close_points = []

    for i in rep_x1:
        chain_pep = f"coords_chain_{pep_x1}"
        chain_rep = f"coords_chain_{i}"
        close_points.extend(find_close_points(x1[chain_pep]["CA_chain_" + pep_x1], x1[chain_rep]["CA_chain_" + i]))

    return list(set(close_points))


# Main processing loop
model.eval()
lst_acc = []
lst_name = []
lst_x = []
recycle_num = 1

with torch.no_grad():
    validation_sum, validation_weights = 0., 0.
    validation_acc = 0.

    for _, batch in enumerate(StructureLoader_test):
        lst_x = []
        lst_seq_pre = []

        # Featurize the batch
        X, S, mask, mask_train, lengths, chain_M, residue_idx, mask_self, chain_encoding_all, S_lst = featurize(batch[0], batch[1], device, is_train=False)
        B = S.shape[0]

        # Process batch data
        for i in batch[0]:
            lst_name.append(i["name"])
            lst_x.append(i)

        # Get sequence embeddings
        S_embed = seq_esm_embed_func(S_lst, S.shape[-1], esm_encoder, alphabet)
        log_probs = model(X, S, S_embed, mask, mask_train, chain_M, residue_idx, chain_encoding_all, is_eval=True)
        mask_for_loss = mask * chain_M

        # Recycle predictions
        for _ in range(recycle_num):
            lst_seq_pre = []

            for i in range(len(mask)):
                ks = i if i % 2 == 0 else i - 1
                seq1 = torch.argmax(log_probs[int(i // 2)], -1).cpu() * (mask[ks] * chain_M[ks]).cpu() + \
                       (torch.ones_like((mask[ks] * chain_M[ks]).cpu()) - (mask[ks] * chain_M[ks]).cpu()) * S[ks].cpu()
                lst_seq_pre.append(seq1)

            # Process predictions with MPNN (Message Passing Neural Network)
            lst_pre_mpnn = seqs_get_single(lst_x, lst_seq_pre)
            S_embed = seq_esm_embed_func_re(lst_pre_mpnn, S.shape[-1], esm_encoder, alphabet)
            log_probs = model(X, S, S_embed, mask, mask_train, chain_M, residue_idx, chain_encoding_all, is_eval=True)

        # Process interfaces
        for i in range(len(batch[0])):
            if i % 2 == 0:
                lst_interface = []

            #lst_name.append(batch[0][i]["name"])
            lst_x.append(batch[0][i])
            lst_interface.extend(interface_point(batch[0][i]))

            if i % 2 != 0:
                mask_z = torch.zeros_like(chain_M[i])
                for ints in lst_interface:
                    mask_z[ints] = 1
                chain_M[i] = mask_z
                chain_M[i - 1] = mask_z

        mask_for_loss = mask * chain_M
        B = S.shape[0]
        S = S.reshape(B // 2, 2, -1)[:, 1]
        mask_for_loss = mask_for_loss.reshape(B // 2, 2, -1)[:, 1]

        # Calculate loss and accuracy
        loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
        lst_acc.extend(torch.sum(true_false * mask_for_loss, -1).cpu().data.numpy() / torch.sum(mask_for_loss, -1).cpu().data.numpy())

        validation_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()
        validation_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()
        validation_weights += torch.sum(mask_for_loss).cpu().data.numpy()

    # Final validation metrics
    validation_loss = validation_sum / validation_weights
    validation_accuracy = validation_acc / validation_weights
    validation_perplexity = np.exp(validation_loss)

    validation_perplexity_ = np.format_float_positional(np.float32(validation_perplexity), unique=False, precision=3)
    validation_accuracy_ = np.format_float_positional(np.float32(validation_accuracy), unique=False, precision=3)


In [52]:
print(len(lst_name),validation_accuracy_,validation_perplexity_,np.mean(lst_acc))

(104, '0.605', '3.831', 0.6071687)

Save Unconditional sequence prediction without context, temperature

In [16]:
def find_first_last_one(lst):
    first_one = -1
    last_one = -1
    
    for i, value in enumerate(lst):
        if value == 1:
            if first_one == -1:
                first_one = i
            last_one = i
            
    return first_one, last_one
def indices_to_chars(indices):
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    return ''.join([alphabet[int(i)] for i in indices])
def seqs_get_signle(lst_x,lst_seq_pre):
    length_lst = len(lst_x)
    lst_seq_all = []

    for i in range(length_lst):
        lst_seq_1 = []
        masked_list = lst_x[i]['masked_list']
        masked_seq_name = "seq_chain_"+masked_list[0]
        masked_seq = lst_x[i][masked_seq_name]
        length_seq = len(masked_seq)
        #a,b = find_first_last_one(lst_mask_chain[i])
        seq_pre_AA = lst_seq_pre[i][:length_seq]
        ######
        #seq_pre_str = indices_to_chars(seq_pre_AA[a:b+1])
        seq_pre_str = indices_to_chars(seq_pre_AA)
        ######
        lst_seq_1.append(seq_pre_str)


        lst_seq_keys_1 = []
        for keys_1 in lst_x[i]:
            if "seq_chain_" in keys_1 and keys_1 != masked_seq_name:
                lst_seq_keys_1.append(keys_1)

        for keys_seq in lst_seq_keys_1:
            lst_seq_1.append(lst_x[i][keys_seq])

        lst_seq_all.append(lst_seq_1)

    return lst_seq_all

def seq_single_complex_cross(lst_seq_all):
    lengths_all = len(lst_seq_all)
    lengths_all = int(lengths_all//2)
    ns = 0
    lst_cross = []
    for i in range(lengths_all):
        lst1 = lst_seq_all[i*2]
        lst2 = lst_seq_all[i * 2 + 1]

        lst_3 = []
        lst_4 = []
        lst_3.append(lst2[0])
        lst_4.append(lst1[0])

        for l in lst1[1:]:
            if "X" in l:
                print("a")
            lst_3.append(l)
        for l in lst2[1:]:
            if "X" in l:
                print("a")
            lst_4.append(l)
        lst_cross.append(lst_3)
        lst_cross.append(lst_4)

    return lst_cross

#def get_fasta_lst(lst):
lst_pre_mpnn =seqs_get_signle(lst_x_all,lst_seq_pre_all)
with open("esm_j_r1_159_uncondition_test.fasta", "w") as file:
     
    for i in range(len(lst_name)):
        
        file.write(f">{lst_name[i]} \n")
        seqs = ""
        for s in lst_pre_mpnn[i]:
            
            all_sequence = list(s)
            for aas in range(len(all_sequence)):
                if all_sequence[aas] == "X":
                    all_sequence[aas] = "A"
            all_sequence = "".join(all_sequence)
            
            seqs = seqs+all_sequence+":"
        file.write(seqs[:-1] + "\n")