In [1]:
import torch
from transformers import MT5ForConditionalGeneration, MT5Config, MT5EncoderModel, MT5Tokenizer, Trainer, TrainingArguments
from progeny_tokenizer import TAPETokenizer
import numpy as np
import math
import random
import scipy
import time
import pandas as pd
from torch.utils.data import DataLoader, RandomSampler, Dataset, BatchSampler
import typing
from pathlib import Path
import argparse
from collections import OrderedDict
import pickle
import matplotlib.pyplot as plt

In [2]:
before_foldx = False

In [3]:
seed = 30
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)


<torch._C.Generator at 0x7f744c0b2250>

# Analyze 250K gen seqs and prepare for FoldX

saved output tsv file to run FoldX inference

In [4]:
wt_seq = 'STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQLQALQ'
constant_region = 'NTNITEEN'
wt_cs_ind = wt_seq.index(constant_region)

In [5]:
gen250k_tsv_name = 'generated_seqs/baseline_gen/rerunwlatentheadpred_tophalf_12ep_250K-basegen_seqs260000.tsv'

In [6]:
gen250k_df = pd.read_table(gen250k_tsv_name)

In [7]:
gen250k_df

Unnamed: 0,disc_pred,latent_head_pred,MT_seq,PDB,Chain,Start_index,WT_seq
0,0.468239,-1.780831,SSYEEQIKTFIDKFKHVAEMLFHQSEQGMMFYMMNYLMMQFMLFMK...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
1,0.577867,-4.800953,STKEMVAKTFLDMFNHEFFIVFLYSFMMAEMLFLFIKFQSTLAQYY...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
2,0.609596,0.808763,SSIEGQAKMFLDKHEHEYEDLFENFFTKMMLFMFMMFMNYNMKAFQ...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
3,0.714211,0.079907,STIETVAKSFLDKFNVEAETGFGQFMMQMYAMMMMQLFELMLQLMK...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
4,0.787052,-2.055216,KTIEEAAGTMLDKLKAFAKNMLMMMKYEMQAFFMFFNNILKNFLMM...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
...,...,...,...,...,...,...,...
260649,34.154423,3.179507,HTGVLQAKTFLDKFAHMSYDLFTLNIMEEFEQAFLKFHMAANAALQ...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260650,34.228806,5.039864,NTIEIQFKTHLDKFTHEAEDLFYQSSLASMNYNTNITEENVQAMNF...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260651,34.324348,4.344697,STIEEQYKTFLDKENHEVEDLFYQRSLASMNMNTNITEENAQNMAN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260652,34.395195,3.548505,STIETQAKTMLDLFNHEFFDEFYQSALGSMNKNTIIFEQFFSFLQK...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...


filter out sequences without constant region

In [8]:
indices_to_drop = []
dropped_seqs = []
for index, row in gen250k_df.iterrows():
    seq = row['MT_seq']
    if constant_region not in seq:
        indices_to_drop.append(index)
        dropped_seqs.append(seq)
    else:
        cs_ind = seq.index(constant_region)
        if cs_ind != wt_cs_ind:
            indices_to_drop.append(index)
            dropped_seqs.append(seq)

In [9]:
print(len(indices_to_drop))
print(indices_to_drop)
print(dropped_seqs)

3321
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 72, 308, 539, 614, 1010, 1014, 1058, 1287, 1549, 2036, 2248, 2679, 3726, 4023, 4201, 4567, 4748, 5336, 5351, 5702, 6041, 6151, 6237, 6827, 6844, 6928, 6954, 7356, 8160, 8451, 10238, 10286, 11152, 11455, 12208, 12290, 12550, 12708, 13593, 15019, 15109, 15194, 16346, 16712, 16830, 17322, 17976, 18013, 18089, 18123, 18532, 18660, 18741, 18866, 19008, 19137, 21469, 22216, 22365, 22654, 23740, 24149, 24474, 25396, 26130, 26147, 26423, 26766, 26955, 27003, 27236, 27595, 27696, 28281, 28392, 29035, 29177, 29202, 29921, 30217, 30464, 32473, 32695, 32709, 33462, 33954, 34042, 34048, 34075, 34627, 35248, 35585, 35629, 36674, 37213, 37281, 37478, 37825, 38185, 38442, 38681, 39138, 39214, 39371, 39699, 39816, 40792, 41151, 41167, 41216, 42458, 44026, 44130, 44851, 44949, 45835, 45885, 46087, 46570, 47983, 48289, 49873, 49999, 50428, 51196, 51211, 51431, 51682, 53101, 53710, 53807, 54491, 545

In [10]:
gen250k_df_dropped_nocon = gen250k_df.drop(indices_to_drop)

In [11]:
gen250k_df_dropped_nocon

Unnamed: 0,disc_pred,latent_head_pred,MT_seq,PDB,Chain,Start_index,WT_seq
28,2.789537,-10.144451,SDIEEQAKTFLDKFNHEAEDLFYQSSLRSMVYNTNITEENIQNMNR...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
29,2.796858,-11.077489,SSIEEQAKTFLDKFNHEAEDYFYQSSLASMNYNTNITEENVQMMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
30,2.805789,-11.647189,SDIEEQAKTFLDKFNHEAEDLFYQSSLASMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
31,2.808174,-12.024853,SDIEEQAKTFLDRFNHEAEDLFYQSSLASANYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
32,2.808599,-10.844015,SDIEEQAKTFLDKFLHEAEDLFYQSSLASFNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
...,...,...,...,...,...,...,...
260646,33.955925,4.574232,STIEEQFKTFLDTFNHEAEDLEPQYSLANMNYNTNITEENPMNMNM...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260647,33.976833,5.299477,DTIEEFFKTFLDKFNHNAEDKFYQSSLASRNYNTNITEENVQSMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260650,34.228806,5.039864,NTIEIQFKTHLDKFTHEAEDLFYQSSLASMNYNTNITEENVQAMNF...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260651,34.324348,4.344697,STIEEQYKTFLDKENHEVEDLFYQRSLASMNMNTNITEENAQNMAN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...


filter out sequences with non-AA tokens

In [12]:
rejected_tokens = ["<pad>", "<sep>", "<cls>", "<mask>", "<unk>"]

In [13]:
indices_to_drop = []
dropped_seqs = []
for index, row in gen250k_df_dropped_nocon.iterrows():
    seq = row['MT_seq']
    
    for rejected_token in rejected_tokens:
        if rejected_token in seq:
            indices_to_drop.append(index)
            dropped_seqs.append(seq)
            break
            

In [14]:
print(len(indices_to_drop))
print(indices_to_drop)
print(dropped_seqs)

188
[208, 7986, 16374, 23177, 26388, 32913, 35817, 39603, 60979, 61524, 63429, 63570, 66154, 67374, 68401, 77476, 77909, 84386, 84549, 87354, 92112, 93118, 96224, 99693, 104626, 107403, 109843, 112811, 116471, 116871, 130677, 131927, 133969, 141555, 141756, 141789, 143183, 146545, 148769, 149116, 152758, 156237, 158249, 161612, 161898, 162709, 164091, 164602, 164623, 168463, 168528, 168825, 169338, 170455, 170568, 175487, 176323, 177378, 179015, 180445, 186059, 186230, 190620, 190727, 192849, 193312, 194481, 194549, 195088, 195996, 196206, 197140, 199732, 200861, 201787, 201915, 202074, 202645, 205518, 207152, 208062, 208173, 208564, 208853, 210176, 211210, 211874, 212425, 214270, 214487, 215969, 216386, 218755, 219209, 220178, 221092, 221620, 222831, 222935, 222961, 223946, 225286, 227065, 227484, 227514, 228184, 228186, 228269, 228497, 228552, 230232, 231063, 231935, 232788, 233530, 234363, 234751, 234933, 235211, 235347, 235757, 239058, 239276, 240309, 240736, 240973, 241126, 241246

In [15]:
gen250k_df_dropped = gen250k_df_dropped_nocon.drop(indices_to_drop)
print(len(gen250k_df_dropped))

257145


In [16]:
indices_to_drop = []
dropped_seqs = []
for index, row in gen250k_df_dropped.iterrows():
    seq = row['MT_seq']
    
    for rejected_token in rejected_tokens:
        if rejected_token in seq:
            indices_to_drop.append(index)
            dropped_seqs.append(seq)
            break
            
print(len(indices_to_drop))

0


In [17]:
indices_to_drop = []
dropped_seqs = []
for index, row in gen250k_df_dropped.iterrows():
    seq = row['MT_seq']
    if constant_region not in seq:
        indices_to_drop.append(index)
        dropped_seqs.append(seq)
    else:
        cs_ind = seq.index(constant_region)
        if cs_ind != wt_cs_ind:
            indices_to_drop.append(index)
            dropped_seqs.append(seq)
            
print(len(indices_to_drop))

0


In [18]:
gen250k_df_dropped

Unnamed: 0,disc_pred,latent_head_pred,MT_seq,PDB,Chain,Start_index,WT_seq
28,2.789537,-10.144451,SDIEEQAKTFLDKFNHEAEDLFYQSSLRSMVYNTNITEENIQNMNR...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
29,2.796858,-11.077489,SSIEEQAKTFLDKFNHEAEDYFYQSSLASMNYNTNITEENVQMMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
30,2.805789,-11.647189,SDIEEQAKTFLDKFNHEAEDLFYQSSLASMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
31,2.808174,-12.024853,SDIEEQAKTFLDRFNHEAEDLFYQSSLASANYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
32,2.808599,-10.844015,SDIEEQAKTFLDKFLHEAEDLFYQSSLASFNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
...,...,...,...,...,...,...,...
260646,33.955925,4.574232,STIEEQFKTFLDTFNHEAEDLEPQYSLANMNYNTNITEENPMNMNM...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260647,33.976833,5.299477,DTIEEFFKTFLDKFNHNAEDKFYQSSLASRNYNTNITEENVQSMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260650,34.228806,5.039864,NTIEIQFKTHLDKFTHEAEDLFYQSSLASMNYNTNITEENVQAMNF...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
260651,34.324348,4.344697,STIEEQYKTFLDKENHEVEDLFYQRSLASMNMNTNITEENAQNMAN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...


In [19]:
topK_saved = 10000
top10K_Dscore_gen250k_tsv_name = 'generated_seqs/baseline_gen/tophalf-basegen_top10Klatentheadfiltered.tsv'
# top10K_Dscore_gen250k_tsv_name = 'generated_seqs/tophalf-basegen_top10K-Dscore_250Kgen_dropped.tsv'

gen250k_df_dropped = gen250k_df_dropped[:250000]

gen250k_df_dropped = gen250k_df_dropped.sort_values(by='latent_head_pred', ascending=True)
topK_gen250k_df_dropped = gen250k_df_dropped.iloc[:topK_saved]

In [20]:
topK_gen250k_df_dropped

Unnamed: 0,disc_pred,latent_head_pred,MT_seq,PDB,Chain,Start_index,WT_seq
51,2.834782,-12.848934,GDIEEQAKTFLDSFNHEAENLFMQSSLASMNYNTNITEENVQNMNG...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
99,2.855145,-12.228134,SDIEEQAKTFLDKFNYEAEDLFLQSSLASMNYNTNITEENYQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
191,2.888042,-12.105712,STIEEQAKTFLDKFNHEAEKLFYQFSLASMNYNTNITEENQQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
136,2.870499,-12.081381,KTIEEFAKTFLMKFNHEAEMLFYQSSLLSMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
31,2.808174,-12.024853,SDIEEQAKTFLDRFNHEAEDLFYQSSLASANYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
...,...,...,...,...,...,...,...
12127,7.200650,-5.630886,STIEEIAKTFLDKFNHEAEDLFYQSSLLSMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
12915,7.376093,-5.630832,STIEEQAKTFLDKFNHEAEDLFYQMRLASMNYNTNITEENVQNMLN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
10896,6.886302,-5.630334,ATTEEQAKTFLEMFNHEAEVLFYQSSLASMNYNTNITEENSQNMIN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
16795,7.971996,-5.630300,STIEEQAVTFLDKFNHEAEDLFYQSSLASMMYNTNITEENDQNMAN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...


# Save top 10K seqs for FoldX Evaluation

In [21]:
if before_foldx:
    topK_gen250k_df_dropped.to_csv(top10K_Dscore_gen250k_tsv_name, sep="\t", index=False)

# Sample for E[min] FoldX Computation

In [22]:
gen250k_df_dropped

Unnamed: 0,disc_pred,latent_head_pred,MT_seq,PDB,Chain,Start_index,WT_seq
51,2.834782,-12.848934,GDIEEQAKTFLDSFNHEAENLFMQSSLASMNYNTNITEENVQNMNG...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
99,2.855145,-12.228134,SDIEEQAKTFLDKFNYEAEDLFLQSSLASMNYNTNITEENYQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
191,2.888042,-12.105712,STIEEQAKTFLDKFNHEAEKLFYQFSLASMNYNTNITEENQQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
136,2.870499,-12.081381,KTIEEFAKTFLMKFNHEAEMLFYQSSLLSMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
31,2.808174,-12.024853,SDIEEQAKTFLDRFNHEAEDLFYQSSLASANYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
...,...,...,...,...,...,...,...
196640,22.194649,7.554973,STIEEQAKTFLDKFNHEAEFLFYNSSLASMNNNTNITEENVNNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
173991,20.747000,7.556052,STIEFQAKTFLDKFNHEAEDLFYQSSLASMNPNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
160532,19.980900,7.620140,STISEQAKTFLDKFNHEAEDLFYQSSLASMNYNTNITEENNQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
166822,20.352039,7.622699,STIAEQAKTFLDKFNHEAEDLFYQSSLASGNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...


In [23]:
# Get topk seqs
num_rounds = 100 # N
round_pool_size = 10000
topk = 10 # K

round_topk = {}
cols_to_sort = ['latent_head_pred']
# cols_to_sort = ['disc_pred']
# cols_to_sort = ['disc_pred', 'latent_head_pred']

foldx_df = None
in_count = 0 
for col_to_sort in cols_to_sort:
    print("col_to_sort: ", col_to_sort)
    round_topk[col_to_sort] = {}
    for round_ind in range(num_rounds):
        sampled_rows = gen250k_df_dropped.sample(n=round_pool_size)
        sorted_sampled_rows = sampled_rows.sort_values(by=col_to_sort, ascending=True)[:topk]
        topk_rows = sorted_sampled_rows[:topk]
        round_topk[col_to_sort][round_ind] = topk_rows
    
    for round_ind in round_topk[col_to_sort]:
        round_topk_df = round_topk[col_to_sort][round_ind]
        if foldx_df is None:
            foldx_df = round_topk_df
        else:
            all_mt = foldx_df['MT_seq'].tolist()

            for row_ind, row in round_topk_df.iterrows():
                if row['MT_seq'] not in all_mt:
                    foldx_df = foldx_df.append(row)
                else:
                    in_count += 1
                    
    print("len(foldx_df)+in_count: ", len(foldx_df)+in_count)

col_to_sort:  latent_head_pred
len(foldx_df)+in_count:  1000


In [24]:
foldx_df

Unnamed: 0,disc_pred,latent_head_pred,MT_seq,PDB,Chain,Start_index,WT_seq
174,2.882024,-11.690946,SDIEEAAKTFLDKFNHEAEELFYQSSRAAMNYNTNITEENVQNMNK...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
63,2.841469,-11.372101,RTIEEQAKTFLIKFNHEAKDLFYQSMLASMLYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
411,2.960012,-10.752542,ATIEEQAKTFLDKFNHEAEDLFYQSSLASMEYNTNITEENVRNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
201,2.891267,-10.747655,VTIEEFAKTFLDKFNHEAEDLFYQFSFASMNYNTNITEENVQMMNA...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
322,2.929900,-10.705520,SDIEEQAKIFLDKFNHEANDLFYMSALSSMLYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
...,...,...,...,...,...,...,...
418,2.962836,-10.224881,STIEEQAKTFLDRFNHEAEDLFYQSSLASMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
278,2.914169,-10.110140,STIEEQAKTFLDKFNHEAEDLFYQSSLLSMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
130,2.866821,-10.319946,SDIEEYAKTFLLEFNHEAEDLFYQSSLQEMNYNTNITEENFQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...
511,2.998430,-10.007259,STIEEVAKTFLDLFNHEAEDLFYQKSLAEMNYNTNITEENVNNMLK...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...


In [25]:
in_count

671

# save E[min] seqs to do FoldX

In [26]:
seqsforEmin_dict_name = 'generated_seqs/baseline_gen/latentheadfiltered_tophalf-basegen_seqsforEmin_df.pkl'

if before_foldx:
    with open(seqsforEmin_dict_name, 'wb') as f:
        pickle.dump(round_topk, f)

# with open(seqsforEmin_dict_name, 'rb') as f:
#     b = pickle.load(f)

In [27]:
seqsforEmin_tsv_name = 'generated_seqs/baseline_gen/latentheadfiltered_tophalf-basegen_seqsforEmin_foldx.tsv'

if before_foldx:
    foldx_df.to_csv(seqsforEmin_tsv_name, sep="\t", index=False)

# <<===== After Foldx Computation =====>>

In [28]:
foldx_results_name = "foldx_sim_results/tophalf-basegen_top10Klatentheadfiltered/results_full.tsv"
# Emin_results_tsv_name = "foldx_sim_results/tophalf-basegen_seqsforEmin_foldx_results/results_full.tsv"


In [29]:
# load results df here
foldx_results_df = pd.read_table(foldx_results_name)



In [30]:
foldx_results_df

Unnamed: 0,disc_pred,latent_head_pred,MT_seq,PDB,Chain,Start_index,WT_seq,ddG
0,6.204562,-6.198444,STIEEQAKTFIDKFNHEAEDLFYQSSLASMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-0.943782
1,6.566597,-6.198411,VTIEEQAKTFLDSFIHEAEDLFYQSSLASMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-0.763596
2,3.709481,-6.197896,SGIEEFAKRFLDKFNHNAEDLFYFMILAKMNYNTNITEENVINMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-4.828940
3,5.815841,-6.197204,SSIEEQAKTFLDKFNHEAMDLFYQSSLASMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-0.616186
4,6.044264,-6.196896,STIEEQAKTFLDKFNHEAEDLFYQSSLASMNYNTNITEENVQNYNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-0.670446
...,...,...,...,...,...,...,...,...
9995,3.506202,-9.032804,STIEEQAKAFLDKFNHEAEDLFYQMSLASMNYNTNITEENVQNMNI...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-3.044100
9996,3.213400,-9.030958,STIEEQAKTFLDKFNHEAEDLFYQSMLASMNYNTNITEENVQNMNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-2.146290
9997,3.841218,-9.030471,STHEEMAKQFLDKFNHEAEDLFYQMSLAKRNYNTNITEENIQNMMN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-6.530910
9998,3.684652,-9.029869,SDIEEQAKTFLEKFNHEAEDLFYQSSLASMYYNTNITEENVQNLNN...,template2.pdb,A,19,STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNN...,-1.944460


In [31]:
# # for debug
# foldx_results_df = foldx_df

In [32]:
# Compute Emin from foldx values
rows_to_patch = None
Emin_results_dict = {}
for col_to_sort in round_topk:
    print(col_to_sort)
    current_score_round_topk = round_topk[col_to_sort]
    
    round_min_list = []
    
    for round_ind in current_score_round_topk:
        round_topk_df = current_score_round_topk[round_ind]
        
        round_ddG = []
        for row_ind, row in round_topk_df.iterrows():
            row_seq = row['MT_seq']
            matched_row = foldx_results_df.loc[foldx_results_df['MT_seq'] == row_seq]
            if len(matched_row) != 1 :
                print("matched_row: ", matched_row)
                if len(matched_row) == 0 :
                    if rows_to_patch is None:
                        rows_to_patch = row
                    else:
                        rows_to_patch.append(row)
#                 raise
                else:
                    round_ddG.append(matched_row.iloc[0]['ddG'])
            else:        
                round_ddG.append(matched_row['ddG']) # ! changed to ddG
        
        round_min  = np.min(round_ddG)
        round_min_list.append(round_min)
        
    Emin = np.mean(round_min_list)
    
    Emin_results_dict[col_to_sort] = Emin

latent_head_pred


In [33]:
print(rows_to_patch)

None


# Save Emin Results

In [34]:
Emin_results_name = 'generated_seqs/baseline_gen/Emin_results/latentheadfiltered_tophalf-basegen_seqsforEmin_results.txt'

In [35]:
with open(Emin_results_name, "w") as writer:
    writer.write("***** E[min] results *****\n")
    writer.write("seqsforEmin_dict_name: {}\n".format(seqsforEmin_dict_name))
    for key in sorted(Emin_results_dict.keys()):
        writer.write("%s = %s\n" % (key, str(Emin_results_dict[key])))