# 集成方法评估

> author: Shizhenkun   
> email: zhenkun.shi@tib.cas.cn   
> date: 2025-09-08  


## 1. Import packages

In [1]:
import sys,os
sys.path.insert(0, os.path.dirname(os.path.realpath('__file__')))
sys.path.insert(1,'../../')
from config import conf as cfg
from tqdm import tqdm
import numpy as np
from pandarallel import pandarallel # 导入pandaralle
pandarallel.initialize(progress_bar=False)
from evaluation import evTools
from joblib import parallel_backend
from multiprocessing import Pool
import itertools
import pandas as pd
from tqdm.notebook import tqdm

FIRST_TIME_RUN = False # For the initial run, please set this flag to True. This will allow the program to download data from UniProt and RHEA, which may take longer depending on your internet speed.

%load_ext autoreload
%autoreload 2

INFO: Pandarallel will run on 192 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 2. Loading test data and predictor results

In [11]:
test_set = pd.read_feather(cfg.FILE_DS_TEST)
test_set = test_set[['uniprot_id', 'reaction_id', 'ec_number', 'label']].rename(columns={'reaction_id': 'rxn_groundtruth', 'ec_number': 'ec_groundtruth', 'label': 'rxn_groundtruth_label'})

res_ec  = pd.read_feather(f'{cfg.CASE_2018LATER}res/res_methods_ec.feather')
res_ec=res_ec.rename(columns={
                        'reaction_ecblast':'rxn_ecblast',
                        'reaction_deepec':'rxn_deepec',
                        'reaction_clean':'rxn_clean',
                        'reaction_ecrecer':'rxn_ecrecer',
                        'reaction_catfam':'rxn_catfam',
                        'reaction_priam':'rxn_priam',
})
res_direct = pd.read_feather(f'{cfg.CASE_2018LATER}res/res_methods_direct.feather')
pd.set_option('display.max_columns', None)
res_direct.head(3)

res = res_direct.merge(res_ec, on='uniprot_id', how='left')
res.head(2)

Unnamed: 0,uniprot_id,rxn_groundtruth_x,enz_groundtruth,lb_rxn_groundtruth_x,rxn_blast,lb_rxn_blast,rxn_eu_esm,enz_eu_esm,rxn_cos_esm,enz_cos_esm,lb_rxn_eu_esm,lb_rxn_cos_esm,rxn_eu_unirep,enz_eu_unirep,rxn_cos_unirep,enz_cos_unirep,lb_rxn_eu_unirep,lb_rxn_cos_unirep,rxn_eu_t5,enz_eu_t5,rxn_cos_t5,enz_cos_t5,lb_rxn_eu_t5,lb_rxn_cos_t5,enz_RXNRECer,rxn_RXNRECer,lb_rxn_RXNRECer,rxn_groundtruth_y,isenzyme_groundtruth,ec_groundtruth,ec_specific_level,lb_rxn_groundtruth_y,ec_ecblast,rxn_ecblast,ec_deepec,rxn_deepec,ec_clean,rxn_clean,ec_ecrecer,rxn_ecrecer,ec_ecpred,reaction_ecpred,ec_catfam,rxn_catfam,ec_priam,rxn_priam,lb_rxn_ecblast,lb_rxn_deepec,lb_rxn_clean,lb_rxn_ecrecer,lb_rxn_ecpred,lb_rxn_catfam,lb_rxn_priam
0,A9JLI2,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,False,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,-,NO-PREDICTION,NO-PREDICTION,3.2.2.6;1.4.3.2;4.2.3.81,RHEA:31427;RHEA:16301;RHEA:13781,-,-,-,-,-,-,NO-PREDICTION,NO-PREDICTION,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,A9JLI3,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,False,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,-,NO-PREDICTION,NO-PREDICTION,4.6.1.18,EC-WITHOUT-REACTION,-,-,-,-,-,-,1.14.11.51;2.3.2.27,RHEA:49524,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [12]:
# 计算指标的函数
def calculate_single_metric(method, avg_type, df, ground_truth_col):
    """计算单个方法的指标"""
    try:
        res = evTools.calculate_metrics(eva_df=df, ground_truth_col=ground_truth_col, pred_col=method, eva_name=method, avg_method=avg_type)
        return res
    except Exception as e:
        print(f"计算 {method} 的 {avg_type} 指标时出错: {e}")
        return None

def get_metrics(pred_cols, avg_types, df, ground_truth_col='rxn_groundtruth', n_jobs=20):
    """
    计算指定列的指标（并发版本）
    
    Parameters:
    - pred_cols: 需要计算指标的预测列名列表
    - avg_types: 平均类型列表
    - df: 数据框
    - ground_truth_col: 真实标签列名
    - n_jobs: 并发数
    """
    from joblib import Parallel, delayed
    
    # 创建任务列表
    tasks = []
    for method in pred_cols:
        for avg_type in avg_types:
            tasks.append((method, avg_type, df, ground_truth_col))
    
    # 使用 joblib 并发计算
    results = Parallel(n_jobs=n_jobs, backend='threading')(
        delayed(calculate_single_metric)(method, avg_type, df, ground_truth_col) 
        for method, avg_type, df, ground_truth_col in tasks
    )
    
    # 过滤掉 None 结果
    valid_results = [r for r in results if r is not None]
    
    if valid_results:
        combined_metrics = pd.concat(valid_results, ignore_index=True)
        return combined_metrics
    else:
        return pd.DataFrame()
    
# 集成方法    
# 投票方法
# recall_boosted_ensemble函数
def recall_boosted_ensemble(res_array):
    """
    对多个 one-hot 编码数组进行并集操作
    规则：如果任一位置为1，则结果为1；只有当所有位置都为0时，结果才为0
    
    Parameters:
    - res_array: one-hot 编码数组列表，包含多个数组
    
    Returns:
    - 并集结果数组
    """
    return np.maximum.reduce(res_array).astype(int)

# 众数投票集成方法
def majority_vote_with_priority(arrays, priority_array, s1first=False):
    """
    众数投票集成方法，以priority_array为基准
    规则：priority_array为1则结果为1，否则按众数投票（相等时优先选1）
    
    Parameters:
    - arrays: one-hot编码数组列表
    - priority_array: 基准数组（如ESMwithCLF）
    
    Returns:
    - 投票结果数组
    """
    # 以priority_array为基准
    result = priority_array.copy().astype(int)
    
    # 决定哪些位置需要投票
    if s1first:  
        zero_mask = (priority_array == 0)  # 只在priority为0的位置投票
    else:
        zero_mask = np.ones_like(priority_array, dtype=bool)  # 所有位置都投票
    
    
    if np.any(zero_mask):
        # 计算众数：1的个数 >= 总数的一半则为1
        stacked = np.stack(arrays, axis=0)
        ones_count = np.sum(stacked, axis=0)
        majority = (ones_count >= len(arrays) / 2)
        result[zero_mask] = majority[zero_mask].astype(int)
    
    return result

def get_combinations_without_order(items, r=None):
    """
    生成任意输入 items 的无序组合。
    
    Parameters:
    - items: 可迭代对象（列表、数组等）
    - r: 组合大小
         - None: 返回所有大小为 2..len(items) 的组合
         - int: 返回该大小的组合
         - 可迭代的 int: 返回这些大小的组合并集
    
    Returns:
    - List[Tuple]: 每个元素是一个组合元组
    """
    n = len(items)
    if n == 0:
        return []

    # 规范化 r
    if r is None:
        sizes = range(2, n + 1)
    elif isinstance(r, int):
        sizes = [r]
    else:
        # 可迭代的 int
        sizes = []
        for k in r:
            try:
                k = int(k)
                if 1 <= k <= n:
                    sizes.append(k)
            except Exception:
                continue
        # 去重且保持顺序
        sizes = list(dict.fromkeys(sizes))
        if not sizes:
            return []

    res = []
    for k in sizes:
        if 1 <= k <= n:
            res.extend(itertools.combinations(items, k))
            
    res = [list(item) for item in res]
    return res


# 通用集成方法计算函数

def calculate_ensemble_metrics(cb_methods, baseline_method='ESMwithCLF', df=None, 
                              ground_truth_col='rxn_groundtruth', avg_types=['weighted'],
                              combination_sizes=[2, 3, 4], ensemble_types=['recall_boosted', 'majority']):
    """
    通用集成方法计算函数
    
    Parameters:
    - cb_methods: 候选方法列表
    - baseline_method: 基准方法名
    - df: 数据框
    - ground_truth_col: 真实标签列名
    - avg_types: 平均类型列表
    - combination_sizes: 组合大小列表，如[2, 3, 4]
    - ensemble_types: 集成类型列表，如['recall_boosted', 'majority']
    
    Returns:
    - 所有集成方法的指标DataFrame
    """
    all_metrics = []
    
    for size in combination_sizes:
        # 获取指定大小的组合
        combinations = get_combinations_without_order(cb_methods, r=size)
        print(f"计算 {size + 1} 个方法的组合, 共 {len(combinations)} 个")
        
        for ensemble_type in ensemble_types:
            print(f"  计算 {ensemble_type} 集成...")
            
            # 生成列名
            col_names = []
            for combo in combinations:
                col_name = '_'.join(combo) + f'_{baseline_method}_{ensemble_type}'
                col_names.append(col_name)
            
            # 计算集成结果
            for i, combo in enumerate(combinations):
                if ensemble_type == 'recall_boosted':
                    # 召回率提升集成（逻辑或）
                    df[col_names[i]] = df.apply(
                        lambda x: recall_boosted_ensemble([x[method] for method in combo] + [x[baseline_method]]), 
                        axis=1
                    )
                elif ensemble_type == 'majority':
                    # 众数投票集成
                    df[col_names[i]] = df.apply(
                        lambda x: majority_vote_with_priority([x[method] for method in combo], x[baseline_method], s1first=True), 
                        axis=1
                    )
            
            # 计算指标
            metrics = get_metrics(pred_cols=col_names, avg_types=avg_types, df=df, ground_truth_col=ground_truth_col)
            all_metrics.append(metrics)
    
    # 合并所有结果
    if all_metrics:
        return pd.concat(all_metrics, ignore_index=True)
    else:
        return pd.DataFrame()

print("通用集成方法计算函数定义完成！")


通用集成方法计算函数定义完成！


### 3 整理计算表格

In [13]:
# 创建公共的重命名字典
rename_dict = {
    'lb_rxn_groundtruth_x': 'rxn_groundtruth',
    'lb_rxn_ecblast': 'MSA-via-EC',
    'lb_rxn_deepec': 'DeepEC',
    'lb_rxn_clean': 'CLEAN',
    'lb_rxn_ecrecer': 'ECRECer',
    'lb_rxn_catfam': 'CatFam',
    'lb_rxn_priam': 'PRIAM',
    'lb_rxn_blast': 'MSA-via-RXN',
    'lb_rxn_cos_esm': 'ESM',
    'lb_rxn_cos_unirep': 'UniRep',
    'lb_rxn_cos_t5': 'T5',
    'lb_rxn_RXNRECer': 'ESMwithCLF',
    'rxn_groundtruth_x': 'rxn_groundtruth',
    'rxn_ecblast': 'MSA-via-EC',
    'rxn_deepec': 'DeepEC',
    'rxn_clean': 'CLEAN',
    'rxn_ecrecer': 'ECRECer',
    'rxn_catfam': 'CatFam',
    'rxn_priam': 'PRIAM',
    'rxn_blast': 'MSA-via-RXN',
    'rxn_cos_esm': 'ESM',
    'rxn_cos_unirep': 'UniRep',
    'rxn_cos_t5': 'T5',
    'rxn_RXNRECer': 'ESMwithCLF'
}

# 定义 methods 列表
methods = ['MSA-via-EC', 'DeepEC', 'CLEAN', 'ECRECer', 'CatFam', 'PRIAM', 'MSA-via-RXN', 'ESM', 'UniRep', 'T5', 'ESMwithCLF']

# 创建函数来生成 DataFrame
def create_dataframe(prefix, df_data):
    # 筛选以指定前缀开头的列
    selected_columns = ['uniprot_id'] + [col for col in res.columns if col.startswith(prefix)]
    # 重命名列并选择所需的列
    df = df_data[selected_columns].rename(columns=rename_dict)
    df = df[['uniprot_id', 'rxn_groundtruth'] + methods]
    return df

In [14]:
df_cp = create_dataframe(prefix='lb_', df_data=res)
df_cp.head(2)

Unnamed: 0,uniprot_id,rxn_groundtruth,MSA-via-EC,DeepEC,CLEAN,ECRECer,CatFam,PRIAM,MSA-via-RXN,ESM,UniRep,T5,ESMwithCLF
0,A9JLI2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,A9JLI3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


## 4.计算指标
### 4.1 单个预测器

In [8]:
# 定义平均类型列表
# avg_types = ['weighted', 'micro', 'macro']
avg_types = ['weighted']
eva_metrics_initial = get_metrics(pred_cols=methods, avg_types=avg_types, df=df_cp, ground_truth_col='rxn_groundtruth')
eva_metrics_initial.sort_values(by=['mF1'], ascending=[False])

Unnamed: 0,evaName,mAccuracy,mPrecision,mRecall,mF1,avgType
10,ESMwithCLF,0.59852,0.749473,0.930456,0.802831,weighted
9,T5,0.746948,0.893889,0.7448,0.778224,weighted
7,ESM,0.727414,0.884853,0.72375,0.760845,weighted
3,ECRECer,0.6899,0.967961,0.651723,0.714824,weighted
8,UniRep,0.689752,0.863138,0.67004,0.712641,weighted
6,MSA-via-RXN,0.501295,0.919604,0.545669,0.642326,weighted
4,CatFam,0.770477,0.870005,0.670351,0.620349,weighted
0,MSA-via-EC,0.43337,0.907868,0.40919,0.504355,weighted
2,CLEAN,0.078949,0.879031,0.145483,0.097747,weighted
5,PRIAM,0.017832,0.853372,0.160137,0.075407,weighted


In [10]:
methods = ['MSA-via-EC', 'DeepEC', 'CLEAN', 'ECRECer', 'CatFam', 'PRIAM', 'MSA-via-RXN', 'ESM', 'UniRep', 'T5', 'ESMwithCLF']
df_hum = create_dataframe('rxn_', res)
for item in methods:
    df_hum[f'pred_true_{item}'] = df_hum.apply(
        lambda x: set(x.rxn_groundtruth.split(cfg.SPLITER)).issubset(set(x[f'{item}'].split(cfg.SPLITER))),
        axis=1
    )
    
for item in methods:
    correct = df_hum[df_hum[f'pred_true_{item}']].shape[0]
    accuracy = round(correct / len(df_hum), 6)
    print(f'{item:12s}: {accuracy}')
    
# evTools.calculate_metrics(eva_df=df_cp, ground_truth_col='rxn_groundtruth', pred_col='MSA-via-EC', eva_name='m', avg_method='macro')

MSA-via-EC  : 0.43744
DeepEC      : 0.036108
CLEAN       : 0.10899
ECRECer     : 0.695154
CatFam      : 0.771143
PRIAM       : 0.118461
MSA-via-RXN : 0.516907
ESM         : 0.739327
UniRep      : 0.698705
T5          : 0.760858
ESMwithCLF  : 0.957085


### 4.2  两个预测器组合

In [29]:
cb_methods = ['MSA-via-EC', 'DeepEC', 'CLEAN', 'ECRECer', 'CatFam', 'PRIAM', 'MSA-via-RXN', 'ESM', 'UniRep', 'T5']
ensemble_metrics_2 = calculate_ensemble_metrics( cb_methods=cb_methods, baseline_method='ESMwithCLF', df=df_cp, ground_truth_col='rxn_groundtruth', avg_types=['weighted'], combination_sizes=[1], ensemble_types=['recall_boosted', 'majority'])

计算 2 个方法的组合, 共 10 个
  计算 recall_boosted 集成...
  计算 majority 集成...


### 4.3 三个预测器组合

In [30]:
ensemble_metrics_3 = calculate_ensemble_metrics( cb_methods=cb_methods, baseline_method='ESMwithCLF', df=df_cp, ground_truth_col='rxn_groundtruth', avg_types=['weighted'], combination_sizes=[2], ensemble_types=['recall_boosted', 'majority'])

计算 3 个方法的组合, 共 45 个
  计算 recall_boosted 集成...
  计算 majority 集成...


### 4.4 四个预测器组合

In [31]:
ensemble_metrics_4 = calculate_ensemble_metrics( cb_methods=cb_methods, baseline_method='ESMwithCLF', df=df_cp, ground_truth_col='rxn_groundtruth', avg_types=['weighted'], combination_sizes=[3], ensemble_types=['recall_boosted', 'majority'])

计算 4 个方法的组合, 共 120 个
  计算 recall_boosted 集成...
  计算 majority 集成...


### 4.5 五个预测器组合

In [None]:
ensemble_metrics_5 = calculate_ensemble_metrics( cb_methods=cb_methods, baseline_method='ESMwithCLF', df=df_cp, ground_truth_col='rxn_groundtruth', avg_types=['weighted'], combination_sizes=[4], ensemble_types=['recall_boosted', 'majority'])

计算 5 个方法的组合, 共 210 个
  计算 recall_boosted 集成...


  计算 majority 集成...


In [None]:
test_set = pd.read_feather(cfg.FILE_DS_TEST)
test_set = test_set[['uniprot_id', 'reaction_id', 'ec_number', 'label']].rename(columns={'reaction_id': 'rxn_groundtruth', 'ec_number': 'ec_groundtruth', 'label': 'rxn_groundtruth_label'})


res_ec  = pd.read_feather(f'{cfg.CASE_2018LATER}res/res_methods_ec.feather')
res_ec=res_ec.rename(columns={
                        'reaction_ecblast':'rxn_ecblast',
                        'reaction_deepec':'rxn_deepec',
                        'reaction_clean':'rxn_clean',
                        'reaction_ecrecer':'rxn_ecrecer',
                        'reaction_catfam':'rxn_catfam',
                        'reaction_priam':'rxn_priam',
})
res_direct = pd.read_feather(f'{cfg.CASE_2018LATER}res/res_methods_direct.feather')
pd.set_option('display.max_columns', None)

res = res_direct.merge(res_ec, on='uniprot_id', how='left')
res.head(2)
res_direct.head(3)

Unnamed: 0,uniprot_id,rxn_groundtruth,enz_groundtruth,lb_rxn_groundtruth,rxn_blast,lb_rxn_blast,rxn_eu_esm,enz_eu_esm,rxn_cos_esm,enz_cos_esm,lb_rxn_eu_esm,lb_rxn_cos_esm,rxn_eu_unirep,enz_eu_unirep,rxn_cos_unirep,enz_cos_unirep,lb_rxn_eu_unirep,lb_rxn_cos_unirep,rxn_eu_t5,enz_eu_t5,rxn_cos_t5,enz_cos_t5,lb_rxn_eu_t5,lb_rxn_cos_t5,enz_RXNRECer,rxn_RXNRECer,lb_rxn_RXNRECer
0,A9JLI2,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,A9JLI3,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,A9JLI5,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",-,0,-,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,-,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


## 5 组合

In [16]:
cb_methods = ['MSA-via-EC', 'DeepEC', 'CLEAN', 'ECRECer', 'CatFam', 'PRIAM', 'MSA-via-RXN', 'ESM', 'UniRep', 'T5']

In [29]:


# 用法示例：
# get_combinations_without_order(cb_methods, r=2)
# get_combinations_without_order(cb_methods, r=[2,3])
# get_combinations_without_order(cb_methods)  # 返回从2到len(items)的所有组合


总共找到 45 个两两组合


In [None]:
eva_metrics_cb3m = get_metrics(pred_cols=cb_methods_3m, avg_types=avg_types, df=df_cp, ground_truth_col='rxn_groundtruth')
eva_metrics = pd.concat([eva_metrics, eva_metrics_cb3m], ignore_index=True)
eva_metrics.sort_values(by=['mF1'], ascending=[False])

Unnamed: 0,evaName,mAccuracy,mPrecision,mRecall,mF1,avgType
47,ECRECer_MSA-via-RXN_ESMwithCLF,0.598890,0.753592,0.947532,0.814141,weighted
50,ECRECer_T5_ESMwithCLF,0.598742,0.752647,0.947594,0.813548,weighted
48,ECRECer_ESM_ESMwithCLF,0.598076,0.752462,0.947283,0.813235,weighted
14,ECRECer_union_ESMwithCLF,0.598520,0.753080,0.946166,0.813171,weighted
23,MSA-via-EC_ECRECer_ESMwithCLF,0.598594,0.752516,0.946600,0.812953,weighted
...,...,...,...,...,...,...
4,CatFam,0.770477,0.870005,0.670351,0.620349,weighted
0,MSA-via-EC,0.433370,0.907868,0.409190,0.504355,weighted
2,CLEAN,0.078949,0.879031,0.145483,0.097747,weighted
5,PRIAM,0.017832,0.853372,0.160137,0.075407,weighted


In [None]:
pd.set_option('display.max_rows', 200)   # 最多显示 100 行


通用集成方法计算函数定义完成！


In [52]:
# 使用通用函数计算集成指标

# 定义候选方法
cb_methods = ['MSA-via-EC', 'DeepEC', 'CLEAN', 'ECRECer', 'CatFam', 'PRIAM', 'MSA-via-RXN', 'ESM', 'UniRep', 'T5']

# 使用通用函数计算所有集成方法
print("开始计算所有集成方法...")
ensemble_metrics_1 = calculate_ensemble_metrics(
    cb_methods=cb_methods,
    baseline_method='ESMwithCLF',
    df=df_cp,
    ground_truth_col='rxn_groundtruth',
    avg_types=['weighted'],
    combination_sizes=[2, 3, 4],  # 可以调整组合大小
    ensemble_types=['recall_boosted', 'majority']  # 可以调整集成类型
)

print(f"\n计算完成！共得到 {len(ensemble_metrics_1)} 个集成方法的指标")
print("\n前10个结果：")
print(ensemble_metrics_1.head(10))

# 按F1分数排序
print("\n按F1分数排序（前10名）：")
f1_sorted = ensemble_metrics_1.sort_values(by=['mF1'], ascending=[False])
print(f1_sorted[['evaName', 'mAccuracy', 'mPrecision', 'mRecall', 'mF1']].head(10))


开始计算所有集成方法...
计算 2 个方法的组合...
  找到 45 个组合
  计算 recall_boosted 集成...
  计算 majority 集成...
计算 3 个方法的组合...
  找到 120 个组合
  计算 recall_boosted 集成...
  计算 majority 集成...
计算 4 个方法的组合...
  找到 210 个组合
  计算 recall_boosted 集成...
  计算 majority 集成...


: 

In [21]:
ensemble_metrics_1

Unnamed: 0,evaName,mAccuracy,mPrecision,mRecall,mF1,avgType
0,MSA-via-EC_DeepEC_ESMwithCLF_recall_boosted,0.595930,0.746985,0.931450,0.801477,weighted
1,MSA-via-EC_CLEAN_ESMwithCLF_recall_boosted,0.164854,0.704334,0.938777,0.766769,weighted
2,MSA-via-EC_ECRECer_ESMwithCLF_recall_boosted,0.598594,0.752516,0.946600,0.812953,weighted
3,MSA-via-EC_CatFam_ESMwithCLF_recall_boosted,0.598224,0.748373,0.931139,0.802340,weighted
4,MSA-via-EC_PRIAM_ESMwithCLF_recall_boosted,0.355235,0.685540,0.942999,0.750495,weighted
...,...,...,...,...,...,...
85,MSA-via-RXN_UniRep_ESMwithCLF_majority,0.557307,0.727527,0.933499,0.789119,weighted
86,MSA-via-RXN_T5_ESMwithCLF_majority,0.598742,0.749319,0.932319,0.803591,weighted
87,ESM_UniRep_ESMwithCLF_majority,0.556641,0.726646,0.933437,0.788570,weighted
88,ESM_T5_ESMwithCLF_majority,0.597928,0.748486,0.932381,0.802940,weighted
