In [2]:
import pandas as pd
from tabulate import tabulate

def format_mean_std(row, metric):
    mean_val = row[(metric, 'mean')]
    std_val = row[(metric, 'std')]
    return f"{mean_val:.3f} ± {std_val:.3f}"


def get_stat_df(filename, verbose=True):
    # Sample loading of your CSV - this would be replaced with your actual loading
    result_df = pd.read_csv(filename)

    # metrics = ['mcc', 'sensitivity', 'specificity', 'accuracy', 'precision', 'micro_auroc']
    metrics = ['mcc', 'sensitivity', 'precision', 'micro_auroc']
    grouped = result_df.groupby('model_key').agg({
        k: ['mean', 'std'] for k in metrics
    })

    record_counts = result_df.groupby('model_key').size()

    for metric in metrics:
        grouped[f'{metric}_formatted'] = grouped.apply(lambda row: format_mean_std(row, metric), axis=1)


    # Add the record count column to the grouped dataframe
    grouped['record_count'] = record_counts.values

    formatted_df = grouped[[k + '_formatted' for k in metrics] + ['record_count']]
    formatted_df.columns = metrics + ['record_count']
    formatted_df = formatted_df.sort_values(metrics[0], ascending=True)
    '''
    ESM:
    2: esm-t33-pretrained
    4: esm-33-gearnet-resiboost	-> 2번하고 비교헀을 때: gearnet을 추가하는 게 더 낫다. resiboost를 추가하는 것도 좋다.
    9: esm-33-gearnet-pretrained-freezelm-ensemble -> 4번하고 비교했을 때 pretrain된 거니까 좋다.


    '''
    formatted_df = formatted_df.reset_index()
    
    if verbose:
        markdown_table = tabulate(formatted_df[['model_key'] + metrics] , headers='keys', tablefmt='pipe', showindex=False)
        print(markdown_table)

    return formatted_df



In [30]:
'''
get_stat_df_summary:

'''

def get_stat_df_summary(filename, groups, verbose=True):
    '''
    - can group some model_key together, and select the model with maximum mcc
    - groups is a dict: {group_name: [model_key1, model_key2, ...]}
    '''
    stat_df = get_stat_df(filename, verbose=False)
    
    new_df = pd.DataFrame()
    
    for group_name, model_keys in groups.items():
        group_df = stat_df[stat_df['model_key'].isin(model_keys)]
        if group_df.shape[0] == 0:
            continue
        group_df = group_df.sort_values('mcc', ascending=False)
        group_df = group_df.head(1).copy().reset_index(drop=True)
        group_df.insert(0, 'group_name', group_name)
        new_df = pd.concat([new_df, group_df])
        
    new_df = new_df.sort_values('mcc', ascending=True).reset_index(drop=True)
    
    metrics = ['mcc', 'sensitivity', 'precision', 'micro_auroc']
    if verbose:
        markdown_table = tabulate(new_df[['group_name', 'model_key'] + metrics] , headers='keys', tablefmt='pipe', showindex=False)
        print(markdown_table)
    
    return new_df
        
    
groups = {
    'ESM': ['esm-t33'],
    'ESM + Pretrained': ['esm-t33-pretrained', 'esm-t33-pretrained-freezelm'],
    'ESM + GearNet': ['esm-33-gearnet'],
    'ESM + GearNet + Ensemble': ['esm-33-gearnet-ensemble'],
    'ESM + GearNet + Pretrained': ['esm-33-gearnet-pretrained', 'esm-33-gearnet-pretrained-freezelm', 'esm-33-gearnet-pretrained-freezeall'],
    'ESM + GearNet + Resiboost': [f'esm-33-gearnet-rboost{ratio:02d}' for ratio in [50]],
    'ESM + GearNet + Pretrained + Resiboost': [f'esm-33-gearnet-pretrained-freezelm-rboost{ratio:02d}' for ratio in [5, 10, 25, 50]] + 
                                              [f'esm-33-gearnet-pretrained-rboost{ratio:02d}' for ratio in [10, 50]],
    'ESM + GearNet + Pretrained + Ensemble': [f'esm-33-gearnet-pretrained-ensemble', 'esm-33-gearnet-pretrained-freezelm-ensemble'],
}


get_stat_df_summary('record.csv', groups=groups)


| group_name                 | model_key                   | mcc           | sensitivity   | precision     | micro_auroc   |
|:---------------------------|:----------------------------|:--------------|:--------------|:--------------|:--------------|
| ESM                        | esm-t33                     | 0.662 ± 0.032 | 0.702 ± 0.043 | 0.675 ± 0.079 | 0.940 ± 0.023 |
| ESM + Pretrained           | esm-t33-pretrained-freezelm | 0.674 ± 0.019 | 0.684 ± 0.028 | 0.709 ± 0.021 | 0.857 ± 0.025 |
| ESM + GearNet + Ensemble   | esm-33-gearnet-ensemble     | 0.709 ± 0.011 | 0.709 ± 0.009 | 0.749 ± 0.017 | 0.913 ± 0.006 |
| ESM + GearNet              | esm-33-gearnet              | 0.719 ± 0.015 | 0.725 ± 0.018 | 0.752 ± 0.026 | 0.910 ± 0.018 |
| ESM + GearNet + Pretrained | esm-33-gearnet-pretrained   | 0.721 ± 0.016 | 0.736 ± 0.021 | 0.743 ± 0.019 | 0.893 ± 0.022 |


Unnamed: 0,group_name,model_key,mcc,sensitivity,precision,micro_auroc,record_count
0,ESM,esm-t33,0.662 ± 0.032,0.702 ± 0.043,0.675 ± 0.079,0.940 ± 0.023,5
1,ESM + Pretrained,esm-t33-pretrained-freezelm,0.674 ± 0.019,0.684 ± 0.028,0.709 ± 0.021,0.857 ± 0.025,5
2,ESM + GearNet + Ensemble,esm-33-gearnet-ensemble,0.709 ± 0.011,0.709 ± 0.009,0.749 ± 0.017,0.913 ± 0.006,5
3,ESM + GearNet,esm-33-gearnet,0.719 ± 0.015,0.725 ± 0.018,0.752 ± 0.026,0.910 ± 0.018,5
4,ESM + GearNet + Pretrained,esm-33-gearnet-pretrained,0.721 ± 0.016,0.736 ± 0.021,0.743 ± 0.019,0.893 ± 0.022,5


In [31]:
get_stat_df('imatinib.csv')
print('--------GROUPED VERSION--------')
get_stat_df_summary('imatinib.csv', groups=groups)

| model_key                                   | mcc           | sensitivity   | precision     | micro_auroc   |
|:--------------------------------------------|:--------------|:--------------|:--------------|:--------------|
| esm-33-gearnet-pretrained-freezeall         | 0.612 ± 0.044 | 0.580 ± 0.058 | 0.701 ± 0.077 | 0.866 ± 0.040 |
| esm-t33-pretrained                          | 0.664 ± 0.039 | 0.707 ± 0.030 | 0.674 ± 0.091 | 0.913 ± 0.057 |
| esm-t33-pretrained-freezelm                 | 0.674 ± 0.019 | 0.684 ± 0.028 | 0.709 ± 0.021 | 0.857 ± 0.025 |
| esm-33-gearnet-pretrained-freezelm-rboost05 | 0.688 ± 0.015 | 0.747 ± 0.017 | 0.676 ± 0.024 | 0.894 ± 0.013 |
| esm-33-gearnet-pretrained-rboost10          | 0.688 ± 0.022 | 0.729 ± 0.032 | 0.691 ± 0.014 | 0.896 ± 0.039 |
| esm-33-gearnet-pretrained-rboost50          | 0.696 ± 0.007 | 0.707 ± 0.018 | 0.728 ± 0.017 | 0.908 ± 0.041 |
| esm-t33                                     | 0.697 ± 0.009 | 0.705 ± 0.036 | 0.731 ± 0.042 | 0.944 ± 

Unnamed: 0,group_name,model_key,mcc,sensitivity,precision,micro_auroc,record_count
0,ESM + Pretrained,esm-t33-pretrained-freezelm,0.674 ± 0.019,0.684 ± 0.028,0.709 ± 0.021,0.857 ± 0.025,5
1,ESM,esm-t33,0.697 ± 0.009,0.705 ± 0.036,0.731 ± 0.042,0.944 ± 0.023,5
2,ESM + GearNet + Resiboost,esm-33-gearnet-rboost50,0.698 ± 0.015,0.705 ± 0.008,0.733 ± 0.027,0.929 ± 0.013,5
3,ESM + GearNet + Pretrained + Resiboost,esm-33-gearnet-pretrained-freezelm-rboost50,0.710 ± 0.012,0.711 ± 0.015,0.748 ± 0.025,0.913 ± 0.007,5
4,ESM + GearNet + Pretrained,esm-33-gearnet-pretrained,0.716 ± 0.025,0.745 ± 0.033,0.728 ± 0.043,0.897 ± 0.045,5
5,ESM + GearNet + Pretrained + Ensemble,esm-33-gearnet-pretrained-freezelm-ensemble,0.717 ± 0.012,0.705 ± 0.017,0.768 ± 0.027,0.886 ± 0.010,5
6,ESM + GearNet,esm-33-gearnet,0.718 ± 0.040,0.733 ± 0.040,0.743 ± 0.044,0.897 ± 0.017,5
7,ESM + GearNet + Ensemble,esm-33-gearnet-ensemble,0.721 ± 0.014,0.705 ± 0.010,0.776 ± 0.030,0.905 ± 0.015,5


In [32]:
get_stat_df('dasatinib.csv')
print('--------GROUPED VERSION--------')
get_stat_df_summary('dasatinib.csv', groups=groups)

| model_key                                    | mcc           | sensitivity   | precision     | micro_auroc   |
|:---------------------------------------------|:--------------|:--------------|:--------------|:--------------|
| esm-33-gearnet-pretrained-freezeall          | 0.607 ± 0.090 | 0.553 ± 0.096 | 0.743 ± 0.176 | 0.881 ± 0.018 |
| esm-t33-pretrained                           | 0.723 ± 0.014 | 0.594 ± 0.041 | 0.922 ± 0.061 | 0.967 ± 0.010 |
| esm-33-gearnet-pretrained                    | 0.726 ± 0.044 | 0.633 ± 0.097 | 0.886 ± 0.124 | 0.950 ± 0.013 |
| esm-t33-pretrained-freezelm                  | 0.727 ± 0.051 | 0.597 ± 0.088 | 0.930 ± 0.023 | 0.892 ± 0.009 |
| esm-33-gearnet                               | 0.732 ± 0.034 | 0.616 ± 0.032 | 0.909 ± 0.050 | 0.935 ± 0.017 |
| esm-t33                                      | 0.747 ± 0.012 | 0.619 ± 0.026 | 0.939 ± 0.026 | 0.956 ± 0.013 |
| esm-33-gearnet-ensemble                      | 0.754 ± 0.024 | 0.627 ± 0.049 | 0.945 ± 0.062 |

Unnamed: 0,group_name,model_key,mcc,sensitivity,precision,micro_auroc,record_count
0,ESM + Pretrained,esm-t33-pretrained-freezelm,0.727 ± 0.051,0.597 ± 0.088,0.930 ± 0.023,0.892 ± 0.009,5
1,ESM + GearNet,esm-33-gearnet,0.732 ± 0.034,0.616 ± 0.032,0.909 ± 0.050,0.935 ± 0.017,5
2,ESM,esm-t33,0.747 ± 0.012,0.619 ± 0.026,0.939 ± 0.026,0.956 ± 0.013,5
3,ESM + GearNet + Ensemble,esm-33-gearnet-ensemble,0.754 ± 0.024,0.627 ± 0.049,0.945 ± 0.062,0.955 ± 0.006,5
4,ESM + GearNet + Resiboost,esm-33-gearnet-rboost50,0.764 ± 0.022,0.647 ± 0.015,0.938 ± 0.041,0.951 ± 0.015,5
5,ESM + GearNet + Pretrained,esm-33-gearnet-pretrained-freezelm,0.772 ± 0.024,0.649 ± 0.036,0.953 ± 0.026,0.939 ± 0.007,5
6,ESM + GearNet + Pretrained + Ensemble,esm-33-gearnet-pretrained-freezelm-ensemble,0.783 ± 0.026,0.657 ± 0.041,0.964 ± 0.021,0.941 ± 0.010,5
7,ESM + GearNet + Pretrained + Resiboost,esm-33-gearnet-pretrained-freezelm-rboost25,0.793 ± 0.010,0.679 ± 0.016,0.955 ± 0.035,0.940 ± 0.013,5


In [33]:
get_stat_df('bosutinib.csv')
print('--------GROUPED VERSION--------')
get_stat_df_summary('bosutinib.csv', groups=groups)

| model_key                                    | mcc           | sensitivity   | precision     | micro_auroc   |
|:---------------------------------------------|:--------------|:--------------|:--------------|:--------------|
| esm-33-gearnet-pretrained-freezeall          | 0.682 ± 0.061 | 0.630 ± 0.043 | 0.784 ± 0.094 | 0.950 ± 0.019 |
| esm-t33-pretrained-freezelm                  | 0.693 ± 0.033 | 0.670 ± 0.057 | 0.769 ± 0.123 | 0.964 ± 0.003 |
| esm-t33-pretrained                           | 0.743 ± 0.024 | 0.670 ± 0.046 | 0.860 ± 0.028 | 0.966 ± 0.015 |
| esm-33-gearnet-pretrained-freezelm           | 0.760 ± 0.019 | 0.713 ± 0.032 | 0.845 ± 0.042 | 0.968 ± 0.008 |
| esm-33-gearnet-pretrained-freezelm-resiboost | 0.761 ± 0.038 | 0.743 ± 0.080 | 0.824 ± 0.127 | 0.958 ± 0.010 |
| esm-33-gearnet-pretrained-freezelm-rboost50  | 0.762 ± 0.018 | 0.737 ± 0.089 | 0.832 ± 0.108 | 0.960 ± 0.008 |
| esm-33-gearnet-resiboost                     | 0.765 ± 0.015 | 0.750 ± 0.077 | 0.819 ± 0.068 |

Unnamed: 0,group_name,model_key,mcc,sensitivity,precision,micro_auroc,record_count
0,ESM + Pretrained,esm-t33-pretrained,0.743 ± 0.024,0.670 ± 0.046,0.860 ± 0.028,0.966 ± 0.015,5
1,ESM + GearNet,esm-33-gearnet,0.767 ± 0.026,0.700 ± 0.057,0.874 ± 0.035,0.953 ± 0.026,5
2,ESM,esm-t33,0.777 ± 0.033,0.733 ± 0.046,0.856 ± 0.049,0.943 ± 0.039,5
3,ESM + GearNet + Pretrained,esm-33-gearnet-pretrained,0.780 ± 0.035,0.727 ± 0.069,0.870 ± 0.024,0.954 ± 0.015,5
4,ESM + GearNet + Resiboost,esm-33-gearnet-rboost50,0.786 ± 0.019,0.757 ± 0.063,0.850 ± 0.073,0.959 ± 0.016,5
5,ESM + GearNet + Pretrained + Resiboost,esm-33-gearnet-pretrained-freezelm-rboost05,0.800 ± 0.022,0.830 ± 0.014,0.800 ± 0.044,0.972 ± 0.007,5


In [34]:
get_stat_df('methotrexate.csv')
print('--------GROUPED VERSION--------')
get_stat_df_summary('methotrexate.csv', groups=groups)

| model_key                                    | mcc           | sensitivity   | precision     | micro_auroc   |
|:---------------------------------------------|:--------------|:--------------|:--------------|:--------------|
| esm-33-gearnet-pretrained-freezeall          | 0.530 ± 0.036 | 0.429 ± 0.034 | 0.697 ± 0.098 | 0.809 ± 0.024 |
| esm-t33-pretrained-freezelm                  | 0.607 ± 0.034 | 0.529 ± 0.067 | 0.730 ± 0.031 | 0.897 ± 0.042 |
| esm-33-gearnet                               | 0.619 ± 0.043 | 0.550 ± 0.079 | 0.735 ± 0.090 | 0.869 ± 0.053 |
| esm-t33-pretrained                           | 0.621 ± 0.069 | 0.616 ± 0.099 | 0.666 ± 0.115 | 0.944 ± 0.027 |
| esm-33-gearnet-ensemble                      | 0.646 ± 0.050 | 0.593 ± 0.116 | 0.742 ± 0.074 | 0.879 ± 0.074 |
| esm-33-gearnet-pretrained-freezelm           | 0.648 ± 0.035 | 0.564 ± 0.066 | 0.775 ± 0.049 | 0.882 ± 0.058 |
| esm-33-gearnet-pretrained-freezelm-rboost05  | 0.648 ± 0.056 | 0.691 ± 0.090 | 0.637 ± 0.057 |

Unnamed: 0,group_name,model_key,mcc,sensitivity,precision,micro_auroc,record_count
0,ESM + GearNet,esm-33-gearnet,0.619 ± 0.043,0.550 ± 0.079,0.735 ± 0.090,0.869 ± 0.053,10
1,ESM + Pretrained,esm-t33-pretrained,0.621 ± 0.069,0.616 ± 0.099,0.666 ± 0.115,0.944 ± 0.027,10
2,ESM + GearNet + Ensemble,esm-33-gearnet-ensemble,0.646 ± 0.050,0.593 ± 0.116,0.742 ± 0.074,0.879 ± 0.074,5
3,ESM + GearNet + Pretrained,esm-33-gearnet-pretrained,0.651 ± 0.041,0.593 ± 0.080,0.749 ± 0.085,0.869 ± 0.068,10
4,ESM,esm-t33,0.652 ± 0.051,0.621 ± 0.090,0.725 ± 0.118,0.951 ± 0.020,10
5,ESM + GearNet + Pretrained + Ensemble,esm-33-gearnet-pretrained-ensemble,0.667 ± 0.041,0.570 ± 0.075,0.810 ± 0.075,0.886 ± 0.069,5
6,ESM + GearNet + Resiboost,esm-33-gearnet-rboost50,0.682 ± 0.062,0.606 ± 0.104,0.799 ± 0.046,0.884 ± 0.069,5
7,ESM + GearNet + Pretrained + Resiboost,esm-33-gearnet-pretrained-rboost10,0.687 ± 0.053,0.659 ± 0.102,0.744 ± 0.035,0.889 ± 0.057,5


In [35]:
get_stat_df('atpbind3d.csv')

| model_key               | mcc           | sensitivity   | precision     | micro_auroc   |
|:------------------------|:--------------|:--------------|:--------------|:--------------|
| esm-33-gearnet          | 0.566 ± 0.065 | 0.510 ± 0.043 | 0.676 ± 0.109 | 0.914 ± 0.024 |
| esm-33-gearnet-ensemble | 0.642 ± 0.023 | 0.563 ± 0.035 | 0.768 ± 0.017 | 0.947 ± 0.007 |
| esm-33-gearnet-b8       | 0.676 ± 0.022 | 0.636 ± 0.009 | 0.751 ± 0.035 | 0.927 ± 0.003 |
| esm-33-gearnet-rboost10 | 0.682 ± 0.009 | 0.661 ± 0.026 | 0.737 ± 0.028 | 0.937 ± 0.006 |


Unnamed: 0,model_key,mcc,sensitivity,precision,micro_auroc,record_count
0,esm-33-gearnet,0.566 ± 0.065,0.510 ± 0.043,0.676 ± 0.109,0.914 ± 0.024,5
1,esm-33-gearnet-ensemble,0.642 ± 0.023,0.563 ± 0.035,0.768 ± 0.017,0.947 ± 0.007,5
2,esm-33-gearnet-b8,0.676 ± 0.022,0.636 ± 0.009,0.751 ± 0.035,0.927 ± 0.003,2
3,esm-33-gearnet-rboost10,0.682 ± 0.009,0.661 ± 0.026,0.737 ± 0.028,0.937 ± 0.006,5
