Here we aggregate and plot the results of the functional benchmarks.

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
benchmark_csvs = [
    "_output/correlations_shane.csv",
]
# load and stack the data
dfs = [pd.read_csv(csv) for csv in benchmark_csvs]
df = pd.concat(dfs)
# sort by dataset and model
df = df.sort_values(['Dataset', 'Model']).reset_index(drop=True)
# Define custom sort order for Model
model_order = ['AbLang2', 'ESM', 'ProGen2', 'DASM']

# Sort by Dataset and then Model using the custom order
df = df.sort_values(['Dataset', 'Model'], 
                   key=lambda x: x.map({v: i for i, v in enumerate(model_order)} if x.name == 'Model' else x)
                  ).reset_index(drop=True)
df

Unnamed: 0,Dataset,Model,Correlation
0,Shane. Trast. zero 119,AbLang2,0.262539
1,Shane. Trast. zero 120,AbLang2,0.166075
2,Shane. Trast. zero 119,ESM,0.248102
3,Shane. Trast. zero 120,ESM,0.336645
4,Shane. Trast. zero 119,ProGen2,0.074493
5,Shane. Trast. zero 120,ProGen2,0.051981
6,Shane. Trast. zero 119,DASM,0.458353
7,Shane. Trast. zero 120,DASM,0.517775


In [4]:
# Pivot the dataframe to get models as columns
df_wide = df.pivot(index='Dataset', columns='Model', values='Correlation')

# Reorder columns
df_wide = df_wide[['AbLang2', 'DASM', 'ESM', 'ProGen2']]

# Create a formatter that only bolds the DASM column
def format_value(x, column):
    if pd.isna(x):
        return '-'
    if column == 'DASM':
        return f"\\textbf{{{x:.3f}}}"
    return f"{x:.3f}"

formatters = {col: lambda x, col=col: format_value(x, col) for col in df_wide.columns}

# Format the table
latex_table = df_wide.to_latex(
    formatters=formatters,
    caption="Correlation of models with functional predictions",
    label="tab:model-comparison",
    position="t",
    column_format="lccccc",
    escape=False,
    bold_rows=False
)

print(latex_table)

\begin{table}[t]
\caption{Correlation of models with functional predictions}
\label{tab:model-comparison}
\begin{tabular}{lccccc}
\toprule
Model & AbLang2 & DASM & ESM & ProGen2 \\
Dataset &  &  &  &  \\
\midrule
Shane. Trast. zero 119 & 0.263 & \textbf{0.458} & 0.248 & 0.074 \\
Shane. Trast. zero 120 & 0.166 & \textbf{0.518} & 0.337 & 0.052 \\
\bottomrule
\end{tabular}
\end{table}

