In [11]:
from lib.utils import generate_mean_ensemble_metrics, aggregate_pred_dataframe, generate_mean_ensemble_metrics_auto
import os
import pandas as pd
test_csv_fliles = [file for file in os.listdir('result_cv/esm-33-gearnet-ensemble/fold_0/intermediate') if 'test' in file]
test_csv_fliles.sort()
test_csv_fliles = [f'result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/{file}' for file in test_csv_fliles]

valid_csv_fliles = [file for file in os.listdir('result_cv/esm-33-gearnet-ensemble/fold_0/intermediate') if 'valid' in file]
valid_csv_fliles.sort()
valid_csv_fliles = [f'result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/{file}' for file in valid_csv_fliles]


In [14]:
agg = pd.DataFrame()
for csv_file in test_csv_fliles:
    df = aggregate_pred_dataframe([csv_file])
    me = generate_mean_ensemble_metrics(df)
    print(f'Processing {csv_file}, aggregated metrics: {me["mcc"]})')
    agg = pd.concat([agg, pd.DataFrame([me])], ignore_index=True)

agg.aggregate(['mean', 'std'])


Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_0_test.csv, aggregated metrics: 0.6399)
Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_1_test.csv, aggregated metrics: 0.6538)
Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_2_test.csv, aggregated metrics: 0.63)
Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_3_test.csv, aggregated metrics: 0.642)
Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_4_test.csv, aggregated metrics: 0.6531)
Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_5_test.csv, aggregated metrics: 0.6393)
Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_6_test.csv, aggregated metrics: 0.6407)
Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_7_test.csv, aggregated metrics: 0.6202)
Processing result_cv/esm-33-gearnet-ensemble/fold_0/intermediate/iter_8_test.csv, aggregated metrics: 0.6681)
Processing re

Unnamed: 0,sensitivity,specificity,accuracy,precision,mcc
mean,0.58135,0.98917,0.96803,0.75053,0.64346
std,0.033306,0.003279,0.001833,0.050536,0.013272


In [13]:
for valid_csv, test_csv in zip(valid_csv_fliles, test_csv_fliles):
    valid_df = aggregate_pred_dataframe([valid_csv])
    test_df = aggregate_pred_dataframe([test_csv])
    me = generate_mean_ensemble_metrics_auto(valid_df, test_df)
    agg = pd.concat([agg, pd.DataFrame([me])], ignore_index=True)


agg.aggregate(['mean', 'std'])

Unnamed: 0,sensitivity,specificity,accuracy,precision,mcc
mean,0.59028,0.988565,0.967925,0.74177,0.644695
std,0.029573,0.002798,0.001656,0.041619,0.013061


In [15]:
def format_mean_std(row, metric):
    mean_val = row[(metric, 'mean')]
    std_val = row[(metric, 'std')]
    return f"{mean_val:.3f} ± {std_val:.3f}"


result_df = pd.read_csv('result_cv/result_cv.csv')
grouped = result_df.groupby('model_key').agg({
    'mcc': ['mean', 'std'],
    'sensitivity': ['mean', 'std'],
    'precision': ['mean', 'std']
})

grouped['mcc_formatted'] = grouped.apply(lambda row: format_mean_std(row, 'mcc'), axis=1)
grouped['sensitivity_formatted'] = grouped.apply(lambda row: format_mean_std(row, 'sensitivity'), axis=1)
grouped['precision_formatted'] = grouped.apply(lambda row: format_mean_std(row, 'precision'), axis=1)

formatted_df = grouped[['mcc_formatted', 'sensitivity_formatted', 'precision_formatted']]
formatted_df.columns = ['mcc', 'sensitivity', 'precision']
formatted_df = formatted_df.sort_values('mcc', ascending=True)

formatted_df = formatted_df.reset_index()
formatted_df

Unnamed: 0,model_key,mcc,sensitivity,precision
0,gearnet,0.458 ± 0.024,0.374 ± 0.038,0.619 ± 0.071
1,bert,0.487 ± 0.013,0.400 ± 0.038,0.649 ± 0.061
2,bert-gearnet,0.521 ± 0.028,0.441 ± 0.095,0.680 ± 0.111
3,bert-gearnet-ensemble,0.538 ± 0.023,0.403 ± 0.034,0.764 ± 0.029
4,esm-33-gearnet,0.636 ± 0.020,0.551 ± 0.069,0.777 ± 0.061
5,esm-t33,0.655 ± 0.016,0.582 ± 0.063,0.775 ± 0.052
6,esm-33-gearnet-ensemble,0.666 ± 0.013,0.564 ± 0.025,0.819 ± 0.014
7,esm-33-gearnet-ensemble-rus,0.674 ± 0.011,0.759 ± 0.006,0.632 ± 0.018
8,esm-33-gearnet-resiboost,0.679 ± 0.010,0.597 ± 0.011,0.805 ± 0.018


In [16]:
from tabulate import tabulate
markdown_table = tabulate(formatted_df, headers='keys', tablefmt='pipe', showindex=False)
print(markdown_table)


| model_key                   | mcc           | sensitivity   | precision     |
|:----------------------------|:--------------|:--------------|:--------------|
| gearnet                     | 0.458 ± 0.024 | 0.374 ± 0.038 | 0.619 ± 0.071 |
| bert                        | 0.487 ± 0.013 | 0.400 ± 0.038 | 0.649 ± 0.061 |
| bert-gearnet                | 0.521 ± 0.028 | 0.441 ± 0.095 | 0.680 ± 0.111 |
| bert-gearnet-ensemble       | 0.538 ± 0.023 | 0.403 ± 0.034 | 0.764 ± 0.029 |
| esm-33-gearnet              | 0.636 ± 0.020 | 0.551 ± 0.069 | 0.777 ± 0.061 |
| esm-t33                     | 0.655 ± 0.016 | 0.582 ± 0.063 | 0.775 ± 0.052 |
| esm-33-gearnet-ensemble     | 0.666 ± 0.013 | 0.564 ± 0.025 | 0.819 ± 0.014 |
| esm-33-gearnet-ensemble-rus | 0.674 ± 0.011 | 0.759 ± 0.006 | 0.632 ± 0.018 |
| esm-33-gearnet-resiboost    | 0.679 ± 0.010 | 0.597 ± 0.011 | 0.805 ± 0.018 |
