### Calculate Top-k Accuracy (TTA20)

In [78]:
import pandas as pd
import rdkit
from rdkit import Chem
print('rdkit version:', rdkit.__version__)

def clear_map_canonical_smiles(smi):
    mol = Chem.MolFromSmiles(smi)
    if mol is not None:
        mol = Chem.MolFromSmiles(smi)
        for atom in mol.GetAtoms():
            if atom.HasProp('molAtomMapNumber'):
                atom.ClearProp('molAtomMapNumber')
        smi = Chem.MolToSmiles(mol, canonical=True)
        return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
    else:
        return smi

from e_smiles import get_edit_from_e_smiles

rdkit version: 2019.03.2


### Label

In [79]:
# Label E-smiles (50k-test)
n_best = 10
argtimes = 1
product_lis = []
true_e_smiles_lis = []
with open('datasets/50k_e_smiles/aug20_test/2023_7_13_50k_test_r20_v.txt') as f:
    for line in f.readlines():
        line = ','.join(line.replace('\n','').split(',')[2:])
        product_lis += [line.split('>>>')[0]] * n_best
        true_e_smiles_lis.append(line)

edit_times_lis = []
for i in range(5000):
    core_edits,chai_edits,stereo_edits,charge_edits,core_edits_add,lg_map_lis = get_edit_from_e_smiles(true_e_smiles_lis[i])
    edit_times = len(core_edits) + len(chai_edits) + len(stereo_edits) + len(charge_edits) + len(core_edits_add) + len(lg_map_lis)
    edit_times_lis.append(edit_times)
#     if edit_times == 0:
#         print(i,e_smiles_lis[i])

# Label SMILES of Reactants (50k-test)
DATA = 'test' 
if DATA == 'val':
    raw_df = pd.read_csv("datasets/50k_raw/raw_val.csv")
    indexes_to_drop = [2302, 2527, 2950, 4368, 4863, 4890]
    df = raw_df.drop(indexes_to_drop)

elif DATA == 'test':
    raw_df = pd.read_csv("datasets/50k_raw/raw_test.csv")
    indexes_to_drop = [822, 1282, 1490, 1558, 2810, 3487, 4958]
    df = raw_df.drop(indexes_to_drop)

truth = list(map(lambda x: x.split(">>")[0], list(df['reactants>reagents>production'])))
truth = list(map(clear_map_canonical_smiles, truth))
print(len(truth))

5000


### Prediction

In [80]:
# Predicted E-SMILES (50k-test)
pred_e_smiles_lis_h = []

with open('output/tgt_50k_e_smiles_aug100_train_aug20_test_infer.txt') as f:
    for line in f.readlines():
        line = line.replace('\n','').replace(" ","")
        pred_e_smiles_lis_h.append(line)

e_smiles_lis = []
for i in range(len(product_lis)):
    e_smiles_lis.append(product_lis[i] + ">>>" + pred_e_smiles_lis_h[i])


# Predicted SMILES of Reactants (50k-test)
pred = []
with open('output/pred_reactants_50k_e_smiles_aug100_train_aug20_test_infer.txt') as f:
    for line in f.readlines():
        line = line.replace('\n','').replace(" ","")
        pred.append(line)

### Calculate top-k accuracy

In [81]:
fold = 20
num = 5000
n = 10

score_1 = []
for i in range(fold):
    for j in range(num):
        for k in range(1,n+1):
            score_1.append(1/k**2)       # square  

In [82]:
vote_score_lis = [-1]*len(pred) 

for j in range(num):
    smiles_score_dic = {}
    for i in range(fold):
        for k in range(n):
            idx = i * num * n + j*n + k
            smiles = pred[idx]
            s = score_1[idx]
            if smiles == '':
                s = 0
            
            if smiles not in smiles_score_dic:
                smiles_score_dic[smiles] = s
            else:
                smiles_score_dic[smiles] += s

    for i in range(fold):
        for k in range(n):
            idx = i * num * n + j*n + k
            smiles = pred[idx]
            vote_score = smiles_score_dic[smiles]
            vote_score_lis[i * num * n + j*n + k] = vote_score

In [83]:
pre_smiles_50_lis = []
pre_e_smiles_50_lis = []

for j in range(num):
    smiles_g = []
    vote_score_g = []
    e_smiles_g = []
    for i in range(fold):
        for k in range(n):
            idx = i * num * n + j*n + k
            smiles_g.append(pred[idx])
            vote_score_g.append(vote_score_lis[idx])
            e_smiles_g.append(e_smiles_lis[idx])
    #print(smiles_g)
    #print(vote_score_g)
    #print(e_smiles_g)
    
    zip_a_b_c = zip(smiles_g, e_smiles_g, vote_score_g)
    sorted_zip = sorted(zip_a_b_c, key=lambda x:x[-1],reverse= True)
    smiles_g, e_smiles_g, vote_score_g = zip(*sorted_zip)
    
    pre_smiles_50 = []
    pre_e_smiles_50 = []
    for i,j in zip(smiles_g, e_smiles_g):
        if i not in pre_smiles_50 and len(pre_smiles_50) < 50:
            pre_smiles_50.append(i)
            pre_e_smiles_50.append(j)

    if len(pre_smiles_50) < 50:
        pre_smiles_50 = pre_smiles_50 + [''] * (50-len(pre_smiles_50))
        pre_e_smiles_50 = pre_e_smiles_50 + [pre_e_smiles_50[-1].split('>>>')[0]+'>>>'] * (50-len(pre_e_smiles_50))
        
    pre_smiles_50_lis.append(pre_smiles_50)
    pre_e_smiles_50_lis.append(pre_e_smiles_50)

In [84]:
# calculate the accuracy of predicted smiles

print("Top-K Pred_Set Prediction: ")
k_lis = []
for i in range(num):
    rank = 11
    for k in range(n):
        if pre_smiles_50_lis[i][k] == truth[i]:
            rank = k+1
            break
    k_lis.append(rank)

top1 = 0
top3 = 0
top5 = 0
top10 = 0
for i in k_lis:
    if i <= 1:
        top1 += 1
    if i <= 3:
        top3 += 1
    if i <= 5:
        top5 += 1
    if i <= 10:
        top10 += 1
        
print(f'Top-1 Pred accuracy after TTA: {top1/num:.3f}')
print(f'Top-3 Pred accuracy after TTA: {top3/num:.3f}')
print(f'Top-5 Pred accuracy after TTA: {top5/num:.3f}')
print(f'Top-10 Pred accuracy after TTA: {top10/num:.3f}')

Top-K Pred_Set Prediction: 
Top-1 Pred accuracy after TTA: 0.589
Top-3 Pred accuracy after TTA: 0.805
Top-5 Pred accuracy after TTA: 0.864
Top-10 Pred accuracy after TTA: 0.914


In [85]:
# # To see the wrong cases

# print("Top-K Pred_Set Prediction: ")
# ks = [1, 3, 5, 10] # 
# pred_k = {k:0 for k in ks}

# for i in range(len(truth)):
#   for k in ks:
#     if truth[i] in pre_smiles_50_lis[i][:k]:
#         pred_k[k] += 1
#     else:
#       if k == 1:
#           print(i)
#           print("label:")
#           print(truth[i])
#           print("pred")
#           for j in range(n_best):
#               print(pre_smiles_50_lis[i][j], end='\n')
#           print()    

# for k in ks:
#   # print(pred_k[k])
#   print ('Top-%d Pred accuracy after TTA: %.3f' % (k, pred_k[k]/len(truth)))

### Calculate top-k accuracy for each reaction type

In [86]:
class_lis = []
with open('datasets/50k_e_smiles/aug20_test/2023_7_13_50k_test_r20_v.txt') as f:
    for line in f.readlines():
        class_lis.append(line.split(',')[1])
class_lis = class_lis[:5000]

In [87]:
for c in range(1,11):
    c ='class_' + str(c)
    num = 0
    top1 = 0
    top3 = 0
    top5 = 0
    top10 = 0
    for i, class_ in zip(k_lis,class_lis):
        if class_ == c:
            num += 1
            if i <= 1:
                top1 += 1
            if i <= 3:
                top3 += 1
            if i <= 5:
                top5 += 1
            if i <= 10:
                top10 += 1
        else:
            pass
    print(c,num)
    print(round(top1/num,4))
    print(round(top3/num,4))
    print(round(top5/num,4))
    print(round(top10/num,4))
    print('')

class_1 1511
0.6036
0.8345
0.8974
0.9437

class_2 1190
0.6891
0.9008
0.9437
0.9731

class_3 566
0.447
0.659
0.7261
0.811

class_4 91
0.5385
0.7473
0.7912
0.8462

class_5 68
0.6176
0.9118
0.9412
0.9412

class_6 824
0.5364
0.7524
0.8143
0.8568

class_7 462
0.6017
0.7987
0.8701
0.9329

class_8 82
0.7195
0.9024
0.9268
0.9756

class_9 183
0.3825
0.5683
0.6721
0.8142

class_10 23
0.7826
0.8696
0.8696
0.913



### Calculate top-k accuracy for each edit time

In [88]:
edit_times_lis = []
for i in range(5000):
    core_edits,chai_edits,stereo_edits,charge_edits,core_edits_add,lg_map_lis = get_edit_from_e_smiles(true_e_smiles_lis[i])
    edit_times = len(core_edits) + len(chai_edits) + len(stereo_edits) + len(charge_edits) + len(core_edits_add) + len(lg_map_lis)
    edit_times_lis.append(edit_times)

In [89]:
num = 0
top1 = 0
top3 = 0
top5 = 0
top10 = 0
e = 1
for i, edit_time in zip(k_lis,edit_times_lis):
    if edit_time == e:
        num += 1
        if i <= 1:
            top1 += 1
        if i <= 3:
            top3 += 1
        if i <= 5:
            top5 += 1
        if i <= 10:
            top10 += 1
    else:
        pass
print(e,num,num/5000)
print(round(top1/num,4))
print(round(top3/num,4))
print(round(top5/num,4))
print(round(top10/num,4))
print('')


1 194 0.0388
0.5361
0.7938
0.8866
0.9639



In [90]:
num = 0
top1 = 0
top3 = 0
top5 = 0
top10 = 0
e = 2
for i, edit_time in zip(k_lis,edit_times_lis):
    if edit_time == e:
        num += 1
        if i <= 1:
            top1 += 1
        if i <= 3:
            top3 += 1
        if i <= 5:
            top5 += 1
        if i <= 10:
            top10 += 1
    else:
        pass
print(e,num,num/5000)
print(round(top1/num,4))
print(round(top3/num,4))
print(round(top5/num,4))
print(round(top10/num,4))
print('')


2 3844 0.7688
0.6155
0.8356
0.891
0.9334



In [91]:
num = 0
top1 = 0
top3 = 0
top5 = 0
top10 = 0
e = 3
for i, edit_time in zip(k_lis,edit_times_lis):
    if edit_time == e:
        num += 1
        if i <= 1:
            top1 += 1
        if i <= 3:
            top3 += 1
        if i <= 5:
            top5 += 1
        if i <= 10:
            top10 += 1
    else:
        pass
print(e,num,num/5000)
print(round(top1/num,4))
print(round(top3/num,4))
print(round(top5/num,4))
print(round(top10/num,4))
print('')


3 648 0.1296
0.3997
0.6265
0.7083
0.8086



In [92]:
num = 0
top1 = 0
top3 = 0
top5 = 0
top10 = 0
e = 4
for i, edit_time in zip(k_lis,edit_times_lis):
    if edit_time == e:
        num += 1
        if i <= 1:
            top1 += 1
        if i <= 3:
            top3 += 1
        if i <= 5:
            top5 += 1
        if i <= 10:
            top10 += 1
    else:
        pass
print(e,num,num/5000)
print(round(top1/num,4))
print(round(top3/num,4))
print(round(top5/num,4))
print(round(top10/num,4))
print('')


4 238 0.0476
0.7353
0.8319
0.8655
0.8908



In [93]:
num = 0
top1 = 0
top3 = 0
top5 = 0
top10 = 0
e = 5
for i, edit_time in zip(k_lis,edit_times_lis):
    if edit_time == e:
        num += 1
        if i <= 1:
            top1 += 1
        if i <= 3:
            top3 += 1
        if i <= 5:
            top5 += 1
        if i <= 10:
            top10 += 1
    else:
        pass
print(e,num,num/5000)
print(round(top1/num,4))
print(round(top3/num,4))
print(round(top5/num,4))
print(round(top10/num,4))
print('')


5 28 0.0056
0.5
0.6071
0.6429
0.7143



In [94]:
num = 0
top1 = 0
top3 = 0
top5 = 0
top10 = 0
e = 6
for i, edit_time in zip(k_lis,edit_times_lis):
    if edit_time >= e:
        num += 1
        if i <= 1:
            top1 += 1
        if i <= 3:
            top3 += 1
        if i <= 5:
            top5 += 1
        if i <= 10:
            top10 += 1
    else:
        pass
print(e,num,num/5000)
print(round(top1/num,4))
print(round(top3/num,4))
print(round(top5/num,4))
print(round(top10/num,4))
print('')


6 48 0.0096
0.5208
0.75
0.7917
0.8333



### Calculate top-k accuracy for Reaction Center Identification

In [95]:
num = 5000

In [106]:
ture_e_smiles_lis = []
with open('datasets/50k_e_smiles/aug20_test/2023_7_13_50k_test_r20_v.txt') as f:
    for line in f.readlines():
        ture_e_smiles_lis.append(','.join(line.split(',')[2:]).replace('\n',''))

In [97]:
true_e_smiles_lis_g = [[true_e_smiles_lis[i + j*num] for j in range(20)] for i in range(num)]

In [98]:
len(true_e_smiles_lis_g)

5000

In [100]:
pre_e_smiles_50_lis_ =[[j.split('<')[0] for j in i] for i in pre_e_smiles_50_lis]

In [101]:
pre_e_smiles_50_lis_1 = []
for e_smiles_50 in pre_e_smiles_50_lis_:
    e_smiles_50_1 = []
    for e_smiles in e_smiles_50:
        if e_smiles not in e_smiles_50_1:
            e_smiles_50_1.append(e_smiles)
    if len(e_smiles_50_1) < 10:
        e_smiles_50_1 += ['']*(10 - len(e_smiles_50_1))
    pre_e_smiles_50_lis_1.append(e_smiles_50_1)

In [103]:
k_lis = []
for i in range(num):
    rank = 11
    for k in range(n):
        if pre_e_smiles_50_lis_1[i][k] in [i.split('<')[0] for i in true_e_smiles_lis_g[i]]:
            rank = k+1
            break
    k_lis.append(rank)
    
#k_lis

In [109]:
top1 = 0
top2 = 0
top3 = 0
top5 = 0
top10 = 0
for i in k_lis:
    if i <= 1:
        top1 += 1
    if i <= 2:
        top2 += 1
    if i <= 3:
        top3 += 1
    if i <= 5:
        top5 += 1
    if i <= 10:
        top10 += 1
        

print(f'Top-1 Pred Reaction Center Identification accuracy after TTA: {top1/num:.3f}')
print(f'Top-2 Pred Reaction Center Identification accuracy after TTA: {top2/num:.3f}')
print(f'Top-3 Pred Reaction Center Identification accuracy after TTA: {top3/num:.3f}')
print(f'Top-5 Pred Reaction Center Identification accuracy after TTA: {top5/num:.3f}')
print(f'Top-10 Pred Reaction Center Identification accuracy after TTA: {top10/num:.3f}')

Top-1 Pred Reaction Center Identification accuracy after TTA: 0.731
Top-2 Pred Reaction Center Identification accuracy after TTA: 0.873
Top-3 Pred Reaction Center Identification accuracy after TTA: 0.920
Top-5 Pred Reaction Center Identification accuracy after TTA: 0.955
Top-10 Pred Reaction Center Identification accuracy after TTA: 0.975
