## This notebook just generate streamlines the code of **generating_sequences.ipynb**

In [1]:
import numpy as np
import os, sys
#

sys.path.insert(1, "./../util/")
sys.path.insert(1, "./../model/")
from encoded_protein_dataset_new import get_embedding, EncodedProteinDataset_new, EncodedProteinDataset_aux
from dynamic_loader import dynamic_collate_fn, dynamic_cluster

from pseudolikelihood import get_npll2, get_npll3
import torch, torchvision
from potts_decoder import PottsDecoder
from test_utils import load_model, get_samples_potts, compute_distances, select_sequences
from esm_utils import get_samples_esm

from torch.utils.data import DataLoader, RandomSampler
from functools import partial

from typing import Sequence, Tuple, List
import scipy
from tqdm import tqdm
import pandas as pd
import csv
import time


sys.path.insert(1, "./../../esm/")
import esm.pretrained as pretrained
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict
from Bio import SeqIO
from dynamic_loader import dynamic_collate_fn, dynamic_cluster
from torch.nn.functional import one_hot

from esm_utils import load_structure, extract_coords_from_structure, get_atom_coords_residuewise
from esm_utils import sample_esm_batch2
from esm_utils import align_esm


import pyhmmer

### First let us load the desired test dataset

In [6]:
max_msas = 2
msa_dir = "./../../split2/"
encoding_dir ="./../../structure_encodings/"

### Now I am doing superfamily, I did not change the name for convenience
test_dataset = EncodedProteinDataset_aux(os.path.join(msa_dir, 'test/superfamily'), encoding_dir, noise=0.0, max_msas=max_msas)

batch_structure_size_train = 1
batch_structure_size=1
perc_subset_test = 1.0    
batch_msa_size = 128 
q = 21 

collate_fn = partial(dynamic_collate_fn, q=q, batch_size=batch_structure_size, batch_msa_size=batch_msa_size)
test_loader = DataLoader(test_dataset, batch_size=batch_structure_size, collate_fn=collate_fn, shuffle=False, 
num_workers=1, pin_memory=True)



Counter is:2 length data:1

  encodings = torch.tensor(read_encodings(encoding_path, trim=False))


## Then let us select a set of representatives from the selected test dataset

In [7]:
summary_path = os.path.join(msa_dir, "./test/summary.txt")
summary_df = pd.read_csv(summary_path, sep="\t")
ids = [el[-14:-7] for el in test_dataset.msas_paths]
# Filter the DataFrame based on admissible values
filtered_df = summary_df[summary_df['# domain'].isin(ids)]
sf = filtered_df[filtered_df["test_type"]=="superfamily"]

In [8]:
# Initialize an empty dictionary
grouped_dict = {}

# Iterate through the strings
for s in sf['superfamily']:
    # Split the string into parts
    parts = s.split('.')
    
    # Create 'x.y' key
    key = f'{parts[0]}.{parts[1]}'
    
    # If key is not in the dictionary, add it with an empty list as the value
    if key not in grouped_dict:
        grouped_dict[key] = []
    
    # Append the 'z' value to the list
    grouped_dict[key].append(parts[2])

In [9]:
len(grouped_dict.keys())

2

In [10]:
from collections import Counter

grouped_dict_2 = {}

for key in grouped_dict.keys():
    # Your input sequence
    vals = []
    sequence = grouped_dict[key]

    # Count the occurrences using Counter
    counts = Counter(sequence)
    # Extract the non-repeating words and their counts
    non_repeating = list(counts.values())
    words = list(counts.keys())
    vals.append(non_repeating)
    vals.append(list(counts.keys()))
    grouped_dict_2[key] = vals
    

In [11]:
repr_set1 = []
repr_set = []
for key in grouped_dict_2.keys():
    #key='3.30'
    vals = grouped_dict_2[key]
    idx = np.argmax(vals[0])
    fold = grouped_dict_2[key][1][idx]
    fold_hom = key+"."+fold
    result = next((i for i, s in enumerate(list(sf['superfamily'].values)) if fold_hom in s), None)    
    id = sf['# domain'].values[result]
    repr_set1.append(id)
    result2 = next((i for i, s in enumerate(test_dataset.msas_paths) if id in s), None) 
    repr_set.append(result2)
    

## Now we can proceed to generate the samples from the different models

In [12]:
### rept_set will be a set defined to ensure that we are properly representing the test dataset under analysis
### We don't want, for instance, to always select members from the same superfamily. 
import warnings
warnings.filterwarnings("ignore")

res_full_esm = {}
res_full_potts = {}
res_full_ardca = {}
pdb_dir = './../../dompdb/'

Ns = torch.zeros(len(repr_set))
Ms = torch.zeros(len(repr_set))

############################## FOR ESM ######################################
alphabet='ACDEFGHIKLMNPQRSTVWY-'
default_index = alphabet.index('-')
aa_index = defaultdict(lambda: default_index, {alphabet[i]: i for i in range(len(alphabet))})
aa_index_inv = dict(map(reversed, aa_index.items()))   
#device='cpu'
model_esm, alphabet_esm = pretrained.esm_if1_gvp4_t16_142M_UR50()
model_esm.eval();
#model.to(device)



device = 0
################################ FOR ARDCA ####################################
bk_dir = "./../../bk_models2/"
fname_par_ardca = 'model_11_07_2023_epoch_' + str(94.0) + '_ardca' + '.pt'
model_path_ardca = os.path.join(bk_dir, fname_par_ardca)
decoder_ardca = load_model(model_path_ardca, device=device)

############################## FOR POTTS ########################################
fname_par_potts = 'model_25_06_2023_epoch_' + str(94.0) + '.pt'
model_path_potts = os.path.join(bk_dir, fname_par_potts)
decoder_potts = load_model(model_path_potts, device=device)
counter=0
for idx_bk in repr_set:
    print(f"I am at index: {counter+1} out of {len(repr_set)}")
    M,N = test_dataset[idx_bk][0].shape
    Ns[counter] = N
    Ms[counter] = M
    idx = -1
    with torch.no_grad():
        for inputs_packed in test_loader:
            idx+=1
            if idx != idx_bk:
                continue
            for inputs in inputs_packed[1]:
                msas, encodings, padding_mask  = [input.to(device, non_blocking=True) for input in inputs]
                B, _, N = msas.shape
                pdb_name = test_dataset.msas_paths[idx][-14:-7]
                samples_ardca = decoder_ardca.sample_ardca_full(encodings, padding_mask, device=0, n_samples=10000)
                samples_ardca = torch.tensor(samples_ardca.to('cpu'), dtype=torch.long)
            
            ####### Now that I have the full samples, I want to compute the distance from the native sequence
            distances_ardca = compute_distances(samples_ardca, idx_bk, test_dataset, pdb_name, aa_index, aa_index_inv)
            
            ####### Subselect from data ##########
            min_dist = np.min(distances_ardca)
            if min_dist < 0.65:
                percs = [0.65, 0.7, 0.75, 0.8, 0.85]
            else:
                percs = [0.7, 0.75, 0.8, 0.85]
            
            res_ardca = select_sequences(distances_ardca, samples_ardca, percs)
            print("Got results for ardca")
            
            ####### Now we move to samples for esm 
            pdb_path = os.path.join(pdb_dir, pdb_name)

            structure =  load_structure(pdb_path)
            coords, native_seq = extract_coords_from_structure(structure)
            native_seq_num = torch.zeros(len(native_seq), dtype=torch.long)
            idx_char=0
            for char in native_seq:
                native_seq_num[idx_char] = aa_index[char]
                idx_char+=1
                
            
            res_esm = get_samples_esm(model_esm, coords, idx_bk, test_dataset, native_seq_num, pdb_name, percs, nfill=10, device=device)
            print("Got results for esm")
            ##### Now we get samples for Potts
            couplings_potts, fields_potts = decoder_potts(encodings, padding_mask)
            samples_potts = get_samples_potts(couplings_potts, fields_potts, aa_index, aa_index_inv, N, q)
            print("GOT SAMPLES FROM POTTS")
            distances_potts = compute_distances(samples_potts, idx_bk, test_dataset, pdb_name, aa_index, aa_index_inv)
            
            min_dist = np.min(distances_potts)
            if min_dist < 0.65:
                percs = [0.65, 0.7, 0.75, 0.8, 0.85]
            else:
                percs = [0.7, 0.75, 0.8, 0.85]
            
            res_potts = select_sequences(distances_potts, samples_potts, percs)
            
            ######################## SAVE THE RESULTS ##########################
            id = test_dataset.msas_paths[idx_bk][-14:-7]
            res_full_potts[id] = res_potts
            res_full_ardca[id] = res_ardca
            res_full_esm[id] = res_esm
            counter+=1
                 

I am at index: 1 out of 2
Got results for ardca
Got results for esm
initializing sampler... 0.323166 sec

sampling model with mcmc... 9.69749 sec
updating mcmc stats with samples... 0.628784 sec
computing sequence energies and correlations... 0.066324 sec
writing final sequences... done
GOT SAMPLES FROM POTTS
I am at index: 2 out of 2


KeyboardInterrupt: 

## Now we can save the results

In [55]:
import pickle

with open("samples_ardca_superfamily", mode="wb") as f:
    pickle.dump(res_full_ardca, f)
    
with open("samples_potts_superfamily", mode="wb") as f:
    pickle.dump(res_full_potts, f)
    
with open("samples_esm_superfamily", mode="wb") as f:
    pickle.dump(res_full_esm, f)
