In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

### View classification results

In [6]:
def combine_stats(list_of_paths):
    # read all results files and combine
    df = pd.concat([pd.read_csv(path) for path in list_of_paths]).reset_index(drop=True)

    # get means and errors
    means = df.groupby("model").mean()
    sem = df.groupby("model").sem()

    # combine in a readable table format
    combined = means.copy()
    for i in means.index:
        for j in [col for col in means.columns if "test" in col]:
            combined.loc[i,j] = f"{str(round(means.loc[i,j], 4))} (\xB1 {str(round(sem.loc[i,j], 4))})"

    return combined

In [15]:
models = ['IgBERT', 'IgT5', 'AbLang2', 'AntiBERTa2', 'CurrAb']

In [None]:
# paired classification results
paired3_res = combine_stats([f'./results/{m}_HD-Flu-CoV-paired_5fold-5ep_results.csv' for m in models])
paired3_res = paired3_res.drop(columns=['itr', 'test_loss', 'test_macro-precision', 'test_micro-precision', 'test_macro-recall', 'test_micro-recall', 'test_micro-f1'])
paired3_res = paired3_res.sort_values(by="model", key=lambda column: column.map(lambda e: models.index(e)))
paired3_res

In [None]:
# unpaired classification results
paired3_res = combine_stats([f'./results/{m}_HD-Flu-CoV-unpaired_5fold-5ep_results.csv' for m in models])
paired3_res = paired3_res.drop(columns=['itr', 'test_loss', 'test_macro-precision', 'test_micro-precision', 'test_macro-recall', 'test_micro-recall', 'test_micro-f1'])
paired3_res = paired3_res.sort_values(by="model", key=lambda column: column.map(lambda e: models.index(e)))
paired3_res

### Accuracy bar plot

In [11]:
def stats_for_plot(list_of_paths):
    df = pd.concat([pd.read_csv(path) for path in list_of_paths]).reset_index(drop=True)
    return df

In [13]:
# load 
p3r = stats_for_plot([f'./results/{m}_HD-Flu-CoV-paired_5fold-5ep_results.csv' for m in models])

In [14]:
# define color palette
color_palette = sns.color_palette("hls", 8)
color_mapping = {'IgBERT': color_palette[3],
                 'IgT5': color_palette[4],
                 'AbLang2': color_palette[0],
                 'AntiBERTa2': color_palette[6],
                 'CurrAb': color_palette[5]}

In [15]:
# sort models
p3r['model'] = pd.Categorical(p3r['model'], categories=color_mapping.keys(), ordered=True)
p3r = p3r.sort_values(by='model')

In [None]:
# barplot w ylim
fig, ax = plt.subplots(figsize=(4.5, 5.8))
sns.barplot(
    data=p3r,
    x="model", y="test_accuracy",
    errorbar="se",
    hue="model",
    palette=color_mapping.values()
)
plt.ylim(0.3, 0.75)

# random guessing line
random_guess_accuracy = 1 / 3  # 33% for a three-way classification
plt.axhline(y=random_guess_accuracy, color='black', linestyle='--', label='Random Guessing')
plt.legend(loc='upper left', fontsize = 12)

# labels and ticks
ax.set_xlabel(None)
plt.ylabel("Average Accuracy", fontsize=14)
ax.xaxis.set_tick_params(labelsize = 13)
ax.yaxis.set_tick_params(labelsize = 11)
plt.xticks(rotation=60)

plt.tight_layout()

plt.savefig("./results/3-paired-class_ylim.png", dpi=300)