## 反应直接预测结果分析
> 2024-11-08

### 1. 导入必要的包

In [9]:
import sys,os
sys.path.insert(0, os.path.dirname(os.path.realpath('__file__')))
sys.path.insert(1,'../')
from config import conf as cfg
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import plotly.graph_objects as go
from tools import btools
from IPython.display import HTML
from pandarallel import pandarallel # 导入pandaralle

pandarallel.initialize(progress_bar=False)
%load_ext autoreload
%autoreload 2

INFO: Pandarallel will run on 128 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### 2. 加载测试数据集

In [2]:
# Read CSV files serially
def read_csv_files(file_paths):
    return [pd.read_csv(file, sep='\t') for file in file_paths]

# Function to get ec_rxn_nores
def get_ec_rxn_nores(pred_detail,  rxnkey):

    no_prediction = len(pred_detail[(pred_detail[rxnkey].str.contains('NO-PREDICTION'))])
    return [len(pred_detail), no_prediction]

def process_no_res(res_list,  rxnkey):
    return pd.DataFrame([get_ec_rxn_nores(pred_detail=res_list[item], rxnkey=rxnkey) for item in range(10)], 
                        columns=['test_size', 'no_prediction'])

# Make one-hot encoding label for each prediction
def make_labels(resdf, src_col1, src_col2, lb1, lb2, rxn_label_dict):
    resdf[[lb1, lb2]] = resdf.apply(
        lambda row: pd.Series({
            lb1: btools.make_label(reaction_id=str(row[src_col1]), rxn_label_dict=rxn_label_dict),
            lb2: btools.make_label(reaction_id=str(row[src_col2]), rxn_label_dict=rxn_label_dict)
        }), axis=1
    )
    return resdf

def apply_labels(res_list, src_col1, src_col2, lb1, lb2, rxn_label_dict):
    for i in tqdm(range(10)):
        res_list[i] = make_labels(resdf=res_list[i], src_col1=src_col1, src_col2=src_col2, lb1=lb1, lb2=lb2, rxn_label_dict=rxn_label_dict)
    return res_list


# Function to calculate metrics
def calculate_metrics(eva_df, ground_truth_col, pred_col, eva_name):
    res =  btools.rxn_eva_metric_with_colName(eva_df=eva_df, col_groundtruth=ground_truth_col, col_pred=pred_col, eva_name=eva_name)
    return res

# 多线程运行评价函数
def calculate_metrics_parallel(res_unirep, ground_truth_col, pred_col, max_workers=None):
    def run_metric_evaluation(index):
        return calculate_metrics(eva_df=res_unirep[index], ground_truth_col=ground_truth_col, pred_col=pred_col, eva_name=f'fold{index + 1}')
    
    results = [None] * len(res_unirep)
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(run_metric_evaluation, i): i
            for i in range(len(res_unirep))
        }
        for future in as_completed(futures):
            i = futures[future]
            results[i] = future.result()
            
    results = pd.concat(results,axis=0)
    
    return results


# Function to display results as HTML
def display_html_results(metrics, fold_std, eva_name):
    return HTML(f"""
         <div style="float:left; width:900px;">
              <h2 style='color:blue'>{eva_name} Evaluation 10 Fold Details</h2>
              {metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{eva_name} Evaluation 10 Fold Overview</h2>
                   {fold_std.to_html()}
         </div>
         """)

In [3]:
# 从 JSON 文件加载反应编码字典
with open(cfg.FILE_DS_DICT_RXN2ID, "r") as json_file:
    dict_rxn2id = json.load(json_file)
    print(f'加载反应编码字典完成，共有 {len(dict_rxn2id)} 个反应。')  # 打印加载的数据
    
print('Loading validation datasets feather path ...')
vali_feather_files = [
    f'{cfg.DIR_DATASET}validation/fold{fold_index}/valid.feather' 
    for fold_index in range(1, 11)
]

# load datasets
ds_test =[pd.read_feather(vali_feather_files[item])[['uniprot_id', 'reaction_id']].rename(columns={'reaction_id': 'rxn_groundtruth'}) for item in tqdm(range(10))]


def read_h5_file(file_path):
    with pd.HDFStore(file_path, 'r') as h5:
        data = h5['data']
    return data

print('Loading uniprot_rxn_dict ...' )
d1 = pd.read_feather(cfg.FILE_DS_TRAIN)
d2 = pd.read_feather(cfg.FILE_DS_TEST)
uniprot_rxn_dict = pd.concat([d1,d2], axis=0).reset_index(drop=True)[['uniprot_id', 'reaction_id']].set_index('uniprot_id')['reaction_id'].to_dict()


加载反应编码字典完成，共有 10479 个反应。
Loading validation datasets feather path ...


  0%|          | 0/10 [00:00<?, ?it/s]

Loading uniprot_rxn_dict ...


## 4. Load results from EC based method

### 4.1 Blast

In [None]:
vali_res_blast = [
    f'{cfg.DIR_RES_BASELINE}results/direct_methods/blast/fold{item}.tsv' 
    for item in range(1, 11)
]
res_blast = read_csv_files(vali_res_blast)
df_blast_no_pred = process_no_res(res_blast, rxnkey='rxn_blast')
# 补充groud truth
res_blast = [ds_test[item].merge(res_blast[item], on='uniprot_id', how='left') for item in range(10)]
# add labels
res_blast = apply_labels(res_blast, 'rxn_groundtruth', 'rxn_blast', 'lb_rxn_groundtruth', 'lb_rxn_blast', dict_rxn2id)


res_blast_metrics =calculate_metrics_parallel(res_unirep=res_blast, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_blast', max_workers=15)
res_blast_metrics = res_blast_metrics.reset_index(drop=True)
res_blast_metrics = pd.concat([res_blast_metrics, df_blast_no_pred], axis=1)
res_blast_metrics.baselineName = 'Blast_direct'

res_blast_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_blast_direct_metrics.feather')
res_blast_fold_std = res_blast_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1', 'no_prediction']].agg(['mean', 'std'])
display_html_results(res_blast_metrics, res_blast_fold_std, 'Blast')

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1,test_size,no_prediction
0,Blast_direct,0.863876,0.90778,0.949859,0.914526,50858,2424
1,Blast_direct,0.863542,0.90567,0.948931,0.912737,50858,2386
2,Blast_direct,0.864544,0.905471,0.950634,0.913331,50858,2379
3,Blast_direct,0.865449,0.907955,0.951079,0.915531,50858,2301
4,Blast_direct,0.865705,0.908164,0.949453,0.915109,50858,2418
5,Blast_direct,0.865488,0.909392,0.949686,0.915556,50858,2391
6,Blast_direct,0.864367,0.906941,0.949399,0.914027,50858,2369
7,Blast_direct,0.864033,0.909656,0.950459,0.916263,50858,2381
8,Blast_direct,0.861595,0.90514,0.949163,0.913027,50858,2449
9,Blast_direct,0.862421,0.905756,0.949344,0.913372,50858,2431

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1,no_prediction
mean,0.864102,0.907192,0.949801,0.914348,2392.9
std,0.001335,0.001646,0.000702,0.001228,41.631852


### 4.2 Unirep

In [None]:
embd_methd = 'unirep'
file_res_unirep = [f'{cfg.RESULTS_DIR}simi/fold_{fold_num}_{embd_methd}_results.h5' for fold_num in range(1,11)]
res_unirep = [read_h5_file(item)for item in tqdm(file_res_unirep)]

# 获取反应ID
for i in tqdm(range(10)):
    res_unirep[i]['rxn_euclidean'] = res_unirep[i].euclidean.apply(lambda x : uniprot_rxn_dict.get(x[0][0]))
    res_unirep[i]['rxn_cosine'] = res_unirep[i].cosine.apply(lambda x : uniprot_rxn_dict.get(x[0][0]))

# 将反应ID标签化    
res_unirep = apply_labels(res_unirep, 'reaction_id', 'rxn_euclidean', 'lb_rxn_groundtruth', 'lb_rxn_unirep_euclidean', dict_rxn2id)
for i in tqdm(range(10)):
    res_unirep[i]['lb_rxn_unirep_cosine'] = res_unirep[i].rxn_cosine.parallel_apply(lambda x :btools.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))



# 计算评价指标
res_unirep_euclidean_metrics = calculate_metrics_parallel(res_unirep=res_unirep, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_unirep_euclidean', max_workers=15)
res_unirep_cosine_metrics = calculate_metrics_parallel(res_unirep=res_unirep, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_unirep_cosine', max_workers=15)

res_unirep_euclidean_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_unirep_euclidean_metrics.feather')
res_unirep_cosine_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_unirep_cosine_metrics.feather')

res_unirep_euclidean_fold_std = res_unirep_euclidean_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])
res_unirep_cosine_fold_std = res_unirep_cosine_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])


  0%|          | 0/10 [00:00<?, ?it/s]

In [141]:
HTML(f"""
         <div style="float:left; width:600px;">
              <h2 style='color:blue'>{'Unirep (Euclidean) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_unirep_euclidean_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'Unirep (Euclidean) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_unirep_euclidean_fold_std.to_html()}
         </div>
         
        <div style="float:left; display:block; width:600px;">
              <h2 style='color:blue'>{'Unirep (Cosine) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_unirep_cosine_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'Unirep (Cosine) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_unirep_cosine_fold_std.to_html()}
         </div>
         """)

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1
0,fold1,0.950568,0.95972,0.946329,0.94465
0,fold2,0.950844,0.958338,0.944755,0.942069
0,fold3,0.950352,0.960056,0.945331,0.943888
0,fold4,0.950647,0.959441,0.943913,0.942439
0,fold5,0.949231,0.957969,0.943381,0.941592
0,fold6,0.94988,0.959917,0.946225,0.944221
0,fold7,0.951512,0.958198,0.945601,0.942734
0,fold8,0.951374,0.960227,0.946564,0.94468
0,fold9,0.948366,0.956707,0.944714,0.942375
0,fold10,0.950411,0.958943,0.945935,0.943439

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.950319,0.958952,0.945275,0.943209
std,0.000953,0.001133,0.001067,0.001116

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1
0,fold1,0.951001,0.960014,0.947409,0.9456
0,fold2,0.951355,0.958806,0.945655,0.943051
0,fold3,0.950922,0.960457,0.945877,0.94436
0,fold4,0.951394,0.95968,0.944798,0.94303
0,fold5,0.950195,0.959141,0.944507,0.942749
0,fold6,0.95047,0.960158,0.946945,0.944663
0,fold7,0.952259,0.958553,0.946114,0.94328
0,fold8,0.95224,0.961079,0.947896,0.945963
0,fold9,0.948582,0.957557,0.945057,0.942993
0,fold10,0.950883,0.959947,0.94648,0.9443

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.95093,0.959539,0.946074,0.943999
std,0.001062,0.001033,0.001124,0.001155


### 4.3 ESM

In [6]:
embd_methd = 'esm'
file_res_esm = [f'{cfg.RESULTS_DIR}simi/fold_{fold_num}_{embd_methd}_results.h5' for fold_num in range(1,11)]
res_esm = [read_h5_file(item)for item in tqdm(file_res_esm)]

# 获取反应ID
for i in tqdm(range(10)):
    res_esm[i]['rxn_euclidean'] = res_esm[i].euclidean.apply(lambda x : uniprot_rxn_dict.get(x[0][0]))
    res_esm[i]['rxn_cosine'] = res_esm[i].cosine.apply(lambda x : uniprot_rxn_dict.get(x[0][0]))

# 将反应ID标签化    
res_esm = apply_labels(res_esm, 'reaction_id', 'rxn_euclidean', 'lb_rxn_groundtruth', 'lb_rxn_esm_euclidean', dict_rxn2id)
for i in tqdm(range(10)):
    res_esm[i]['lb_rxn_esm_cosine'] = res_esm[i].rxn_cosine.parallel_apply(lambda x :btools.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [7]:
# 计算评价指标
res_esm_euclidean_metrics = calculate_metrics_parallel(res_unirep=res_esm, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_esm_euclidean', max_workers=15)
res_esm_cosine_metrics = calculate_metrics_parallel(res_unirep=res_esm, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_esm_cosine', max_workers=15)

res_esm_euclidean_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_esm_euclidean_metrics.feather')
res_esm_cosine_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_esm_cosine_metrics.feather')
res_esm_euclidean_fold_std = res_esm_euclidean_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])
res_esm_cosine_fold_std = res_esm_cosine_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])

HTML(f"""
         <div style="float:left; width:600px;">
              <h2 style='color:blue'>{'ESM (Euclidean) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_esm_euclidean_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'ESM (Euclidean) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_esm_euclidean_fold_std.to_html()}
         </div>
         
        <div style="float:left; display:block; width:600px;">
              <h2 style='color:blue'>{'ESM (Cosine) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_esm_cosine_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'ESM (Cosine) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_esm_cosine_fold_std.to_html()}
         </div>
         """)

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1
0,fold1,0.974714,0.981384,0.972891,0.971203
0,fold2,0.973377,0.980362,0.969627,0.968123
0,fold3,0.973082,0.979898,0.969392,0.968083
0,fold4,0.974989,0.981705,0.969956,0.968921
0,fold5,0.974026,0.979469,0.971316,0.969259
0,fold6,0.973829,0.981042,0.97109,0.969537
0,fold7,0.975343,0.981847,0.971227,0.96954
0,fold8,0.975382,0.981557,0.97294,0.971142
0,fold9,0.973062,0.98054,0.970501,0.969261
0,fold10,0.973809,0.980722,0.970666,0.969195

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.974161,0.980853,0.970961,0.969426
std,0.000889,0.000797,0.00122,0.001054

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1
0,fold1,0.974891,0.981897,0.973387,0.971992
0,fold2,0.974006,0.981142,0.970136,0.968783
0,fold3,0.973357,0.980386,0.969289,0.968235
0,fold4,0.975736,0.982739,0.97108,0.97021
0,fold5,0.97438,0.980406,0.97181,0.970091
0,fold6,0.974557,0.981685,0.971913,0.970467
0,fold7,0.97613,0.982155,0.972014,0.970389
0,fold8,0.975638,0.982414,0.973162,0.971655
0,fold9,0.973377,0.980821,0.970894,0.969541
0,fold10,0.974026,0.980953,0.970701,0.969277

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.97461,0.98146,0.971439,0.970064
std,0.000976,0.000837,0.00128,0.001175


### 4.4 T5

In [5]:
embd_methd = 't5'
file_res_t5 = [f'{cfg.RESULTS_DIR}simi/fold_{fold_num}_{embd_methd}_results.h5' for fold_num in range(1,11)]
print('Loading T5 results...')
res_t5 = [read_h5_file(item)for item in tqdm(file_res_t5)]

print('Adding reaction ID labels to T5 results...')
# 获取反应ID
for i in tqdm(range(10)):
    res_t5[i]['rxn_euclidean'] = res_t5[i].euclidean.apply(lambda x : uniprot_rxn_dict.get(x[0][0]))
    res_t5[i]['rxn_cosine'] = res_t5[i].cosine.apply(lambda x : uniprot_rxn_dict.get(x[0][0]))

# 将反应ID标签化    
res_t5 = apply_labels(res_t5, 'reaction_id', 'rxn_euclidean', 'lb_rxn_groundtruth', 'lb_rxn_t5_euclidean', dict_rxn2id)
for i in tqdm(range(10)):
    res_t5[i]['lb_rxn_t5_cosine'] = res_t5[i].rxn_cosine.parallel_apply(lambda x :btools.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))


print('Calculating metrics for T5 (Euclidean)...')
# 计算评价指标
res_t5_euclidean_metrics = calculate_metrics_parallel(res_unirep=res_t5, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_t5_euclidean', max_workers=15)
print('Calculating metrics for T5 (Cosine)...')
res_t5_cosine_metrics = calculate_metrics_parallel(res_unirep=res_t5, ground_truth_col='lb_rxn_groundtruth', pred_col='lb_rxn_t5_cosine', max_workers=15)

res_t5_euclidean_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_t5_euclidean_metrics.feather')
res_t5_cosine_metrics.to_feather('/hpcfs/fhome/shizhenkun/codebase/RXNRECer/evaluation/data/res_t5_cosine_metrics.feather')
res_t5_euclidean_fold_std = res_t5_euclidean_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])
res_t5_cosine_fold_std = res_t5_cosine_metrics[['mAccuracy', 'mPrecision', 'mRecall', 'mF1']].agg(['mean', 'std'])

HTML(f"""
         <div style="float:left; width:600px;">
              <h2 style='color:blue'>{'T5 (Euclidean) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_t5_euclidean_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'T5 (Euclidean) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_t5_euclidean_fold_std.to_html()}
         </div>
         
        <div style="float:left; display:block; width:600px;">
              <h2 style='color:blue'>{'T5 (Cosine) Evaluation 10 Fold Details'} Evaluation 10 Fold Details</h2>
              {res_t5_cosine_metrics.to_html()}
         </div>
         <div  style="float:left; width:600px;" >
              <h2 style='color:blue' >{'T5 (Cosine) Evaluation 10 Fold Overview'} Evaluation 10 Fold Overview</h2>
                   {res_t5_cosine_fold_std.to_html()}
         </div>
         """)

Loading T5 results...


  0%|          | 0/10 [00:00<?, ?it/s]

Adding reaction ID labels to T5 results...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Calculating metrics for T5 (Euclidean)...
Calculating metrics for T5 (Cosine)...


Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1
0,fold1,0.97959,0.986376,0.977209,0.976097
0,fold2,0.979394,0.986146,0.976809,0.975303
0,fold3,0.979236,0.986456,0.976468,0.975564
0,fold4,0.97959,0.986906,0.97588,0.974801
0,fold5,0.979354,0.985921,0.977046,0.975649
0,fold6,0.979512,0.986872,0.977002,0.976042
0,fold7,0.980495,0.986233,0.976102,0.974534
0,fold8,0.980357,0.987147,0.977809,0.97665
0,fold9,0.978607,0.985792,0.976815,0.975636
0,fold10,0.979433,0.986659,0.977399,0.976313

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.979557,0.986451,0.976854,0.975659
std,0.000539,0.000444,0.000583,0.000657

Unnamed: 0,baselineName,mAccuracy,mPrecision,mRecall,mF1
0,fold1,0.979453,0.986158,0.977534,0.976353
0,fold2,0.979826,0.986524,0.977148,0.975671
0,fold3,0.979335,0.986373,0.976536,0.975612
0,fold4,0.980062,0.987317,0.976561,0.975488
0,fold5,0.979472,0.98594,0.977199,0.975701
0,fold6,0.979767,0.986739,0.977379,0.976168
0,fold7,0.980573,0.986503,0.976222,0.974706
0,fold8,0.980514,0.987045,0.977997,0.97677
0,fold9,0.978784,0.985607,0.97702,0.975666
0,fold10,0.979472,0.986422,0.977467,0.97621

Unnamed: 0,mAccuracy,mPrecision,mRecall,mF1
mean,0.979726,0.986463,0.977106,0.975835
std,0.000549,0.000501,0.000538,0.000569
