In [1]:
import pandas as pd
from tabulate import tabulate
def format_mean_std(row, metric):
    mean_val = row[(metric, 'mean')]
    std_val = row[(metric, 'std')]
    return f"{mean_val:.3f} ± {std_val:.3f}"


def get_stat_df(result_df, verbose=True, metrics=['mcc', 'micro_auprc', 'micro_auroc',
               'sensitivity', 'precision']):
    # Sample loading of your CSV - this would be replaced with your actual loading

    grouped = result_df.groupby('model_key').agg({
        k: ['mean', 'std'] for k in metrics
    })

    record_counts = result_df.groupby('model_key').size()

    for metric in metrics:
        grouped[f'{metric}_formatted'] = grouped.apply(
            lambda row: format_mean_std(row, metric), axis=1)

    # Add the record count column to the grouped dataframe
    grouped['record_count'] = record_counts.values

    formatted_df = grouped[[
        k + '_formatted' for k in metrics] + ['record_count']]
    formatted_df.columns = metrics + ['record_count']
    formatted_df = formatted_df.sort_values(metrics[0], ascending=True)

    formatted_df = formatted_df.reset_index()

    if verbose:
        markdown_table = tabulate(formatted_df[[
                                  'model_key'] + metrics], headers='keys', tablefmt='pipe', showindex=False)
        print(markdown_table)

    return formatted_df

In [2]:
import pandas as pd

df = pd.read_csv('atpbind3d_stats.csv')

df_gn = df[
    # ((df['model_key'].str.startswith('esm-t33-gearnet-640')) |
    (df['model_key'].str.startswith('esm-t33-gearnet'))
    # (df['model_key'] == 'esm-t33-gearnet') | 
    # (df['model_key'] == 'esm-t33')) & 
]

get_stat_df(df_gn, metrics=['mcc','micro_auprc', 'micro_auroc',
               'sensitivity', 'precision', ])

| model_key                              | mcc           | micro_auprc   | micro_auroc   | sensitivity   | precision     |
|:---------------------------------------|:--------------|:--------------|:--------------|:--------------|:--------------|
| esm-t33-gearnet                        | 0.668 ± 0.012 | 0.687 ± 0.011 | 0.924 ± 0.007 | 0.623 ± 0.023 | 0.747 ± 0.014 |
| esm-t33-gearnet-640-2                  | 0.671 ± 0.009 | 0.683 ± 0.013 | 0.923 ± 0.008 | 0.622 ± 0.023 | 0.753 ± 0.021 |
| esm-t33-gearnet-320-4                  | 0.672 ± 0.008 | 0.683 ± 0.004 | 0.920 ± 0.006 | 0.623 ± 0.024 | 0.755 ± 0.015 |
| esm-t33-gearnet-800-4                  | 0.672 ± 0.011 | 0.694 ± 0.018 | 0.932 ± 0.009 | 0.640 ± 0.033 | 0.736 ± 0.027 |
| esm-t33-gearnet-800-3                  | 0.672 ± 0.016 | 0.685 ± 0.014 | 0.925 ± 0.008 | 0.623 ± 0.024 | 0.753 ± 0.022 |
| esm-t33-gearnet-20-cycle               | 0.675 ± 0.015 | 0.695 ± 0.011 | 0.923 ± 0.004 | 0.619 ± 0.015 | 0.766 ± 0.019 |
| esm-t33-gearne

Unnamed: 0,model_key,mcc,micro_auprc,micro_auroc,sensitivity,precision,record_count
0,esm-t33-gearnet,0.668 ± 0.012,0.687 ± 0.011,0.924 ± 0.007,0.623 ± 0.023,0.747 ± 0.014,6
1,esm-t33-gearnet-640-2,0.671 ± 0.009,0.683 ± 0.013,0.923 ± 0.008,0.622 ± 0.023,0.753 ± 0.021,10
2,esm-t33-gearnet-320-4,0.672 ± 0.008,0.683 ± 0.004,0.920 ± 0.006,0.623 ± 0.024,0.755 ± 0.015,5
3,esm-t33-gearnet-800-4,0.672 ± 0.011,0.694 ± 0.018,0.932 ± 0.009,0.640 ± 0.033,0.736 ± 0.027,5
4,esm-t33-gearnet-800-3,0.672 ± 0.016,0.685 ± 0.014,0.925 ± 0.008,0.623 ± 0.024,0.753 ± 0.022,15
5,esm-t33-gearnet-20-cycle,0.675 ± 0.015,0.695 ± 0.011,0.923 ± 0.004,0.619 ± 0.015,0.766 ± 0.019,5
6,esm-t33-gearnet-960-2,0.676 ± 0.013,0.691 ± 0.015,0.926 ± 0.010,0.631 ± 0.014,0.753 ± 0.023,10
7,esm-t33-gearnet-960-1,0.676 ± 0.016,0.684 ± 0.016,0.923 ± 0.006,0.617 ± 0.030,0.769 ± 0.029,10
8,esm-t33-gearnet-640,0.677 ± 0.009,0.685 ± 0.017,0.923 ± 0.008,0.640 ± 0.014,0.745 ± 0.015,15
9,esm-t33-gearnet-800-3-adaboost-r10,0.685 ± 0.005,0.705 ± 0.003,0.913 ± 0.005,0.639 ± 0.011,0.762 ± 0.014,5


In [3]:
import pandas as pd

df = pd.read_csv('atpbind3d_stats.csv')

df.groupby('model_key').agg({'mcc': ['mean', 'std'], 'micro_auprc': [
    'mean', 'std']}).sort_values(('micro_auprc', 'mean'), ascending=False)

Unnamed: 0_level_0,mcc,mcc,micro_auprc,micro_auprc
Unnamed: 0_level_1,mean,std,mean,std
model_key,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
esm-t33-gearnet-20-cycle-resiboost-r50,0.694733,0.011511,0.727867,0.011971
esm-t33-gearnet-adaboost-r50,0.69442,0.005516,0.7255,0.011629
esm-t33-gearnet-resiboost-r90,0.69656,0.011627,0.7248,0.007234
esm-t33-gearnet-resiboost-r50,0.69716,0.009042,0.72448,0.008971
esm-t33-gearnet-ensemble,0.69238,0.00925,0.7238,0.009424
esm-t33-gearnet-20-cycle-adaboost-r50,0.69428,0.008225,0.72322,0.006511
esm-t33-gearnet-adaboost-r80,0.69846,0.007369,0.72304,0.008454
esm-t33-gearnet-800-3-ensemble,0.69268,0.004945,0.72286,0.008355
esm-t33-gearnet-adaboost-r90,0.69486,0.006749,0.72234,0.008256
esm-t33-gearnet-resiboost-r80,0.69628,0.010193,0.72228,0.007619


In [4]:
df_final = df[
    (df['model_key'].str.startswith('esm-t33-gvp'))
]

get_stat_df(df_final)

| model_key          | mcc           | micro_auprc   | micro_auroc   | sensitivity   | precision     |
|:-------------------|:--------------|:--------------|:--------------|:--------------|:--------------|
| esm-t33-gvp-c100   | 0.457 ± 0.182 | 0.433 ± 0.200 | 0.860 ± 0.050 | 0.484 ± 0.036 | 0.542 ± 0.322 |
| esm-t33-gvp-c10-v1 | 0.483 ± 0.256 | 0.476 ± 0.286 | 0.868 ± 0.077 | 0.540 ± 0.102 | 0.522 ± 0.335 |
| esm-t33-gvp-c50    | 0.535 ± 0.199 | 0.557 ± 0.233 | 0.895 ± 0.070 | 0.570 ± 0.112 | 0.633 ± 0.301 |
| esm-t33-gvp-c10-v2 | 0.573 ± 0.219 | 0.565 ± 0.243 | 0.893 ± 0.072 | 0.613 ± 0.029 | 0.624 ± 0.285 |
| esm-t33-gvp-c20    | 0.645 ± 0.048 | 0.675 ± 0.042 | 0.936 ± 0.012 | 0.583 ± 0.042 | 0.747 ± 0.065 |
| esm-t33-gvp-c10    | 0.676 ± 0.019 | 0.679 ± 0.028 | 0.926 ± 0.009 | 0.614 ± 0.019 | 0.772 ± 0.035 |


Unnamed: 0,model_key,mcc,micro_auprc,micro_auroc,sensitivity,precision,record_count
0,esm-t33-gvp-c100,0.457 ± 0.182,0.433 ± 0.200,0.860 ± 0.050,0.484 ± 0.036,0.542 ± 0.322,5
1,esm-t33-gvp-c10-v1,0.483 ± 0.256,0.476 ± 0.286,0.868 ± 0.077,0.540 ± 0.102,0.522 ± 0.335,5
2,esm-t33-gvp-c50,0.535 ± 0.199,0.557 ± 0.233,0.895 ± 0.070,0.570 ± 0.112,0.633 ± 0.301,5
3,esm-t33-gvp-c10-v2,0.573 ± 0.219,0.565 ± 0.243,0.893 ± 0.072,0.613 ± 0.029,0.624 ± 0.285,5
4,esm-t33-gvp-c20,0.645 ± 0.048,0.675 ± 0.042,0.936 ± 0.012,0.583 ± 0.042,0.747 ± 0.065,5
5,esm-t33-gvp-c10,0.676 ± 0.019,0.679 ± 0.028,0.926 ± 0.009,0.614 ± 0.019,0.772 ± 0.035,5


In [5]:
df_final = df[
    (df['model_key'].str.startswith('esm-t33-gearnet') & df['model_key'].str.endswith('r50') &
     ~df['model_key'].str.startswith('esm-t33-gearnet-800') & 
     ~df['model_key'].str.startswith('esm-t33-gearnet-320') & 
     ~df['model_key'].str.startswith('esm-t33-gearnet-640') & 
     ~df['model_key'].str.startswith('esm-t33-gearnet-960')) |
    (df['model_key'] == 'esm-t33-gearnet') |
    (df['model_key'].str.startswith('esm-t33-gvp')) |
    (df['model_key'] == 'esm-t33-gearnet-20-cycle') |
    (df['model_key'].str.startswith('gearnet')) |
    (df['model_key'] == 'bert') |
    (df['model_key'] == 'bert-gearnet') |
    (df['model_key'] == 'esm-t33') |
    (df['model_key'].str.startswith('gvp'))
]

get_stat_df(df_final)

| model_key                              | mcc           | micro_auprc   | micro_auroc   | sensitivity   | precision     |
|:---------------------------------------|:--------------|:--------------|:--------------|:--------------|:--------------|
| gvp                                    | 0.236 ± 0.026 | 0.229 ± 0.032 | 0.796 ± 0.011 | 0.405 ± 0.091 | 0.205 ± 0.042 |
| gvp-20-cycle                           | 0.359 ± 0.039 | 0.363 ± 0.030 | 0.851 ± 0.008 | 0.317 ± 0.075 | 0.483 ± 0.129 |
| esm-t33-gvp-c100                       | 0.457 ± 0.182 | 0.433 ± 0.200 | 0.860 ± 0.050 | 0.484 ± 0.036 | 0.542 ± 0.322 |
| gearnet                                | 0.470 ± 0.012 | 0.490 ± 0.012 | 0.879 ± 0.009 | 0.385 ± 0.057 | 0.627 ± 0.065 |
| bert                                   | 0.481 ± 0.031 | 0.467 ± 0.014 | 0.871 ± 0.007 | 0.399 ± 0.014 | 0.626 ± 0.061 |
| gearnet-50-cycle                       | 0.483 ± 0.019 | 0.499 ± 0.015 | 0.878 ± 0.004 | 0.373 ± 0.015 | 0.671 ± 0.057 |
| esm-t33-gvp-c1

Unnamed: 0,model_key,mcc,micro_auprc,micro_auroc,sensitivity,precision,record_count
0,gvp,0.236 ± 0.026,0.229 ± 0.032,0.796 ± 0.011,0.405 ± 0.091,0.205 ± 0.042,5
1,gvp-20-cycle,0.359 ± 0.039,0.363 ± 0.030,0.851 ± 0.008,0.317 ± 0.075,0.483 ± 0.129,5
2,esm-t33-gvp-c100,0.457 ± 0.182,0.433 ± 0.200,0.860 ± 0.050,0.484 ± 0.036,0.542 ± 0.322,5
3,gearnet,0.470 ± 0.012,0.490 ± 0.012,0.879 ± 0.009,0.385 ± 0.057,0.627 ± 0.065,5
4,bert,0.481 ± 0.031,0.467 ± 0.014,0.871 ± 0.007,0.399 ± 0.014,0.626 ± 0.061,5
5,gearnet-50-cycle,0.483 ± 0.019,0.499 ± 0.015,0.878 ± 0.004,0.373 ± 0.015,0.671 ± 0.057,5
6,esm-t33-gvp-c10-v1,0.483 ± 0.256,0.476 ± 0.286,0.868 ± 0.077,0.540 ± 0.102,0.522 ± 0.335,5
7,gearnet-100-cycle,0.488 ± 0.021,0.512 ± 0.021,0.882 ± 0.009,0.370 ± 0.015,0.688 ± 0.035,5
8,gearnet-20-cycle,0.490 ± 0.029,0.508 ± 0.021,0.883 ± 0.008,0.389 ± 0.049,0.663 ± 0.045,5
9,gvp-50-cycle,0.503 ± 0.023,0.526 ± 0.034,0.889 ± 0.008,0.449 ± 0.043,0.609 ± 0.034,5


In [6]:
# Group by hyperparameters and calculate average metrics
df = pd.read_csv('atpbind3d_esm-t33-gvp_stats.csv')

# Define the hyperparameters to group by
hyperparameters = ['model_kwargs.lm_freeze_layer_count', 'model_kwargs.node_h_dim', 
                   'model_kwargs.num_layers', 'model_kwargs.residual', 'cycle_size', 'max_lr']

# Define the metrics to average
metrics = ['mcc', 'micro_auprc', 'micro_auroc', 'sensitivity', 'precision']

# Group by hyperparameters and calculate mean of metrics
grouped_df = df.groupby(hyperparameters)[metrics].agg(['mean', 'std']).reset_index()

# Sort by MCC (you can change this to any other metric)
grouped_df = grouped_df.sort_values(('mcc', 'mean'), ascending=False)

# Display the results
print("Average metrics grouped by hyperparameters:")
grouped_df


Average metrics grouped by hyperparameters:


Unnamed: 0_level_0,model_kwargs.lm_freeze_layer_count,model_kwargs.node_h_dim,model_kwargs.num_layers,model_kwargs.residual,cycle_size,max_lr,mcc,mcc,micro_auprc,micro_auprc,micro_auroc,micro_auroc,sensitivity,sensitivity,precision,precision
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,mean,std,mean,std,mean,std,mean,std,mean,std
2,30,"(256, 16)",3,False,20,0.002,0.67234,0.013974,0.68515,0.020854,0.92466,0.012061,0.62568,0.026225,0.75225,0.020278
27,30,"(512, 16)",4,True,20,0.002,0.66924,0.01741,0.6964,0.008604,0.93,0.016658,0.61664,0.027383,0.75606,0.013958
6,30,"(256, 16)",3,True,20,0.002,0.66803,0.013869,0.68244,0.022596,0.9314,0.009458,0.63294,0.026669,0.73551,0.027415
20,30,"(512, 16)",3,True,10,0.002,0.666,0.011365,0.68596,0.00869,0.9345,0.010743,0.63024,0.020691,0.73388,0.016475
22,30,"(512, 16)",3,True,20,0.002,0.66586,0.018281,0.6672,0.012739,0.9129,0.023342,0.60534,0.030487,0.76244,0.010941
12,30,"(256, 16)",4,True,10,0.002,0.66514,0.010518,0.69002,0.017658,0.93844,0.008436,0.61706,0.026711,0.74823,0.040812
39,31,"(512, 16)",3,True,20,0.002,0.66456,0.013554,0.67296,0.020929,0.91418,0.038265,0.61336,0.027262,0.7505,0.023483
43,31,"(512, 16)",4,True,20,0.002,0.66378,0.005257,0.68016,0.012995,0.93078,0.009352,0.59762,0.034352,0.7686,0.037765
4,30,"(256, 16)",3,True,10,0.002,0.66352,0.011653,0.67868,0.020474,0.93279,0.025663,0.62567,0.024296,0.73431,0.02224
34,31,"(256, 16)",4,True,10,0.002,0.66352,0.019222,0.68722,0.012269,0.93934,0.006211,0.60534,0.04059,0.75816,0.017461


### ESM-T33-GEARNET

In [7]:
# Group by hyperparameters and calculate average metrics
df = pd.read_csv('atpbind3d_esm-t33-gearnet_stats.csv')

# Define the hyperparameters to group by
hyperparameters = ['model_kwargs.lm_freeze_layer_count', 'max_slice_length', 'padding']

# Define the metrics to average
metrics = ['mcc', 'micro_auprc', 'micro_auroc', 'sensitivity', 'precision']

# Group by hyperparameters and calculate mean of metrics
grouped_df = df.groupby(hyperparameters)[metrics].agg(
    ['mean', 'std']).reset_index()

# Sort by MCC (you can change this to any other metric)
grouped_df = grouped_df.sort_values(('mcc', 'mean'), ascending=False)

# Display the results
print("Average metrics grouped by hyperparameters:")
grouped_df

Average metrics grouped by hyperparameters:


Unnamed: 0_level_0,model_kwargs.lm_freeze_layer_count,max_slice_length,padding,mcc,mcc,micro_auprc,micro_auprc,micro_auroc,micro_auroc,sensitivity,sensitivity,precision,precision
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,mean,std,mean,std,mean,std,mean,std,mean,std
16,28.0,600.0,50.0,0.680100,0.009383,0.695340,0.011962,0.91848,0.006943,0.633220,0.017891,0.759620,0.030512
37,29.0,600.0,75.0,0.679940,0.007546,0.683640,0.015803,0.92028,0.010449,0.619580,0.010406,0.774800,0.020189
2,27.0,600.0,50.0,0.678500,0.007970,0.689280,0.010896,0.91452,0.005949,0.625800,0.018640,0.764820,0.027292
64,30.0,900.0,50.0,0.678230,0.013720,0.692360,0.021266,0.92380,0.009686,0.636630,0.019221,0.751590,0.021000
12,28.0,500.0,50.0,0.676800,0.016594,0.690680,0.011410,0.92416,0.008964,0.619280,0.025742,0.768760,0.022014
...,...,...,...,...,...,...,...,...,...,...,...,...,...
102,32.0,700.0,25.0,0.620100,,0.646600,,0.90970,,0.563800,,0.715600,
98,32.0,600.0,25.0,0.615800,,0.633500,,0.90630,,0.556400,,0.715600,
107,33.0,500.0,50.0,0.603700,0.015584,0.648140,0.016888,0.93248,0.003867,0.564100,0.033823,0.682660,0.036961
108,33.0,600.0,50.0,0.601850,0.035603,0.635775,0.029424,0.93235,0.006203,0.577875,0.015034,0.663875,0.066945


In [8]:
def analyze_metrics(csv_path, group_by_param, filter_conditions={}, sort_by=('mcc', 'mean'), ascending=False):
    """
    Analyze metrics by grouping on specified hyperparameters and filtering the DataFrame.

    Parameters:
    - group_by_param: str or list, the hyperparameter(s) to group by.
    - filter_conditions: dict, conditions to filter the DataFrame.

    Returns:
    - grouped_df: DataFrame, the grouped and averaged metrics.
    """
    import pandas as pd
    df = pd.read_csv(csv_path)

    # Apply filter conditions
    for column, value in filter_conditions.items():
        if isinstance(value, list):
            df = df[df[column].isin(value)]
        else:
            df = df[df[column] == value]

    # Define the hyperparameters to group by
    hyperparameters = group_by_param if isinstance(group_by_param, list) else [group_by_param]

    # Define the metrics to average
    metrics = ['mcc', 'sensitivity', 'precision', 'micro_auroc', 'micro_auprc']

    # Group by hyperparameters and calculate mean of metrics, including a single count column
    
    grouped_df = df.groupby(hyperparameters)[metrics].agg(['mean', 'std']).reset_index()
    grouped_df['count'] = df.groupby(hyperparameters).size().values

    # Round the final metrics to the fourth digit
    for metric in metrics:
        grouped_df[(metric, 'mean')] = grouped_df[(metric, 'mean')].round(4)
        grouped_df[(metric, 'std')] = grouped_df[(metric, 'std')].round(4)

    # Sort by MCC (you can change this to any other metric)
    grouped_df = grouped_df.sort_values(sort_by, ascending=ascending)

    # Display the results
    return grouped_df


def print_markdown_table(grouped_df, metrics=['mcc', 'sensitivity', 'precision', 'micro_auroc', 'micro_auprc']):
    """
    Print the markdown formatted table from the grouped DataFrame.

    Parameters:
    - grouped_df: DataFrame, the grouped and averaged metrics.
    - metrics: list, the metrics to include in the table.
    """
    from tabulate import tabulate

    # Format the metrics with mean and std
    for metric in metrics:
        grouped_df[f'{metric}_formatted'] = grouped_df.apply(
            lambda row: f"{row[(metric, 'mean')]:.3f} ± {row[(metric, 'std')]:.3f}", axis=1)
    # Drop the original metric columns
    for metric in metrics:
        grouped_df.drop(columns=[(metric, 'mean'), (metric, 'std')], inplace=True)
    
    # Rename the formatted metric columns to the original metric names
    for metric in metrics:
        grouped_df.rename(columns={f'{metric}_formatted': metric}, inplace=True)

    # Select the columns to display, including non-metric keys
    non_metric_keys = [col for col in grouped_df.columns.levels[0] if col not in metrics and col != 'count' and col != 'index']
    formatted_df = grouped_df[non_metric_keys + [f'{metric}' for metric in metrics] + ['count']]
    formatted_df.columns = non_metric_keys + metrics + ['record_count']
    # Print the markdown table
    markdown_table = tabulate(formatted_df, headers='keys', tablefmt='pipe', showindex=False)
    print(markdown_table)
    

# Review 1-1: Large Dataset

First, check normal esm-t33-gearnet. There is some hyperparameter search.

In [9]:
patp1930_df = analyze_metrics(
    'atpbind3d-1930_esm-t33_stats.csv',
    ['model_key'],
    sort_by=('mcc', 'mean'),
    ascending=False
)
print_markdown_table(patp1930_df)
patp1930_df

| model_key   | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|:------------|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
| esm-t33     | 0.701 ± 0.007 | 0.628 ± 0.015 | 0.810 ± 0.011 | 0.949 ± 0.002 | 0.735 ± 0.003 |              5 |


Unnamed: 0,model_key,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,
0.0,esm-t33,5.0,0.701 ± 0.007,0.628 ± 0.015,0.810 ± 0.011,0.949 ± 0.002,0.735 ± 0.003


In [10]:
patp1930_df = analyze_metrics(
    'atpbind3d-1930_esm-t33-gearnet_stats.csv',
    ['model_key', 'model_kwargs.lm_freeze_layer_count', 'max_slice_length'],
    sort_by=('mcc', 'mean'),
    ascending=False
)
print_markdown_table(patp1930_df)
patp1930_df

|   max_slice_length | model_key           |   model_kwargs.lm_freeze_layer_count | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|-------------------:|:--------------------|-------------------------------------:|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|                600 | esm-t33-gearnet     |                                   28 | 0.727 ± 0.013 | 0.663 ± 0.011 | 0.821 ± 0.017 | 0.936 ± 0.003 | 0.750 ± 0.008 |              5 |
|                300 | esm-t33-gearnet     |                                   30 | 0.722 ± nan   | 0.639 ± nan   | 0.839 ± nan   | 0.932 ± nan   | 0.743 ± nan   |              1 |
|                500 | esm-t33-gearnet     |                                   30 | 0.721 ± 0.008 | 0.668 ± 0.015 | 0.803 ± 0.017 | 0.943 ± 0.005 | 0.745 ± 0.010 |              5 |
|                500 | esm-t33-gearnet     |                                   28 | 0.720 ± 0.0

Unnamed: 0,model_key,model_kwargs.lm_freeze_layer_count,max_slice_length,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,,,
1.0,esm-t33-gearnet,28.0,600.0,5.0,0.727 ± 0.013,0.663 ± 0.011,0.821 ± 0.017,0.936 ± 0.003,0.750 ± 0.008
2.0,esm-t33-gearnet,30.0,300.0,1.0,0.722 ± nan,0.639 ± nan,0.839 ± nan,0.932 ± nan,0.743 ± nan
4.0,esm-t33-gearnet,30.0,500.0,5.0,0.721 ± 0.008,0.668 ± 0.015,0.803 ± 0.017,0.943 ± 0.005,0.745 ± 0.010
0.0,esm-t33-gearnet,28.0,500.0,5.0,0.720 ± 0.007,0.666 ± 0.021,0.804 ± 0.020,0.939 ± 0.003,0.748 ± 0.008
5.0,esm-t33-gearnet,30.0,600.0,5.0,0.716 ± 0.006,0.638 ± 0.024,0.830 ± 0.026,0.938 ± 0.006,0.751 ± 0.011
3.0,esm-t33-gearnet,30.0,400.0,1.0,0.700 ± nan,0.623 ± nan,0.812 ± nan,0.933 ± nan,0.730 ± nan
6.0,esm-t33-gearnet-bs2,30.0,300.0,1.0,0.668 ± nan,0.596 ± nan,0.778 ± nan,0.946 ± nan,0.708 ± nan


In [11]:
patp1930_df = analyze_metrics(
    'atpbind3d-1930_esm-t33-gearnet-resiboost_stats.csv',
    ['boost_negative_use_ratio', 'boost_mask_positive'],
    sort_by=('mcc', 'mean'),
    filter_conditions={'boost_negative_use_ratio': [0.1]},
    ascending=False
)
print_markdown_table(patp1930_df)
patp1930_df

| boost_mask_positive   |   boost_negative_use_ratio | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|:----------------------|---------------------------:|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
| False                 |                        0.1 | 0.736 ± 0.005 | 0.685 ± 0.008 | 0.815 ± 0.017 | 0.924 ± 0.007 | 0.757 ± 0.007 |              5 |
| True                  |                        0.1 | 0.732 ± 0.004 | 0.673 ± 0.016 | 0.820 ± 0.011 | 0.926 ± 0.003 | 0.752 ± 0.006 |              5 |


Unnamed: 0,boost_negative_use_ratio,boost_mask_positive,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,,
0.0,0.1,False,5.0,0.736 ± 0.005,0.685 ± 0.008,0.815 ± 0.017,0.924 ± 0.007,0.757 ± 0.007
1.0,0.1,True,5.0,0.732 ± 0.004,0.673 ± 0.016,0.820 ± 0.011,0.926 ± 0.003,0.752 ± 0.006


# Review 1-2: Hyperparameter search

In [12]:
df = analyze_metrics('atpbind3d_esm-t33-gearnet_stats.csv', 
                'model_kwargs.lm_freeze_layer_count', 
                filter_conditions={'padding': 50, 'max_slice_length': [500],},
                sort_by='model_kwargs.lm_freeze_layer_count',
                ascending=True
)
print_markdown_table(df)
df

|   model_kwargs.lm_freeze_layer_count | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|-------------------------------------:|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|                                   27 | 0.673 ± 0.015 | 0.626 ± 0.023 | 0.752 ± 0.019 | 0.923 ± 0.010 | 0.696 ± 0.017 |              5 |
|                                   28 | 0.677 ± 0.017 | 0.619 ± 0.026 | 0.769 ± 0.022 | 0.924 ± 0.009 | 0.691 ± 0.011 |              5 |
|                                   29 | 0.670 ± 0.007 | 0.611 ± 0.013 | 0.763 ± 0.008 | 0.915 ± 0.004 | 0.675 ± 0.006 |              5 |
|                                   30 | 0.675 ± 0.015 | 0.632 ± 0.021 | 0.751 ± 0.023 | 0.922 ± 0.008 | 0.689 ± 0.011 |             10 |
|                                   31 | 0.668 ± 0.013 | 0.615 ± 0.022 | 0.755 ± 0.033 | 0.927 ± 0.004 | 0.686 ± 0.005 |              5 |
|                                 

Unnamed: 0,model_kwargs.lm_freeze_layer_count,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,
0.0,27.0,5.0,0.673 ± 0.015,0.626 ± 0.023,0.752 ± 0.019,0.923 ± 0.010,0.696 ± 0.017
1.0,28.0,5.0,0.677 ± 0.017,0.619 ± 0.026,0.769 ± 0.022,0.924 ± 0.009,0.691 ± 0.011
2.0,29.0,5.0,0.670 ± 0.007,0.611 ± 0.013,0.763 ± 0.008,0.915 ± 0.004,0.675 ± 0.006
3.0,30.0,10.0,0.675 ± 0.015,0.632 ± 0.021,0.751 ± 0.023,0.922 ± 0.008,0.689 ± 0.011
4.0,31.0,5.0,0.668 ± 0.013,0.615 ± 0.022,0.755 ± 0.033,0.927 ± 0.004,0.686 ± 0.005
5.0,32.0,5.0,0.652 ± 0.019,0.578 ± 0.028,0.766 ± 0.022,0.918 ± 0.003,0.670 ± 0.014
6.0,33.0,5.0,0.604 ± 0.016,0.564 ± 0.034,0.683 ± 0.037,0.932 ± 0.004,0.648 ± 0.017


In [19]:
df = analyze_metrics('atpbind3d_esm-t33-gearnet_stats.csv',
                     'max_slice_length',
                     filter_conditions={'padding': 50,
                                        'model_kwargs.lm_freeze_layer_count': [30],
                                        'valid_fold': [0,1,2,3,4],
                                        'max_slice_length': [300, 400, 500, 600, 700, 800],
                                        },
                     sort_by='max_slice_length',
                     ascending=True
                     )
print_markdown_table(df)
df

|   max_slice_length | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|-------------------:|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|                300 | 0.658 ± 0.008 | 0.619 ± 0.021 | 0.729 ± 0.019 | 0.908 ± 0.007 | 0.667 ± 0.010 |             10 |
|                400 | 0.666 ± 0.016 | 0.619 ± 0.022 | 0.746 ± 0.038 | 0.916 ± 0.008 | 0.679 ± 0.011 |             10 |
|                500 | 0.675 ± 0.015 | 0.632 ± 0.021 | 0.751 ± 0.023 | 0.922 ± 0.008 | 0.689 ± 0.011 |             10 |
|                600 | 0.666 ± 0.014 | 0.625 ± 0.042 | 0.743 ± 0.046 | 0.923 ± 0.009 | 0.686 ± 0.011 |             10 |
|                700 | 0.666 ± 0.016 | 0.617 ± 0.031 | 0.750 ± 0.021 | 0.919 ± 0.006 | 0.678 ± 0.014 |             10 |
|                800 | 0.664 ± 0.011 | 0.626 ± 0.030 | 0.736 ± 0.021 | 0.919 ± 0.009 | 0.677 ± 0.015 |             10 |


Unnamed: 0,max_slice_length,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,
0.0,300.0,10.0,0.658 ± 0.008,0.619 ± 0.021,0.729 ± 0.019,0.908 ± 0.007,0.667 ± 0.010
1.0,400.0,10.0,0.666 ± 0.016,0.619 ± 0.022,0.746 ± 0.038,0.916 ± 0.008,0.679 ± 0.011
2.0,500.0,10.0,0.675 ± 0.015,0.632 ± 0.021,0.751 ± 0.023,0.922 ± 0.008,0.689 ± 0.011
3.0,600.0,10.0,0.666 ± 0.014,0.625 ± 0.042,0.743 ± 0.046,0.923 ± 0.009,0.686 ± 0.011
4.0,700.0,10.0,0.666 ± 0.016,0.617 ± 0.031,0.750 ± 0.021,0.919 ± 0.006,0.678 ± 0.014
5.0,800.0,10.0,0.664 ± 0.011,0.626 ± 0.030,0.736 ± 0.021,0.919 ± 0.009,0.677 ± 0.015


# Review 1-3: GVP

In [14]:
# All Data
df = analyze_metrics('atpbind3d_gvp_stats.csv',
                     ['cycle_size'],
                     filter_conditions={'valid_fold': [0, 1, 2, 3, 4]},
                     sort_by=('mcc', 'mean'),
                     ascending=False
                     )
print_markdown_table(df)
df

|   cycle_size | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|-------------:|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|          100 | 0.509 ± 0.022 | 0.409 ± 0.032 | 0.677 ± 0.024 | 0.804 ± 0.034 | 0.484 ± 0.021 |              5 |
|           50 | 0.504 ± 0.032 | 0.418 ± 0.049 | 0.653 ± 0.034 | 0.879 ± 0.008 | 0.527 ± 0.037 |              5 |
|           20 | 0.377 ± 0.034 | 0.357 ± 0.075 | 0.464 ± 0.080 | 0.850 ± 0.015 | 0.380 ± 0.031 |              5 |


Unnamed: 0,cycle_size,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,
2.0,100.0,5.0,0.509 ± 0.022,0.409 ± 0.032,0.677 ± 0.024,0.804 ± 0.034,0.484 ± 0.021
1.0,50.0,5.0,0.504 ± 0.032,0.418 ± 0.049,0.653 ± 0.034,0.879 ± 0.008,0.527 ± 0.037
0.0,20.0,5.0,0.377 ± 0.034,0.357 ± 0.075,0.464 ± 0.080,0.850 ± 0.015,0.380 ± 0.031


In [15]:
# Selected Data
df = analyze_metrics('atpbind3d_bert-gvp_stats.csv',
                     ['cycle_size', 'max_lr'],
                     sort_by=('mcc', 'mean'),
                     ascending=False
                     )
print_markdown_table(df)
df

|   cycle_size |   max_lr | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|-------------:|---------:|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|           20 |    0.001 | 0.538 ± 0.016 | 0.470 ± 0.030 | 0.658 ± 0.051 | 0.915 ± 0.007 | 0.563 ± 0.019 |              5 |
|           40 |    0.001 | 0.537 ± 0.014 | 0.441 ± 0.034 | 0.696 ± 0.043 | 0.895 ± 0.015 | 0.538 ± 0.023 |              5 |
|           10 |    0.001 | 0.536 ± 0.022 | 0.453 ± 0.047 | 0.678 ± 0.052 | 0.914 ± 0.011 | 0.549 ± 0.027 |              5 |
|           40 |    0.003 | 0.524 ± 0.024 | 0.431 ± 0.039 | 0.680 ± 0.028 | 0.894 ± 0.010 | 0.536 ± 0.026 |              5 |
|           20 |    0.003 | 0.505 ± 0.017 | 0.431 ± 0.051 | 0.638 ± 0.042 | 0.882 ± 0.021 | 0.512 ± 0.019 |              5 |
|           10 |    0.003 | 0.475 ± 0.015 | 0.410 ± 0.068 | 0.610 ± 0.102 | 0.897 ± 0.008 | 0.464 ± 0.020 |              5 |


Unnamed: 0,cycle_size,max_lr,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,,
2.0,20.0,0.001,5.0,0.538 ± 0.016,0.470 ± 0.030,0.658 ± 0.051,0.915 ± 0.007,0.563 ± 0.019
4.0,40.0,0.001,5.0,0.537 ± 0.014,0.441 ± 0.034,0.696 ± 0.043,0.895 ± 0.015,0.538 ± 0.023
0.0,10.0,0.001,5.0,0.536 ± 0.022,0.453 ± 0.047,0.678 ± 0.052,0.914 ± 0.011,0.549 ± 0.027
5.0,40.0,0.003,5.0,0.524 ± 0.024,0.431 ± 0.039,0.680 ± 0.028,0.894 ± 0.010,0.536 ± 0.026
3.0,20.0,0.003,5.0,0.505 ± 0.017,0.431 ± 0.051,0.638 ± 0.042,0.882 ± 0.021,0.512 ± 0.019
1.0,10.0,0.003,5.0,0.475 ± 0.015,0.410 ± 0.068,0.610 ± 0.102,0.897 ± 0.008,0.464 ± 0.020


In [16]:
# Selected Data
df = analyze_metrics('atpbind3d_esm-t33-gvp_stats.csv',
                     ['model_kwargs.lm_freeze_layer_count', 'model_kwargs.num_layers', 'model_kwargs.residual', 'cycle_size'],

                     filter_conditions={
                        'max_lr': [0.002],
                        'cycle_size': [10],
                        'model_kwargs.lm_freeze_layer_count': [30],
                        'model_kwargs.node_h_dim': ["(256, 16)"],
                    },
                     sort_by=('mcc', 'mean'),
                     ascending=False
                     )
print_markdown_table(df)
df

|   cycle_size |   model_kwargs.lm_freeze_layer_count |   model_kwargs.num_layers | model_kwargs.residual   | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|-------------:|-------------------------------------:|--------------------------:|:------------------------|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|           10 |                                   30 |                         4 | True                    | 0.665 ± 0.011 | 0.617 ± 0.027 | 0.748 ± 0.041 | 0.938 ± 0.008 | 0.690 ± 0.018 |             10 |
|           10 |                                   30 |                         3 | True                    | 0.663 ± 0.012 | 0.626 ± 0.024 | 0.734 ± 0.022 | 0.933 ± 0.026 | 0.679 ± 0.021 |             10 |
|           10 |                                   30 |                         4 | False                   | 0.657 ± 0.012 | 0.623 ± 0.022 | 0.724 ± 0.028 | 0.933 ± 0.009 

Unnamed: 0,model_kwargs.lm_freeze_layer_count,model_kwargs.num_layers,model_kwargs.residual,cycle_size,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,,,,
3.0,30.0,4.0,True,10.0,10.0,0.665 ± 0.011,0.617 ± 0.027,0.748 ± 0.041,0.938 ± 0.008,0.690 ± 0.018
1.0,30.0,3.0,True,10.0,10.0,0.663 ± 0.012,0.626 ± 0.024,0.734 ± 0.022,0.933 ± 0.026,0.679 ± 0.021
2.0,30.0,4.0,False,10.0,10.0,0.657 ± 0.012,0.623 ± 0.022,0.724 ± 0.028,0.933 ± 0.009,0.675 ± 0.012
0.0,30.0,3.0,False,10.0,10.0,0.605 ± 0.195,0.610 ± 0.055,0.675 ± 0.218,0.900 ± 0.112,0.625 ± 0.200


# Review 1-4: WeightedEntropyLoss
Params to check: 
Weighted BCE:
- pos_weight_factor
Focal Loss:
- task_kwargs.criterion
- task_kwargs.focal_loss_gamma
- task_kwargs.focal_loss_alpha

In [17]:
df = analyze_metrics('atpbind3d_esm-t33-gearnet_stats.csv',
                     ['pos_weight_factor'],
                     filter_conditions={'pos_weight_factor': [0.25, 1, 4, 16]},
                     sort_by=('mcc', 'mean'),
                     ascending=False
                     )
print_markdown_table(df)
df

|   pos_weight_factor | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|--------------------:|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|                4    | 0.671 ± 0.011 | 0.635 ± 0.014 | 0.740 ± 0.026 | 0.927 ± 0.011 | 0.681 ± 0.015 |              5 |
|                1    | 0.669 ± 0.012 | 0.628 ± 0.024 | 0.743 ± 0.017 | 0.927 ± 0.008 | 0.689 ± 0.017 |              5 |
|                0.25 | 0.662 ± 0.020 | 0.611 ± 0.023 | 0.749 ± 0.046 | 0.917 ± 0.004 | 0.675 ± 0.014 |              5 |
|               16    | 0.658 ± 0.004 | 0.645 ± 0.016 | 0.702 ± 0.014 | 0.931 ± 0.008 | 0.669 ± 0.012 |              5 |


Unnamed: 0,pos_weight_factor,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,
2.0,4.0,5.0,0.671 ± 0.011,0.635 ± 0.014,0.740 ± 0.026,0.927 ± 0.011,0.681 ± 0.015
1.0,1.0,5.0,0.669 ± 0.012,0.628 ± 0.024,0.743 ± 0.017,0.927 ± 0.008,0.689 ± 0.017
0.0,0.25,5.0,0.662 ± 0.020,0.611 ± 0.023,0.749 ± 0.046,0.917 ± 0.004,0.675 ± 0.014
3.0,16.0,5.0,0.658 ± 0.004,0.645 ± 0.016,0.702 ± 0.014,0.931 ± 0.008,0.669 ± 0.012


In [18]:
df = analyze_metrics('atpbind3d_esm-t33-gearnet_stats.csv',
                     ['task_kwargs.focal_loss_gamma', 'task_kwargs.focal_loss_alpha'],
                     {'task_kwargs.criterion': ['focal']},
                     sort_by=('mcc', 'mean'),
                     ascending=False
                     )
print_markdown_table(df)
df

|   task_kwargs.focal_loss_alpha |   task_kwargs.focal_loss_gamma | mcc           | sensitivity   | precision     | micro_auroc   | micro_auprc   |   record_count |
|-------------------------------:|-------------------------------:|:--------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|                           0.25 |                              2 | 0.674 ± 0.007 | 0.625 ± 0.031 | 0.757 ± 0.023 | 0.926 ± 0.010 | 0.675 ± 0.004 |              5 |
|                           0.3  |                              1 | 0.672 ± 0.014 | 0.617 ± 0.024 | 0.762 ± 0.015 | 0.931 ± 0.008 | 0.681 ± 0.014 |              5 |
|                           0.2  |                              2 | 0.671 ± 0.011 | 0.639 ± 0.008 | 0.735 ± 0.027 | 0.929 ± 0.010 | 0.685 ± 0.013 |              5 |
|                           0.25 |                              1 | 0.670 ± 0.008 | 0.632 ± 0.011 | 0.741 ± 0.015 | 0.929 ± 0.005 | 0.690 ± 0.008 |              5 |
|         

Unnamed: 0,task_kwargs.focal_loss_gamma,task_kwargs.focal_loss_alpha,count,mcc,sensitivity,precision,micro_auroc,micro_auprc
,,,,,,,,
4.0,2.0,0.25,5.0,0.674 ± 0.007,0.625 ± 0.031,0.757 ± 0.023,0.926 ± 0.010,0.675 ± 0.004
2.0,1.0,0.3,5.0,0.672 ± 0.014,0.617 ± 0.024,0.762 ± 0.015,0.931 ± 0.008,0.681 ± 0.014
3.0,2.0,0.2,5.0,0.671 ± 0.011,0.639 ± 0.008,0.735 ± 0.027,0.929 ± 0.010,0.685 ± 0.013
1.0,1.0,0.25,5.0,0.670 ± 0.008,0.632 ± 0.011,0.741 ± 0.015,0.929 ± 0.005,0.690 ± 0.008
5.0,2.0,0.3,5.0,0.669 ± 0.011,0.627 ± 0.019,0.745 ± 0.017,0.934 ± 0.006,0.689 ± 0.019
7.0,3.0,0.25,5.0,0.667 ± 0.018,0.631 ± 0.036,0.736 ± 0.022,0.930 ± 0.011,0.684 ± 0.030
6.0,3.0,0.2,5.0,0.666 ± 0.010,0.621 ± 0.027,0.746 ± 0.023,0.934 ± 0.010,0.684 ± 0.007
8.0,3.0,0.3,5.0,0.663 ± 0.008,0.609 ± 0.025,0.751 ± 0.017,0.934 ± 0.005,0.684 ± 0.010
0.0,1.0,0.2,5.0,0.662 ± 0.019,0.621 ± 0.036,0.738 ± 0.024,0.931 ± 0.013,0.690 ± 0.020
