## 反应直接预测结果分析
> 2024-11-08

### 1. 导入必要的包

In [1]:
# Standard Library Imports
import os
import sys

# Third-party Imports
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import plotly.graph_objects as go
from IPython.display import HTML
from pandarallel import pandarallel  # Importing pandarallel for parallel processing

# Setting up the path for the module
sys.path.insert(0, os.path.dirname(os.path.realpath('__file__')))
sys.path.insert(1, '../')

# Local Imports
from config import conf as cfg
from tools import btools
import evTools

FIRST_TIME_RUN = False

# Initialize parallel processing
pandarallel.initialize(progress_bar=False)

# Enable autoreloading of modules in IPython
%load_ext autoreload
%autoreload 2

INFO: Pandarallel will run on 128 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


### 2. 加载测试数据集

In [2]:
# 从 JSON 文件加载反应编码字典
with open(cfg.FILE_DS_DICT_RXN2ID, "r") as json_file:
    dict_rxn2id = json.load(json_file)
    print(f'加载反应编码字典完成，共有 {len(dict_rxn2id)} 个反应。')  # 打印加载的数据
    
print('Loading validation datasets feather path ...')
vali_feather_files = [
    f'{cfg.DIR_DATASET}validation/fold{fold_index}/valid.feather' 
    for fold_index in range(1, 11)
]

# load datasets
ds_test =[pd.read_feather(vali_feather_files[item])[['uniprot_id', 'reaction_id']].rename(columns={'reaction_id': 'rxn_groundtruth'}) for item in tqdm(range(10))]


print('Loading uniprot_rxn_dict ...' )
d1 = pd.read_feather(cfg.FILE_DS_TRAIN)
d2 = pd.read_feather(cfg.FILE_DS_TEST)
uniprot_rxn_dict = pd.concat([d1,d2], axis=0).reset_index(drop=True)[['uniprot_id', 'reaction_id']].set_index('uniprot_id')['reaction_id'].to_dict()


加载反应编码字典完成，共有 10479 个反应。
Loading validation datasets feather path ...


100%|██████████| 10/10 [00:04<00:00,  2.06it/s]


Loading uniprot_rxn_dict ...


## 4. Load results from EC based method

### 4.1 Blast

In [4]:
vali_res_blast = [
    f'{cfg.DIR_RES_BASELINE}results/direct_methods/blast/fold{item}.tsv' 
    for item in range(1, 11)
]
res_blast = evTools.read_10fold_res_csv_files(vali_res_blast)
df_blast_no_pred = evTools.process_no_res(res_blast, rxnkey='rxn_blast')
# 补充groud truth
res_blast = [ds_test[item].merge(res_blast[item], on='uniprot_id', how='left') for item in range(10)]
# add labels
res_blast = evTools.apply_labels(res_blast, 'rxn_groundtruth', 'rxn_blast', 'lb_rxn_groundtruth', 'lb_rxn_blast', dict_rxn2id)


# res_blast_metrics =evTools.calculate_metrics_parallel(res_unirep=res_blast, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_blast', avg_method='macro',  max_workers=15)
# res_blast_metrics = res_blast_metrics.reset_index(drop=True)
# res_blast_metrics = pd.concat([res_blast_metrics, df_blast_no_pred], axis=1)
# res_blast_metrics.baselineName = 'blast_direct'
# res_blast_metrics['runFold'] = res_blast_metrics.index+1
# res_blast_metrics.to_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_blast_direct_metrics.feather')
# # res_blast_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_blast_direct_metrics.feather')
# res_blast_fold_std = res_blast_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1', 'no_prediction']].agg(['mean', 'std'])
# evTools.display_html_results(res_blast_metrics, res_blast_fold_std, 'Blast')

100%|██████████| 10/10 [01:43<00:00, 10.38s/it]


In [None]:
vali_res_blast = [
    f'{cfg.DIR_RES_BASELINE}results/direct_methods/blast/fold{item}.tsv' 
    for item in range(1, 11)
]
res_blast = evTools.read_10fold_res_csv_files(vali_res_blast)
df_blast_no_pred = evTools.process_no_res(res_blast, rxnkey='rxn_blast')
# 补充groud truth
res_blast = [ds_test[item].merge(res_blast[item], on='uniprot_id', how='left') for item in range(10)]
# add labels
res_blast = evTools.apply_labels(res_blast, 'rxn_groundtruth', 'rxn_blast', 'lb_rxn_groundtruth', 'lb_rxn_blast', dict_rxn2id)


res_blast_metrics =evTools.calculate_metrics_parallel(res_unirep=res_blast, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_blast', avg_method='marco',  max_workers=15)
res_blast_metrics = res_blast_metrics.reset_index(drop=True)
res_blast_metrics = pd.concat([res_blast_metrics, df_blast_no_pred], axis=1)
res_blast_metrics.baselineName = 'blast_direct'
res_blast_metrics['runFold'] = res_blast_metrics.index+1
res_blast_metrics.to_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_blast_direct_metrics.feather')
# res_blast_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_blast_direct_metrics.feather')
res_blast_fold_std = res_blast_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1', 'no_prediction']].agg(['mean', 'std'])
evTools.display_html_results(res_blast_metrics, res_blast_fold_std, 'Blast')

100%|██████████| 10/10 [01:42<00:00, 10.29s/it]


Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,test_size,no_prediction,runFold
0,blast_direct,0.831649,0.879623,0.950236,0.895777,50858,2424,1
1,blast_direct,0.832711,0.875751,0.949407,0.892806,50858,2386,2
2,blast_direct,0.832435,0.876919,0.951163,0.894495,50858,2379,3
3,blast_direct,0.834008,0.879929,0.951624,0.897056,50858,2301,4
4,blast_direct,0.83391,0.879941,0.950015,0.896753,50858,2418,5
5,blast_direct,0.835896,0.882497,0.950218,0.89812,50858,2391,6
6,blast_direct,0.832593,0.876341,0.949912,0.894036,50858,2369,7
7,blast_direct,0.832671,0.882323,0.950886,0.89806,50858,2381,8
8,blast_direct,0.830076,0.87714,0.949523,0.894221,50858,2449,9
9,blast_direct,0.83102,0.877479,0.950026,0.894609,50858,2431,10

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1,no_prediction
mean,0.832697,0.878794,0.950301,0.895593,2392.9
std,0.001643,0.00242,0.000712,0.001833,41.631852


### 4.2 Unirep

In [None]:
embd_methd = 'unirep'
file_res_unirep = [f'{cfg.RESULTS_DIR}simi/fold_{fold_num}_{embd_methd}_results.h5' for fold_num in range(1,11)]
res_unirep = [evTools.read_h5_file(item)for item in tqdm(file_res_unirep)]

# 获取反应ID
for i in tqdm(range(10)):
    res_unirep[i]['rxn_euclidean'] = res_unirep[i].euclidean.apply(lambda x : evTools.get_simi_Pred(pred_list=x, uniprot_rxn_dict=uniprot_rxn_dict))
    res_unirep[i]['rxn_cosine'] = res_unirep[i].cosine.apply(lambda x : evTools.get_simi_Pred(pred_list=x, uniprot_rxn_dict=uniprot_rxn_dict))
    
    

# 将反应ID标签化    
res_unirep = evTools.apply_labels(res_unirep, 'reaction_id', 'rxn_euclidean', 'lb_rxn_groundtruth', 'lb_rxn_unirep_euclidean', dict_rxn2id)
for i in tqdm(range(10)):
    res_unirep[i]['lb_rxn_unirep_cosine'] = res_unirep[i].rxn_cosine.parallel_apply(lambda x :btools.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))
    

# 计算评价指标
res_unirep_euclidean_metrics = evTools.calculate_metrics_parallel(res_unirep=res_unirep, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_unirep_euclidean', max_workers=15)
res_unirep_cosine_metrics = evTools.calculate_metrics_parallel(res_unirep=res_unirep, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_unirep_cosine', max_workers=15)


res_unirep_euclidean_metrics['baselineName'] = 'unirep_eu'
res_unirep_cosine_metrics['baselineName'] = 'unirep_cos'
res_unirep_euclidean_metrics['runFold'] = res_unirep_euclidean_metrics.index+1
res_unirep_cosine_metrics['runFold'] = res_unirep_cosine_metrics.index+1

res_unirep = pd.concat([res_unirep_euclidean_metrics, res_unirep_cosine_metrics], axis=0).reset_index(drop=True)
res_unirep.to_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_unirep_metrics.feather')


res_unirep_euclidean_fold_std = res_unirep_euclidean_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])
res_unirep_cosine_fold_std = res_unirep_cosine_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])


In [12]:
res_unirep

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,runFold
0,unirep_eu,0.889575,0.892769,0.963311,0.915861,1
1,unirep_eu,0.889595,0.892241,0.960985,0.913871,2
2,unirep_eu,0.888946,0.88975,0.961411,0.913392,3
3,unirep_eu,0.891345,0.894308,0.96068,0.915687,4
4,unirep_eu,0.887687,0.893174,0.960435,0.914768,5
5,unirep_eu,0.889437,0.891764,0.962008,0.914212,6
6,unirep_eu,0.891325,0.893932,0.961476,0.9151,7
7,unirep_eu,0.891895,0.894992,0.963322,0.917196,8
8,unirep_eu,0.887353,0.890655,0.962801,0.914369,9
9,unirep_eu,0.88985,0.893834,0.961377,0.915201,10


In [13]:
HTML(f"""
         <div style="float:left; width:600px;">
              <h2 style='color:blue'>{'Unirep (Euclidean) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_unirep_euclidean_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'Unirep (Euclidean) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_unirep_euclidean_fold_std.to_html()}
         </div>
         
        <div style="float:left; display:block; width:600px;">
              <h2 style='color:blue'>{'Unirep (Cosine) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_unirep_cosine_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'Unirep (Cosine) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_unirep_cosine_fold_std.to_html()}
         </div>
         """)

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,runFold
0,unirep_eu,0.889575,0.892769,0.963311,0.915861,1
1,unirep_eu,0.889595,0.892241,0.960985,0.913871,2
2,unirep_eu,0.888946,0.88975,0.961411,0.913392,3
3,unirep_eu,0.891345,0.894308,0.96068,0.915687,4
4,unirep_eu,0.887687,0.893174,0.960435,0.914768,5
5,unirep_eu,0.889437,0.891764,0.962008,0.914212,6
6,unirep_eu,0.891325,0.893932,0.961476,0.9151,7
7,unirep_eu,0.891895,0.894992,0.963322,0.917196,8
8,unirep_eu,0.887353,0.890655,0.962801,0.914369,9
9,unirep_eu,0.88985,0.893834,0.961377,0.915201,10

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.889701,0.892742,0.961781,0.914966
std,0.001506,0.001661,0.001047,0.001106

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,runFold
0,unirep_cos,0.890656,0.89393,0.964254,0.916958,1
1,unirep_cos,0.890578,0.894111,0.961342,0.914934,2
2,unirep_cos,0.890696,0.890728,0.961991,0.914388,3
3,unirep_cos,0.892249,0.895324,0.961105,0.916438,4
4,unirep_cos,0.888356,0.894634,0.961083,0.915965,5
5,unirep_cos,0.890578,0.894673,0.962042,0.916023,6
6,unirep_cos,0.892642,0.895804,0.961767,0.916512,7
7,unirep_cos,0.893134,0.897074,0.963579,0.918411,8
8,unirep_cos,0.888474,0.890809,0.962595,0.914526,9
9,unirep_cos,0.890243,0.894445,0.962025,0.915833,10

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.890761,0.894153,0.962178,0.915999
std,0.00159,0.002005,0.00104,0.001208


### 4.3 ESM

In [None]:
embd_methd = 'esm'
file_res_esm = [f'{cfg.RESULTS_DIR}simi/fold_{fold_num}_{embd_methd}_results.h5' for fold_num in range(1,11)]
res_esm = [evTools.read_h5_file(item)for item in tqdm(file_res_esm)]

# 获取反应ID
for i in tqdm(range(10)):
    # res_esm[i]['rxn_euclidean'] = res_esm[i].euclidean.apply(lambda x : uniprot_rxn_dict.get(x[0][0]))
    res_esm[i]['rxn_euclidean'] = res_esm[i].euclidean.apply(lambda x : evTools.get_simi_Pred(pred_list=x, uniprot_rxn_dict=uniprot_rxn_dict, topk=6))
    res_esm[i]['rxn_cosine'] = res_esm[i].cosine.apply(lambda x : evTools.get_simi_Pred(pred_list=x, uniprot_rxn_dict=uniprot_rxn_dict, topk=6))

# 将反应ID标签化    
res_esm = evTools.apply_labels(res_esm, 'reaction_id', 'rxn_euclidean', 'lb_rxn_groundtruth', 'lb_rxn_esm_euclidean', dict_rxn2id)
for i in tqdm(range(10)):
    res_esm[i]['lb_rxn_esm_cosine'] = res_esm[i].rxn_cosine.parallel_apply(lambda x :btools.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))
    
# 计算评价指标
res_esm_euclidean_metrics = evTools.calculate_metrics_parallel(res_unirep=res_esm, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_esm_euclidean', max_workers=15)
res_esm_cosine_metrics = evTools.calculate_metrics_parallel(res_unirep=res_esm, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_esm_cosine', max_workers=15)


res_esm_euclidean_metrics['baselineName'] = 'esm_eu'
res_esm_cosine_metrics['baselineName'] = 'esm_cos'

res_esm_euclidean_metrics['runFold'] = res_esm_euclidean_metrics.index+1
res_esm_cosine_metrics['runFold'] = res_esm_cosine_metrics.index+1

res_esm = pd.concat([res_esm_euclidean_metrics, res_esm_cosine_metrics], axis=0).reset_index(drop=True)
res_esm.to_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_esm_metrics.feather')

# res_esm_euclidean_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_esm_euclidean_metrics.feather')
# res_esm_cosine_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_esm_cosine_metrics.feather')
res_esm_euclidean_fold_std = res_esm_euclidean_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])
res_esm_cosine_fold_std = res_esm_cosine_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])

100%|██████████| 10/10 [10:02<00:00, 60.27s/it]
100%|██████████| 10/10 [00:14<00:00,  1.46s/it]
100%|██████████| 10/10 [03:20<00:00, 20.05s/it]
100%|██████████| 10/10 [32:42<00:00, 196.28s/it]


In [None]:
HTML(f"""
         <div style="float:left; width:600px;">
              <h2 style='color:blue'>{'ESM (Euclidean) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_esm_euclidean_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'ESM (Euclidean) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_esm_euclidean_fold_std.to_html()}
         </div>
         
        <div style="float:left; display:block; width:600px;">
              <h2 style='color:blue'>{'ESM (Cosine) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_esm_cosine_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'ESM (Cosine) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_esm_cosine_fold_std.to_html()}
         </div>
         """)

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,runFold
0,esm_eu,0.887491,0.882065,0.984269,0.918213,1
0,esm_eu,0.886095,0.881635,0.982224,0.915974,1
0,esm_eu,0.886508,0.880745,0.982181,0.916297,1
0,esm_eu,0.889654,0.883257,0.982297,0.917988,1
0,esm_eu,0.888454,0.883698,0.983287,0.918886,1
0,esm_eu,0.888081,0.882835,0.983223,0.917706,1
0,esm_eu,0.889083,0.884422,0.982791,0.918351,1
0,esm_eu,0.888061,0.882121,0.984369,0.917552,1
0,esm_eu,0.886252,0.880343,0.983591,0.91646,1
0,esm_eu,0.885819,0.880933,0.982683,0.916066,1

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.88755,0.882205,0.983091,0.917349
std,0.001334,0.001342,0.000802,0.00106

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,runFold
0,fold1,0.891109,0.88514,0.9848,0.921256,1
0,fold2,0.889909,0.884424,0.982394,0.918346,1
0,fold3,0.890381,0.883653,0.982539,0.918754,1
0,fold4,0.892839,0.886935,0.982893,0.921122,1
0,fold5,0.892819,0.886306,0.983253,0.921114,1
0,fold6,0.892485,0.88647,0.983377,0.920763,1
0,fold7,0.892721,0.887791,0.982808,0.921154,1
0,fold8,0.892131,0.885634,0.984454,0.920569,1
0,fold9,0.889378,0.882434,0.983916,0.918471,1
0,fold10,0.890774,0.885339,0.983143,0.919844,1

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.891455,0.885413,0.983358,0.920139
std,0.001305,0.001597,0.000801,0.001191


### 4.4 T5

In [None]:
embd_methd = 't5'
file_res_t5 = [f'{cfg.RESULTS_DIR}simi/fold_{fold_num}_{embd_methd}_results.h5' for fold_num in range(1,11)]
print('Loading T5 results...')
res_t5 = [evTools.read_h5_file(item)for item in tqdm(file_res_t5)]

print('Adding reaction ID labels to T5 results...')
# 获取反应ID
for i in tqdm(range(10)):
    res_t5[i]['rxn_euclidean'] = res_t5[i].euclidean.apply(lambda x : evTools.get_simi_Pred(pred_list=x, uniprot_rxn_dict=uniprot_rxn_dict, topk=10))
    res_t5[i]['rxn_cosine'] = res_t5[i].cosine.apply(lambda x : evTools.get_simi_Pred(pred_list=x, uniprot_rxn_dict=uniprot_rxn_dict, topk=10))

# 将反应ID标签化    
res_t5 = evTools.apply_labels(res_t5, 'reaction_id', 'rxn_euclidean', 'lb_rxn_groundtruth', 'lb_rxn_t5_euclidean', dict_rxn2id)
for i in tqdm(range(10)):
    res_t5[i]['lb_rxn_t5_cosine'] = res_t5[i].rxn_cosine.parallel_apply(lambda x :btools.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))


print('Calculating metrics for T5 (Euclidean)...')
# 计算评价指标
res_t5_euclidean_metrics = evTools.calculate_metrics_parallel(res_unirep=res_t5, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_t5_euclidean', max_workers=15)
print('Calculating metrics for T5 (Cosine)...')
res_t5_cosine_metrics = evTools.calculate_metrics_parallel(res_unirep=res_t5, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_t5_cosine', max_workers=15)


res_t5_euclidean_metrics['baselineName'] = 't5_eu'
res_t5_cosine_metrics['baselineName'] = 't5_cos'

res_t5_euclidean_metrics['runFold'] = res_t5_euclidean_metrics.index+1
res_t5_cosine_metrics['runFold'] = res_t5_cosine_metrics.index+1

res_t5 = pd.concat([res_t5_euclidean_metrics, res_t5_cosine_metrics], axis=0).reset_index(drop=True)
res_t5.to_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_t5_metrics.feather')

# res_t5_euclidean_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_t5_euclidean_metrics.feather')
# res_t5_cosine_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_t5_cosine_metrics.feather')
res_t5_euclidean_fold_std = res_t5_euclidean_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])
res_t5_cosine_fold_std = res_t5_cosine_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])

HTML(f"""
         <div style="float:left; width:600px;">
              <h2 style='color:blue'>{'T5 (Euclidean) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_t5_euclidean_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'T5 (Euclidean) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_t5_euclidean_fold_std.to_html()}
         </div>
         
        <div style="float:left; display:block; width:600px;">
              <h2 style='color:blue'>{'T5 (Cosine) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_t5_cosine_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'T5 (Cosine) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_t5_cosine_fold_std.to_html()}
         </div>
         """)

Loading T5 results...


100%|██████████| 10/10 [09:41<00:00, 58.11s/it]


Adding reaction ID labels to T5 results...


100%|██████████| 10/10 [00:14<00:00,  1.48s/it]
100%|██████████| 10/10 [02:53<00:00, 17.37s/it]
100%|██████████| 10/10 [14:44<00:00, 88.46s/it] 


Calculating metrics for T5 (Euclidean)...
Calculating metrics for T5 (Cosine)...


Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,runFold
0,esm_eu,0.8963,0.895696,0.988347,0.927489,1
1,esm_eu,0.895906,0.892883,0.987284,0.924237,2
2,esm_eu,0.895257,0.892595,0.987569,0.92497,3
3,esm_eu,0.897774,0.898979,0.987319,0.929176,4
4,esm_eu,0.897637,0.895469,0.987653,0.927142,5
5,esm_eu,0.896732,0.899115,0.98761,0.928954,6
6,esm_eu,0.89801,0.895741,0.986948,0.926036,7
7,esm_eu,0.896712,0.894307,0.98823,0.92627,8
8,esm_eu,0.893448,0.88967,0.988005,0.923425,9
9,esm_eu,0.894333,0.891083,0.987728,0.924059,10

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.896211,0.894554,0.987669,0.926176
std,0.001506,0.003101,0.000434,0.002018

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,runFold
0,esm_cos,0.898226,0.897244,0.988467,0.928613,1
1,esm_cos,0.898246,0.894508,0.987165,0.925326,2
2,esm_cos,0.897184,0.893745,0.987705,0.926053,3
3,esm_cos,0.899033,0.898972,0.987251,0.929163,4
4,esm_cos,0.89976,0.89707,0.987772,0.928342,5
5,esm_cos,0.898679,0.900294,0.987901,0.929988,6
6,esm_cos,0.900016,0.896876,0.986759,0.926894,7
7,esm_cos,0.899603,0.89695,0.988093,0.928039,8
8,esm_cos,0.895867,0.890981,0.987988,0.924218,9
9,esm_cos,0.895887,0.893186,0.98766,0.92538,10

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.89825,0.895982,0.987676,0.927202
std,0.001506,0.002829,0.000499,0.001908


In [7]:
embd_methd = 'tdit5'
file_res_t5tdi = [f'{cfg.DIR_PROJECT_ROOT}/results/intermediate/direct/{embd_methd}_fold{item}.tsv' for item in range(1, 11)]
print('Loading T5TDI results...')
res_t5tdi = [pd.read_csv(file_res_t5tdi[item], sep='\t') for item in range(10)]
print('Adding reaction ID labels to T5TDI results...')

Loading T5TDI results...
Adding reaction ID labels to T5TDI results...


In [8]:
file_res_t5 = [f'{cfg.RESULTS_DIR}simi/fold_{fold_num}_{embd_methd}_results.h5' for fold_num in range(1,11)]
print('Loading T5 results...')
res_t5 = [evTools.read_h5_file(item)for item in tqdm(file_res_t5)]

Loading T5 results...




  0%|          | 0/10 [19:16<?, ?it/s]
  0%|          | 0/10 [06:19<?, ?it/s]
  0%|          | 0/10 [02:25<?, ?it/s]


[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

100%|██████████| 10/10 [10:08<00:00, 60.83s/it]


In [None]:
embd_methd = 't5'
file_res_t5 = [f'{cfg.RESULTS_DIR}simi/fold_{fold_num}_{embd_methd}_results.h5' for fold_num in range(1,11)]
print('Loading T5 results...')
res_t5 = [evTools.read_h5_file(item)for item in tqdm(file_res_t5)]

print('Adding reaction ID labels to T5 results...')
# 获取反应ID
for i in tqdm(range(10)):
    res_t5[i]['rxn_euclidean'] = res_t5[i].euclidean.apply(lambda x : evTools.get_simi_Pred(pred_list=x, uniprot_rxn_dict=uniprot_rxn_dict, topk=10))
    res_t5[i]['rxn_cosine'] = res_t5[i].cosine.apply(lambda x : evTools.get_simi_Pred(pred_list=x, uniprot_rxn_dict=uniprot_rxn_dict, topk=10))

# 将反应ID标签化    
res_t5 = evTools.apply_labels(res_t5, 'reaction_id', 'rxn_euclidean', 'lb_rxn_groundtruth', 'lb_rxn_t5_euclidean', dict_rxn2id)
for i in tqdm(range(10)):
    res_t5[i]['lb_rxn_t5_cosine'] = res_t5[i].rxn_cosine.parallel_apply(lambda x :btools.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))


print('Calculating metrics for T5 (Euclidean)...')
# 计算评价指标
res_t5_euclidean_metrics = evTools.calculate_metrics_parallel(res_unirep=res_t5, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_t5_euclidean', max_workers=15)
print('Calculating metrics for T5 (Cosine)...')
res_t5_cosine_metrics = evTools.calculate_metrics_parallel(res_unirep=res_t5, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_t5_cosine', max_workers=15)


res_t5_euclidean_metrics['baselineName'] = 't5_eu'
res_t5_cosine_metrics['baselineName'] = 't5_cos'

res_t5_euclidean_metrics['runFold'] = res_t5_euclidean_metrics.index+1
res_t5_cosine_metrics['runFold'] = res_t5_cosine_metrics.index+1

res_t5 = pd.concat([res_t5_euclidean_metrics, res_t5_cosine_metrics], axis=0).reset_index(drop=True)
res_t5.to_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_t5_metrics.feather')

# res_t5_euclidean_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_t5_euclidean_metrics.feather')
# res_t5_cosine_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_t5_cosine_metrics.feather')
res_t5_euclidean_fold_std = res_t5_euclidean_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])
res_t5_cosine_fold_std = res_t5_cosine_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])

HTML(f"""
         <div style="float:left; width:600px;">
              <h2 style='color:blue'>{'T5 (Euclidean) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_t5_euclidean_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'T5 (Euclidean) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_t5_euclidean_fold_std.to_html()}
         </div>
         
        <div style="float:left; display:block; width:600px;">
              <h2 style='color:blue'>{'T5 (Cosine) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_t5_cosine_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'T5 (Cosine) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_t5_cosine_fold_std.to_html()}
         </div>
         """)

# 5. 整合指标

In [23]:
res_metrics_blast  = pd.read_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_blast_direct_metrics.feather')
res_metrics_unirep = pd.read_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_unirep_metrics.feather')
res_metrics_esm = pd.read_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_esm_metrics.feather')
res_metrics_t5 = pd.read_feather(f'{cfg.DIR_PROJECT_ROOT}/evaluation/data/res_t5_metrics.feather')

In [None]:
res_metrics_blast  = pd.read_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_blast_direct_metrics.feather')
res_metrics_unirep_eu = pd.read_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_unirep_euclidean_metrics.feather')
res_metrics_unirep_cos = pd.read_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_unirep_cosine_metrics.feather')
res_metrics_esm_eu = pd.read_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_esm_euclidean_metrics.feather')
res_metrics_esm_cos = pd.read_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_esm_cosine_metrics.feather')


res_metrics = pd.concat([res_metrics_blast, res_metrics_unirep_eu, res_metrics_unirep_cos, res_metrics_esm_eu, res_metrics_esm_cos, res_metrics_t5_eu, res_metrics_t5_cos], axis=0).reset_index(drop=True)

In [24]:
res_metrics_blast

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,test_size,no_prediction,runFold
0,blast_direct,0.831649,0.879623,0.950236,0.895777,50858,2424,1
1,blast_direct,0.832711,0.875751,0.949407,0.892806,50858,2386,2
2,blast_direct,0.832435,0.876919,0.951163,0.894495,50858,2379,3
3,blast_direct,0.834008,0.879929,0.951624,0.897056,50858,2301,4
4,blast_direct,0.83391,0.879941,0.950015,0.896753,50858,2418,5
5,blast_direct,0.835896,0.882497,0.950218,0.89812,50858,2391,6
6,blast_direct,0.832593,0.876341,0.949912,0.894036,50858,2369,7
7,blast_direct,0.832671,0.882323,0.950886,0.89806,50858,2381,8
8,blast_direct,0.830076,0.87714,0.949523,0.894221,50858,2449,9
9,blast_direct,0.83102,0.877479,0.950026,0.894609,50858,2431,10


In [22]:
res_metrics_unirep

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,runFold
0,unirep_eu,0.889575,0.892769,0.963311,0.915861,1
1,unirep_eu,0.889595,0.892241,0.960985,0.913871,2
2,unirep_eu,0.888946,0.88975,0.961411,0.913392,3
3,unirep_eu,0.891345,0.894308,0.96068,0.915687,4
4,unirep_eu,0.887687,0.893174,0.960435,0.914768,5
5,unirep_eu,0.889437,0.891764,0.962008,0.914212,6
6,unirep_eu,0.891325,0.893932,0.961476,0.9151,7
7,unirep_eu,0.891895,0.894992,0.963322,0.917196,8
8,unirep_eu,0.887353,0.890655,0.962801,0.914369,9
9,unirep_eu,0.88985,0.893834,0.961377,0.915201,10


In [None]:
res_metrics_unirep_eu

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1
0,fold1,0.889575,0.892769,0.963311,0.915861
0,fold2,0.889595,0.892241,0.960985,0.913871
0,fold3,0.888946,0.88975,0.961411,0.913392
0,fold4,0.891345,0.894308,0.96068,0.915687
0,fold5,0.887687,0.893174,0.960435,0.914768
0,fold6,0.889437,0.891764,0.962008,0.914212
0,fold7,0.891325,0.893932,0.961476,0.9151
0,fold8,0.891895,0.894992,0.963322,0.917196
0,fold9,0.887353,0.890655,0.962801,0.914369
0,fold10,0.88985,0.893834,0.961377,0.915201
