In [2]:
import pandas as pd
from tabulate import tabulate
def format_mean_std(row, metric):
    mean_val = row[(metric, 'mean')]
    std_val = row[(metric, 'std')]
    return f"{mean_val:.3f} ± {std_val:.3f}"


def get_stat_df(result_df, verbose=True, metrics=['mcc', 'micro_auprc', 'micro_auroc',
               'sensitivity', 'precision']):
    # Sample loading of your CSV - this would be replaced with your actual loading

    grouped = result_df.groupby('model_key').agg({
        k: ['mean', 'std'] for k in metrics
    })

    record_counts = result_df.groupby('model_key').size()

    for metric in metrics:
        grouped[f'{metric}_formatted'] = grouped.apply(
            lambda row: format_mean_std(row, metric), axis=1)

    # Add the record count column to the grouped dataframe
    grouped['record_count'] = record_counts.values

    formatted_df = grouped[[
        k + '_formatted' for k in metrics] + ['record_count']]
    formatted_df.columns = metrics + ['record_count']
    formatted_df = formatted_df.sort_values(metrics[0], ascending=True)

    formatted_df = formatted_df.reset_index()

    if verbose:
        markdown_table = tabulate(formatted_df[[
                                  'model_key'] + metrics], headers='keys', tablefmt='pipe', showindex=False)
        print(markdown_table)

    return formatted_df

In [3]:
import pandas as pd

df = pd.read_csv('atpbind3d_stats.csv')

df_gn = df[
    # ((df['model_key'].str.startswith('esm-t33-gearnet-640')) |
    (df['model_key'].str.startswith('esm-t33-gearnet'))
    # (df['model_key'] == 'esm-t33-gearnet') | 
    # (df['model_key'] == 'esm-t33')) & 
]

get_stat_df(df_gn, metrics=['mcc','micro_auprc', 'micro_auroc',
               'sensitivity', 'precision', ])

| model_key                              | mcc           | micro_auprc   | micro_auroc   | sensitivity   | precision     |
|:---------------------------------------|:--------------|:--------------|:--------------|:--------------|:--------------|
| esm-t33-gearnet                        | 0.668 ± 0.012 | 0.687 ± 0.011 | 0.924 ± 0.007 | 0.623 ± 0.023 | 0.747 ± 0.014 |
| esm-t33-gearnet-640-2                  | 0.671 ± 0.009 | 0.683 ± 0.013 | 0.923 ± 0.008 | 0.622 ± 0.023 | 0.753 ± 0.021 |
| esm-t33-gearnet-320-4                  | 0.672 ± 0.008 | 0.683 ± 0.004 | 0.920 ± 0.006 | 0.623 ± 0.024 | 0.755 ± 0.015 |
| esm-t33-gearnet-800-4                  | 0.672 ± 0.011 | 0.694 ± 0.018 | 0.932 ± 0.009 | 0.640 ± 0.033 | 0.736 ± 0.027 |
| esm-t33-gearnet-800-3                  | 0.672 ± 0.016 | 0.685 ± 0.014 | 0.925 ± 0.008 | 0.623 ± 0.024 | 0.753 ± 0.022 |
| esm-t33-gearnet-20-cycle               | 0.675 ± 0.015 | 0.695 ± 0.011 | 0.923 ± 0.004 | 0.619 ± 0.015 | 0.766 ± 0.019 |
| esm-t33-gearne

Unnamed: 0,model_key,mcc,micro_auprc,micro_auroc,sensitivity,precision,record_count
0,esm-t33-gearnet,0.668 ± 0.012,0.687 ± 0.011,0.924 ± 0.007,0.623 ± 0.023,0.747 ± 0.014,6
1,esm-t33-gearnet-640-2,0.671 ± 0.009,0.683 ± 0.013,0.923 ± 0.008,0.622 ± 0.023,0.753 ± 0.021,10
2,esm-t33-gearnet-320-4,0.672 ± 0.008,0.683 ± 0.004,0.920 ± 0.006,0.623 ± 0.024,0.755 ± 0.015,5
3,esm-t33-gearnet-800-4,0.672 ± 0.011,0.694 ± 0.018,0.932 ± 0.009,0.640 ± 0.033,0.736 ± 0.027,5
4,esm-t33-gearnet-800-3,0.672 ± 0.016,0.685 ± 0.014,0.925 ± 0.008,0.623 ± 0.024,0.753 ± 0.022,15
5,esm-t33-gearnet-20-cycle,0.675 ± 0.015,0.695 ± 0.011,0.923 ± 0.004,0.619 ± 0.015,0.766 ± 0.019,5
6,esm-t33-gearnet-960-2,0.676 ± 0.013,0.691 ± 0.015,0.926 ± 0.010,0.631 ± 0.014,0.753 ± 0.023,10
7,esm-t33-gearnet-960-1,0.676 ± 0.016,0.684 ± 0.016,0.923 ± 0.006,0.617 ± 0.030,0.769 ± 0.029,10
8,esm-t33-gearnet-640,0.677 ± 0.009,0.685 ± 0.017,0.923 ± 0.008,0.640 ± 0.014,0.745 ± 0.015,15
9,esm-t33-gearnet-800-3-adaboost-r10,0.685 ± 0.005,0.705 ± 0.003,0.913 ± 0.005,0.639 ± 0.011,0.762 ± 0.014,5


In [4]:
import pandas as pd

df = pd.read_csv('atpbind3d_stats.csv')

df.groupby('model_key').agg({'mcc': ['mean', 'std'], 'micro_auprc': [
    'mean', 'std']}).sort_values(('micro_auprc', 'mean'), ascending=False)

Unnamed: 0_level_0,mcc,mcc,micro_auprc,micro_auprc
Unnamed: 0_level_1,mean,std,mean,std
model_key,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
esm-t33-gearnet-20-cycle-resiboost-r50,0.694733,0.011511,0.727867,0.011971
esm-t33-gearnet-adaboost-r50,0.69442,0.005516,0.7255,0.011629
esm-t33-gearnet-resiboost-r90,0.69656,0.011627,0.7248,0.007234
esm-t33-gearnet-resiboost-r50,0.69716,0.009042,0.72448,0.008971
esm-t33-gearnet-ensemble,0.69238,0.00925,0.7238,0.009424
esm-t33-gearnet-20-cycle-adaboost-r50,0.69428,0.008225,0.72322,0.006511
esm-t33-gearnet-adaboost-r80,0.69846,0.007369,0.72304,0.008454
esm-t33-gearnet-800-3-ensemble,0.69268,0.004945,0.72286,0.008355
esm-t33-gearnet-adaboost-r90,0.69486,0.006749,0.72234,0.008256
esm-t33-gearnet-resiboost-r80,0.69628,0.010193,0.72228,0.007619


In [5]:
df_final = df[
    (df['model_key'].str.startswith('esm-t33-gvp'))
]

get_stat_df(df_final)

| model_key          | mcc           | micro_auprc   | micro_auroc   | sensitivity   | precision     |
|:-------------------|:--------------|:--------------|:--------------|:--------------|:--------------|
| esm-t33-gvp-c100   | 0.457 ± 0.182 | 0.433 ± 0.200 | 0.860 ± 0.050 | 0.484 ± 0.036 | 0.542 ± 0.322 |
| esm-t33-gvp-c10-v1 | 0.483 ± 0.256 | 0.476 ± 0.286 | 0.868 ± 0.077 | 0.540 ± 0.102 | 0.522 ± 0.335 |
| esm-t33-gvp-c50    | 0.535 ± 0.199 | 0.557 ± 0.233 | 0.895 ± 0.070 | 0.570 ± 0.112 | 0.633 ± 0.301 |
| esm-t33-gvp-c10-v2 | 0.573 ± 0.219 | 0.565 ± 0.243 | 0.893 ± 0.072 | 0.613 ± 0.029 | 0.624 ± 0.285 |
| esm-t33-gvp-c20    | 0.645 ± 0.048 | 0.675 ± 0.042 | 0.936 ± 0.012 | 0.583 ± 0.042 | 0.747 ± 0.065 |
| esm-t33-gvp-c10    | 0.676 ± 0.019 | 0.679 ± 0.028 | 0.926 ± 0.009 | 0.614 ± 0.019 | 0.772 ± 0.035 |


Unnamed: 0,model_key,mcc,micro_auprc,micro_auroc,sensitivity,precision,record_count
0,esm-t33-gvp-c100,0.457 ± 0.182,0.433 ± 0.200,0.860 ± 0.050,0.484 ± 0.036,0.542 ± 0.322,5
1,esm-t33-gvp-c10-v1,0.483 ± 0.256,0.476 ± 0.286,0.868 ± 0.077,0.540 ± 0.102,0.522 ± 0.335,5
2,esm-t33-gvp-c50,0.535 ± 0.199,0.557 ± 0.233,0.895 ± 0.070,0.570 ± 0.112,0.633 ± 0.301,5
3,esm-t33-gvp-c10-v2,0.573 ± 0.219,0.565 ± 0.243,0.893 ± 0.072,0.613 ± 0.029,0.624 ± 0.285,5
4,esm-t33-gvp-c20,0.645 ± 0.048,0.675 ± 0.042,0.936 ± 0.012,0.583 ± 0.042,0.747 ± 0.065,5
5,esm-t33-gvp-c10,0.676 ± 0.019,0.679 ± 0.028,0.926 ± 0.009,0.614 ± 0.019,0.772 ± 0.035,5


In [6]:
df_final = df[
    (df['model_key'].str.startswith('esm-t33-gearnet') & df['model_key'].str.endswith('r50') &
     ~df['model_key'].str.startswith('esm-t33-gearnet-800') & 
     ~df['model_key'].str.startswith('esm-t33-gearnet-320') & 
     ~df['model_key'].str.startswith('esm-t33-gearnet-640') & 
     ~df['model_key'].str.startswith('esm-t33-gearnet-960')) |
    (df['model_key'] == 'esm-t33-gearnet') |
    (df['model_key'].str.startswith('esm-t33-gvp')) |
    (df['model_key'] == 'esm-t33-gearnet-20-cycle') |
    (df['model_key'].str.startswith('gearnet')) |
    (df['model_key'] == 'bert') |
    (df['model_key'] == 'bert-gearnet') |
    (df['model_key'] == 'esm-t33') |
    (df['model_key'].str.startswith('gvp'))
]

get_stat_df(df_final)

| model_key                              | mcc           | micro_auprc   | micro_auroc   | sensitivity   | precision     |
|:---------------------------------------|:--------------|:--------------|:--------------|:--------------|:--------------|
| gvp                                    | 0.236 ± 0.026 | 0.229 ± 0.032 | 0.796 ± 0.011 | 0.405 ± 0.091 | 0.205 ± 0.042 |
| gvp-20-cycle                           | 0.359 ± 0.039 | 0.363 ± 0.030 | 0.851 ± 0.008 | 0.317 ± 0.075 | 0.483 ± 0.129 |
| esm-t33-gvp-c100                       | 0.457 ± 0.182 | 0.433 ± 0.200 | 0.860 ± 0.050 | 0.484 ± 0.036 | 0.542 ± 0.322 |
| gearnet                                | 0.470 ± 0.012 | 0.490 ± 0.012 | 0.879 ± 0.009 | 0.385 ± 0.057 | 0.627 ± 0.065 |
| bert                                   | 0.481 ± 0.031 | 0.467 ± 0.014 | 0.871 ± 0.007 | 0.399 ± 0.014 | 0.626 ± 0.061 |
| gearnet-50-cycle                       | 0.483 ± 0.019 | 0.499 ± 0.015 | 0.878 ± 0.004 | 0.373 ± 0.015 | 0.671 ± 0.057 |
| esm-t33-gvp-c1

Unnamed: 0,model_key,mcc,micro_auprc,micro_auroc,sensitivity,precision,record_count
0,gvp,0.236 ± 0.026,0.229 ± 0.032,0.796 ± 0.011,0.405 ± 0.091,0.205 ± 0.042,5
1,gvp-20-cycle,0.359 ± 0.039,0.363 ± 0.030,0.851 ± 0.008,0.317 ± 0.075,0.483 ± 0.129,5
2,esm-t33-gvp-c100,0.457 ± 0.182,0.433 ± 0.200,0.860 ± 0.050,0.484 ± 0.036,0.542 ± 0.322,5
3,gearnet,0.470 ± 0.012,0.490 ± 0.012,0.879 ± 0.009,0.385 ± 0.057,0.627 ± 0.065,5
4,bert,0.481 ± 0.031,0.467 ± 0.014,0.871 ± 0.007,0.399 ± 0.014,0.626 ± 0.061,5
5,gearnet-50-cycle,0.483 ± 0.019,0.499 ± 0.015,0.878 ± 0.004,0.373 ± 0.015,0.671 ± 0.057,5
6,esm-t33-gvp-c10-v1,0.483 ± 0.256,0.476 ± 0.286,0.868 ± 0.077,0.540 ± 0.102,0.522 ± 0.335,5
7,gearnet-100-cycle,0.488 ± 0.021,0.512 ± 0.021,0.882 ± 0.009,0.370 ± 0.015,0.688 ± 0.035,5
8,gearnet-20-cycle,0.490 ± 0.029,0.508 ± 0.021,0.883 ± 0.008,0.389 ± 0.049,0.663 ± 0.045,5
9,gvp-50-cycle,0.503 ± 0.023,0.526 ± 0.034,0.889 ± 0.008,0.449 ± 0.043,0.609 ± 0.034,5


In [7]:
# Group by hyperparameters and calculate average metrics
df = pd.read_csv('atpbind3d_esm-t33-gvp_stats.csv')

# Define the hyperparameters to group by
hyperparameters = ['model_kwargs.lm_freeze_layer_count', 'model_kwargs.node_h_dim', 
                   'model_kwargs.num_layers', 'model_kwargs.residual', 'cycle_size', 'max_lr']

# Define the metrics to average
metrics = ['mcc', 'micro_auprc', 'micro_auroc', 'sensitivity', 'precision']

# Group by hyperparameters and calculate mean of metrics
grouped_df = df.groupby(hyperparameters)[metrics].agg(['mean', 'std']).reset_index()

# Sort by MCC (you can change this to any other metric)
grouped_df = grouped_df.sort_values(('mcc', 'mean'), ascending=False)

# Display the results
print("Average metrics grouped by hyperparameters:")
grouped_df


Average metrics grouped by hyperparameters:


Unnamed: 0_level_0,model_kwargs.lm_freeze_layer_count,model_kwargs.node_h_dim,model_kwargs.num_layers,model_kwargs.residual,cycle_size,max_lr,mcc,mcc,micro_auprc,micro_auprc,micro_auroc,micro_auroc,sensitivity,sensitivity,precision,precision
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,mean,std,mean,std,mean,std,mean,std,mean,std
2,30,"(256, 16)",3,False,20,0.002,0.6748,0.014721,0.69068,0.017769,0.921,0.011412,0.62138,0.018788,0.76192,0.016466
27,30,"(512, 16)",4,True,20,0.002,0.66924,0.01741,0.6964,0.008604,0.93,0.016658,0.61664,0.027383,0.75606,0.013958
4,30,"(256, 16)",3,True,10,0.002,0.66762,0.012302,0.68628,0.008423,0.94252,0.005644,0.61988,0.023457,0.74902,0.017671
0,30,"(256, 16)",3,False,10,0.002,0.66616,0.022167,0.68806,0.023196,0.93206,0.00651,0.62226,0.04757,0.74448,0.028639
6,30,"(256, 16)",3,True,20,0.002,0.66608,0.012708,0.67666,0.021485,0.93404,0.007257,0.6184,0.022505,0.74748,0.012786
20,30,"(512, 16)",3,True,10,0.002,0.666,0.011365,0.68596,0.00869,0.9345,0.010743,0.63024,0.020691,0.73388,0.016475
22,30,"(512, 16)",3,True,20,0.002,0.66586,0.018281,0.6672,0.012739,0.9129,0.023342,0.60534,0.030487,0.76244,0.010941
39,31,"(512, 16)",3,True,20,0.002,0.66456,0.013554,0.67296,0.020929,0.91418,0.038265,0.61336,0.027262,0.7505,0.023483
12,30,"(256, 16)",4,True,10,0.002,0.66392,0.008846,0.6931,0.019136,0.94198,0.005701,0.61394,0.0309,0.74968,0.047116
43,31,"(512, 16)",4,True,20,0.002,0.66378,0.005257,0.68016,0.012995,0.93078,0.009352,0.59762,0.034352,0.7686,0.037765


### ESM-T33-GEARNET

In [8]:
# Group by hyperparameters and calculate average metrics
df = pd.read_csv('atpbind3d_esm-t33-gearnet_stats.csv')

# Define the hyperparameters to group by
hyperparameters = ['model_kwargs.lm_freeze_layer_count', 'max_slice_length', 'padding']

# Define the metrics to average
metrics = ['mcc', 'micro_auprc', 'micro_auroc', 'sensitivity', 'precision']

# Group by hyperparameters and calculate mean of metrics
grouped_df = df.groupby(hyperparameters)[metrics].agg(
    ['mean', 'std']).reset_index()

# Sort by MCC (you can change this to any other metric)
grouped_df = grouped_df.sort_values(('mcc', 'mean'), ascending=False)

# Display the results
print("Average metrics grouped by hyperparameters:")
grouped_df

Average metrics grouped by hyperparameters:


Unnamed: 0_level_0,model_kwargs.lm_freeze_layer_count,max_slice_length,padding,mcc,mcc,micro_auprc,micro_auprc,micro_auroc,micro_auroc,sensitivity,sensitivity,precision,precision
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,mean,std,mean,std,mean,std,mean,std,mean,std
16,28.0,600.0,50.0,0.680100,0.009383,0.695340,0.011962,0.918480,0.006943,0.633220,0.017891,0.759620,0.030512
37,29.0,600.0,75.0,0.679940,0.007546,0.683640,0.015803,0.920280,0.010449,0.619580,0.010406,0.774800,0.020189
2,27.0,600.0,50.0,0.678500,0.007970,0.689280,0.010896,0.914520,0.005949,0.625800,0.018640,0.764820,0.027292
12,28.0,500.0,50.0,0.676800,0.016594,0.690680,0.011410,0.924160,0.008964,0.619280,0.025742,0.768760,0.022014
31,29.0,500.0,25.0,0.675780,0.016257,0.689780,0.010991,0.922800,0.004794,0.644820,0.011920,0.737440,0.025490
...,...,...,...,...,...,...,...,...,...,...,...,...,...
99,32.0,700.0,25.0,0.620100,,0.646600,,0.909700,,0.563800,,0.715600,
95,32.0,600.0,25.0,0.615800,,0.633500,,0.906300,,0.556400,,0.715600,
104,33.0,500.0,50.0,0.607875,0.014409,0.648975,0.019381,0.932825,0.004376,0.568250,0.037556,0.686800,0.041319
105,33.0,600.0,50.0,0.601850,0.035603,0.635775,0.029424,0.932350,0.006203,0.577875,0.015034,0.663875,0.066945


In [9]:
def analyze_metrics(csv_path, group_by_param, filter_conditions={}, sort_by=('mcc', 'mean'), ascending=False):
    """
    Analyze metrics by grouping on specified hyperparameters and filtering the DataFrame.

    Parameters:
    - group_by_param: str or list, the hyperparameter(s) to group by.
    - filter_conditions: dict, conditions to filter the DataFrame.

    Returns:
    - grouped_df: DataFrame, the grouped and averaged metrics.
    """
    import pandas as pd
    df = pd.read_csv(csv_path)

    # Apply filter conditions
    for column, value in filter_conditions.items():
        if isinstance(value, list):
            df = df[df[column].isin(value)]
        else:
            df = df[df[column] == value]

    # Define the hyperparameters to group by
    hyperparameters = group_by_param if isinstance(group_by_param, list) else [group_by_param]

    # Define the metrics to average
    metrics = ['mcc', 'micro_auprc', 'sensitivity', 'precision']

    # Group by hyperparameters and calculate mean of metrics, including a single count column
    grouped_df = df.groupby(hyperparameters)[metrics].agg(['mean', 'std']).reset_index()
    grouped_df['count'] = df.groupby(hyperparameters).size().values

    # Round the final metrics to the fourth digit
    for metric in metrics:
        grouped_df[(metric, 'mean')] = grouped_df[(metric, 'mean')].round(4)
        grouped_df[(metric, 'std')] = grouped_df[(metric, 'std')].round(4)

    # Sort by MCC (you can change this to any other metric)
    grouped_df = grouped_df.sort_values(sort_by, ascending=ascending)

    # Display the results
    return grouped_df

def print_markdown_table(grouped_df, metrics=['mcc', 'micro_auprc', 'sensitivity', 'precision']):
    """
    Print the markdown formatted table from the grouped DataFrame.

    Parameters:
    - grouped_df: DataFrame, the grouped and averaged metrics.
    - metrics: list, the metrics to include in the table.
    """
    from tabulate import tabulate

    # Format the metrics with mean and std
    for metric in metrics:
        grouped_df[f'{metric}_formatted'] = grouped_df.apply(
            lambda row: f"{row[(metric, 'mean')]:.3f} ± {row[(metric, 'std')]:.3f}", axis=1)
    # Drop the original metric columns
    for metric in metrics:
        grouped_df.drop(columns=[(metric, 'mean'), (metric, 'std')], inplace=True)
    
    # Rename the formatted metric columns to the original metric names
    for metric in metrics:
        grouped_df.rename(columns={f'{metric}_formatted': metric}, inplace=True)

    # Select the columns to display, including non-metric keys
    non_metric_keys = [col for col in grouped_df.columns.levels[0] if col not in metrics and col != 'count' and col != 'index']
    formatted_df = grouped_df[non_metric_keys + [f'{metric}' for metric in metrics] + ['count']]
    formatted_df.columns = non_metric_keys + metrics + ['record_count']
    # Print the markdown table
    markdown_table = tabulate(formatted_df, headers='keys', tablefmt='pipe', showindex=False)
    print(markdown_table)
    

# Review 1-1: Large Dataset

First, check normal esm-t33-gearnet. There is some hyperparameter search.

In [10]:
patp1930_df = analyze_metrics(
    'atpbind3d-1930_esm-t33-gearnet_stats.csv',
    ['model_kwargs.lm_freeze_layer_count', 'max_slice_length'],
    sort_by=('mcc', 'mean'),
    ascending=False
)
print_markdown_table(patp1930_df)
patp1930_df

|   max_slice_length |   model_kwargs.lm_freeze_layer_count | mcc           | micro_auprc   | sensitivity   | precision     |   record_count |
|-------------------:|-------------------------------------:|:--------------|:--------------|:--------------|:--------------|---------------:|
|                600 |                                   28 | 0.727 ± 0.013 | 0.750 ± 0.008 | 0.663 ± 0.011 | 0.821 ± 0.017 |              5 |
|                500 |                                   30 | 0.721 ± 0.008 | 0.745 ± 0.010 | 0.668 ± 0.015 | 0.803 ± 0.017 |              5 |
|                500 |                                   28 | 0.720 ± 0.007 | 0.748 ± 0.008 | 0.666 ± 0.021 | 0.804 ± 0.020 |              5 |
|                600 |                                   30 | 0.716 ± 0.006 | 0.751 ± 0.011 | 0.638 ± 0.024 | 0.830 ± 0.026 |              5 |
|                400 |                                   30 | 0.700 ± nan   | 0.730 ± nan   | 0.623 ± nan   | 0.812 ± nan   |              1 |

Unnamed: 0,model_kwargs.lm_freeze_layer_count,max_slice_length,count,mcc,micro_auprc,sensitivity,precision
,,,,,,,
1.0,28.0,600.0,5.0,0.727 ± 0.013,0.750 ± 0.008,0.663 ± 0.011,0.821 ± 0.017
4.0,30.0,500.0,5.0,0.721 ± 0.008,0.745 ± 0.010,0.668 ± 0.015,0.803 ± 0.017
0.0,28.0,500.0,5.0,0.720 ± 0.007,0.748 ± 0.008,0.666 ± 0.021,0.804 ± 0.020
5.0,30.0,600.0,5.0,0.716 ± 0.006,0.751 ± 0.011,0.638 ± 0.024,0.830 ± 0.026
3.0,30.0,400.0,1.0,0.700 ± nan,0.730 ± nan,0.623 ± nan,0.812 ± nan
2.0,30.0,300.0,2.0,0.695 ± 0.038,0.726 ± 0.024,0.618 ± 0.030,0.808 ± 0.043


In [11]:
patp1930_df = analyze_metrics(
    'atpbind3d-1930_esm-t33-gearnet-resiboost_stats.csv',
    ['boost_negative_use_ratio', 'boost_mask_positive'],
    sort_by=('mcc', 'mean'),
    ascending=False
)
print_markdown_table(patp1930_df)
patp1930_df

| boost_mask_positive   |   boost_negative_use_ratio | mcc           | micro_auprc   | sensitivity   | precision     |   record_count |
|:----------------------|---------------------------:|:--------------|:--------------|:--------------|:--------------|---------------:|
| True                  |                        0.5 | 0.733 ± 0.006 | 0.777 ± 0.005 | 0.674 ± 0.017 | 0.821 ± 0.021 |              5 |
| False                 |                        0.5 | 0.731 ± 0.004 | 0.775 ± 0.004 | 0.665 ± 0.013 | 0.829 ± 0.014 |              5 |


Unnamed: 0,boost_negative_use_ratio,boost_mask_positive,count,mcc,micro_auprc,sensitivity,precision
,,,,,,,
1.0,0.5,True,5.0,0.733 ± 0.006,0.777 ± 0.005,0.674 ± 0.017,0.821 ± 0.021
0.0,0.5,False,5.0,0.731 ± 0.004,0.775 ± 0.004,0.665 ± 0.013,0.829 ± 0.014


# Review 1-2: Hyperparameter search

In [31]:
df = analyze_metrics('atpbind3d_esm-t33-gearnet_stats.csv', 
                'model_kwargs.lm_freeze_layer_count', 
                filter_conditions={'padding': 50, 'max_slice_length': [500],},
                sort_by='model_kwargs.lm_freeze_layer_count',
                ascending=True
)
print_markdown_table(df)
df

|   model_kwargs.lm_freeze_layer_count | mcc           | micro_auprc   | sensitivity   | precision     |   record_count |
|-------------------------------------:|:--------------|:--------------|:--------------|:--------------|---------------:|
|                                   27 | 0.673 ± 0.015 | 0.696 ± 0.017 | 0.626 ± 0.023 | 0.752 ± 0.019 |              5 |
|                                   28 | 0.677 ± 0.017 | 0.691 ± 0.011 | 0.619 ± 0.026 | 0.769 ± 0.022 |              5 |
|                                   29 | 0.670 ± 0.007 | 0.675 ± 0.006 | 0.611 ± 0.013 | 0.763 ± 0.008 |              5 |
|                                   30 | 0.675 ± 0.009 | 0.686 ± 0.007 | 0.628 ± 0.013 | 0.755 ± 0.028 |              5 |
|                                   31 | 0.668 ± 0.013 | 0.686 ± 0.005 | 0.615 ± 0.022 | 0.755 ± 0.033 |              5 |
|                                   32 | 0.652 ± 0.019 | 0.670 ± 0.014 | 0.578 ± 0.028 | 0.766 ± 0.022 |              5 |
|                       

Unnamed: 0,model_kwargs.lm_freeze_layer_count,count,mcc,micro_auprc,sensitivity,precision
,,,,,,
0.0,27.0,5.0,0.673 ± 0.015,0.696 ± 0.017,0.626 ± 0.023,0.752 ± 0.019
1.0,28.0,5.0,0.677 ± 0.017,0.691 ± 0.011,0.619 ± 0.026,0.769 ± 0.022
2.0,29.0,5.0,0.670 ± 0.007,0.675 ± 0.006,0.611 ± 0.013,0.763 ± 0.008
3.0,30.0,5.0,0.675 ± 0.009,0.686 ± 0.007,0.628 ± 0.013,0.755 ± 0.028
4.0,31.0,5.0,0.668 ± 0.013,0.686 ± 0.005,0.615 ± 0.022,0.755 ± 0.033
5.0,32.0,5.0,0.652 ± 0.019,0.670 ± 0.014,0.578 ± 0.028,0.766 ± 0.022
6.0,33.0,4.0,0.608 ± 0.014,0.649 ± 0.019,0.568 ± 0.038,0.687 ± 0.041


In [23]:
df = analyze_metrics('atpbind3d_esm-t33-gearnet_stats.csv',
                     'max_slice_length',
                     filter_conditions={'padding': 50,
                                        'model_kwargs.lm_freeze_layer_count': [30],
                                        'valid_fold': [0,1,2,3,4],
                                        },
                     sort_by='max_slice_length',
                     ascending=True
                     )
print_markdown_table(df)
df

|   max_slice_length | mcc           | micro_auprc   | sensitivity   | precision     |   record_count |
|-------------------:|:--------------|:--------------|:--------------|:--------------|---------------:|
|                300 | 0.650 ± nan   | 0.662 ± nan   | 0.583 ± nan   | 0.756 ± nan   |              1 |
|                400 | 0.673 ± 0.011 | 0.682 ± 0.003 | 0.619 ± 0.026 | 0.761 ± 0.033 |              5 |
|                500 | 0.675 ± 0.009 | 0.686 ± 0.007 | 0.628 ± 0.013 | 0.755 ± 0.028 |              5 |
|                600 | 0.673 ± 0.009 | 0.681 ± 0.008 | 0.614 ± 0.020 | 0.768 ± 0.019 |              5 |
|                700 | 0.640 ± nan   | 0.659 ± nan   | 0.592 ± nan   | 0.724 ± nan   |              1 |


Unnamed: 0,max_slice_length,count,mcc,micro_auprc,sensitivity,precision
,,,,,,
0.0,300.0,1.0,0.650 ± nan,0.662 ± nan,0.583 ± nan,0.756 ± nan
1.0,400.0,5.0,0.673 ± 0.011,0.682 ± 0.003,0.619 ± 0.026,0.761 ± 0.033
2.0,500.0,5.0,0.675 ± 0.009,0.686 ± 0.007,0.628 ± 0.013,0.755 ± 0.028
3.0,600.0,5.0,0.673 ± 0.009,0.681 ± 0.008,0.614 ± 0.020,0.768 ± 0.019
4.0,700.0,1.0,0.640 ± nan,0.659 ± nan,0.592 ± nan,0.724 ± nan


# Review 1-3: GVP

In [28]:
df = analyze_metrics('atpbind3d_esm-t33-gvp_stats.csv',
                     ['model_kwargs.lm_freeze_layer_count', 'model_kwargs.node_h_dim', 'model_kwargs.num_layers', 'model_kwargs.residual', 'cycle_size', 'max_lr'],
                     filter_conditions={'valid_fold': [0, 1, 2, 3, 4]},
                     sort_by=('mcc', 'mean'),
                     ascending=False
                     )
print_markdown_table(df)
df

|   cycle_size |   max_lr |   model_kwargs.lm_freeze_layer_count | model_kwargs.node_h_dim   |   model_kwargs.num_layers | model_kwargs.residual   | mcc           | micro_auprc   | sensitivity   | precision     |   record_count |
|-------------:|---------:|-------------------------------------:|:--------------------------|--------------------------:|:------------------------|:--------------|:--------------|:--------------|:--------------|---------------:|
|           20 |    0.002 |                                   30 | (256, 16)                 |                         3 | False                   | 0.675 ± 0.015 | 0.691 ± 0.018 | 0.621 ± 0.019 | 0.762 ± 0.017 |              5 |
|           20 |    0.002 |                                   30 | (512, 16)                 |                         4 | True                    | 0.669 ± 0.017 | 0.696 ± 0.009 | 0.617 ± 0.027 | 0.756 ± 0.014 |              5 |
|           10 |    0.002 |                                   30 | (256, 16)    

Unnamed: 0,model_kwargs.lm_freeze_layer_count,model_kwargs.node_h_dim,model_kwargs.num_layers,model_kwargs.residual,cycle_size,max_lr,count,mcc,micro_auprc,sensitivity,precision
,,,,,,,,,,,
2.0,30.0,"(256, 16)",3.0,False,20.0,0.002,5.0,0.675 ± 0.015,0.691 ± 0.018,0.621 ± 0.019,0.762 ± 0.017
27.0,30.0,"(512, 16)",4.0,True,20.0,0.002,5.0,0.669 ± 0.017,0.696 ± 0.009,0.617 ± 0.027,0.756 ± 0.014
4.0,30.0,"(256, 16)",3.0,True,10.0,0.002,5.0,0.668 ± 0.012,0.686 ± 0.008,0.620 ± 0.024,0.749 ± 0.018
0.0,30.0,"(256, 16)",3.0,False,10.0,0.002,5.0,0.666 ± 0.022,0.688 ± 0.023,0.622 ± 0.048,0.745 ± 0.029
6.0,30.0,"(256, 16)",3.0,True,20.0,0.002,5.0,0.666 ± 0.013,0.677 ± 0.021,0.618 ± 0.022,0.748 ± 0.013
20.0,30.0,"(512, 16)",3.0,True,10.0,0.002,5.0,0.666 ± 0.011,0.686 ± 0.009,0.630 ± 0.021,0.734 ± 0.017
22.0,30.0,"(512, 16)",3.0,True,20.0,0.002,5.0,0.666 ± 0.018,0.667 ± 0.013,0.605 ± 0.030,0.762 ± 0.011
39.0,31.0,"(512, 16)",3.0,True,20.0,0.002,5.0,0.665 ± 0.014,0.673 ± 0.021,0.613 ± 0.027,0.750 ± 0.024
12.0,30.0,"(256, 16)",4.0,True,10.0,0.002,5.0,0.664 ± 0.009,0.693 ± 0.019,0.614 ± 0.031,0.750 ± 0.047


# Review 1-4: WeightedEntropyLoss
Params to check: 
Weighted BCE:
- pos_weight_factor
Focal Loss:
- task_kwargs.criterion
- task_kwargs.focal_loss_gamma
- task_kwargs.focal_loss_alpha

In [33]:
df = analyze_metrics('atpbind3d_esm-t33-gearnet_stats.csv',
                     ['pos_weight_factor'],
                     filter_conditions={'pos_weight_factor': [0.25, 1, 4, 16]},
                     sort_by=('mcc', 'mean'),
                     ascending=False
                     )
print_markdown_table(df)
df

|   pos_weight_factor | mcc           | micro_auprc   | sensitivity   | precision     |   record_count |
|--------------------:|:--------------|:--------------|:--------------|:--------------|---------------:|
|                4    | 0.671 ± 0.011 | 0.681 ± 0.015 | 0.635 ± 0.014 | 0.740 ± 0.026 |              5 |
|                1    | 0.669 ± 0.012 | 0.689 ± 0.017 | 0.628 ± 0.024 | 0.743 ± 0.017 |              5 |
|                0.25 | 0.662 ± 0.020 | 0.675 ± 0.014 | 0.611 ± 0.023 | 0.749 ± 0.046 |              5 |
|               16    | 0.658 ± 0.004 | 0.669 ± 0.012 | 0.645 ± 0.016 | 0.702 ± 0.014 |              5 |


Unnamed: 0,pos_weight_factor,count,mcc,micro_auprc,sensitivity,precision
,,,,,,
2.0,4.0,5.0,0.671 ± 0.011,0.681 ± 0.015,0.635 ± 0.014,0.740 ± 0.026
1.0,1.0,5.0,0.669 ± 0.012,0.689 ± 0.017,0.628 ± 0.024,0.743 ± 0.017
0.0,0.25,5.0,0.662 ± 0.020,0.675 ± 0.014,0.611 ± 0.023,0.749 ± 0.046
3.0,16.0,5.0,0.658 ± 0.004,0.669 ± 0.012,0.645 ± 0.016,0.702 ± 0.014


In [36]:
df = analyze_metrics('atpbind3d_esm-t33-gearnet_stats.csv',
                     ['task_kwargs.focal_loss_gamma', 'task_kwargs.focal_loss_alpha'],
                     {'task_kwargs.criterion': ['focal']},
                     sort_by=('mcc', 'mean'),
                     ascending=False
                     )
print_markdown_table(df)
df

|   task_kwargs.focal_loss_alpha |   task_kwargs.focal_loss_gamma | mcc           | micro_auprc   | sensitivity   | precision     |   record_count |
|-------------------------------:|-------------------------------:|:--------------|:--------------|:--------------|:--------------|---------------:|
|                           0.25 |                              2 | 0.674 ± 0.007 | 0.675 ± 0.004 | 0.625 ± 0.031 | 0.757 ± 0.023 |              5 |
|                           0.3  |                              1 | 0.672 ± 0.014 | 0.681 ± 0.014 | 0.617 ± 0.024 | 0.762 ± 0.015 |              5 |
|                           0.2  |                              2 | 0.671 ± 0.011 | 0.685 ± 0.013 | 0.639 ± 0.008 | 0.735 ± 0.027 |              5 |
|                           0.25 |                              1 | 0.670 ± 0.008 | 0.690 ± 0.008 | 0.632 ± 0.011 | 0.741 ± 0.015 |              5 |
|                           0.3  |                              2 | 0.669 ± 0.011 | 0.689 ± 0.019 | 0.627 

Unnamed: 0,task_kwargs.focal_loss_gamma,task_kwargs.focal_loss_alpha,count,mcc,micro_auprc,sensitivity,precision
,,,,,,,
4.0,2.0,0.25,5.0,0.674 ± 0.007,0.675 ± 0.004,0.625 ± 0.031,0.757 ± 0.023
2.0,1.0,0.3,5.0,0.672 ± 0.014,0.681 ± 0.014,0.617 ± 0.024,0.762 ± 0.015
3.0,2.0,0.2,5.0,0.671 ± 0.011,0.685 ± 0.013,0.639 ± 0.008,0.735 ± 0.027
1.0,1.0,0.25,5.0,0.670 ± 0.008,0.690 ± 0.008,0.632 ± 0.011,0.741 ± 0.015
5.0,2.0,0.3,5.0,0.669 ± 0.011,0.689 ± 0.019,0.627 ± 0.019,0.745 ± 0.017
7.0,3.0,0.25,5.0,0.667 ± 0.018,0.684 ± 0.030,0.631 ± 0.036,0.736 ± 0.022
6.0,3.0,0.2,5.0,0.666 ± 0.010,0.684 ± 0.007,0.621 ± 0.027,0.746 ± 0.023
8.0,3.0,0.3,5.0,0.663 ± 0.008,0.684 ± 0.010,0.609 ± 0.025,0.751 ± 0.017
0.0,1.0,0.2,5.0,0.662 ± 0.019,0.690 ± 0.020,0.621 ± 0.036,0.738 ± 0.024
