In [1]:
import os
from pathlib import Path

import pandas as pd
import plotly.express as px
from sklearn.metrics import balanced_accuracy_score, f1_score

In [2]:
data_dir = os.path.join("..", "assets", "inference-results")

embedding_types = [d for d in os.listdir(data_dir) if d != "beta"]

In [3]:
def stack_dfs(data_dir):
    dfs = []

    for embedding_type in embedding_types:
        embedding_dir = os.path.join(data_dir, embedding_type)
        training_methods = os.listdir(embedding_dir)

        for training_method in training_methods:
            method_dir = os.path.join(embedding_dir, training_method)
            trial_nums = sorted(os.listdir(method_dir), key=lambda x: int(x.split("-")[-1]))

            for trial in trial_nums:
                trial_path = os.path.join(method_dir, trial)
                result_files = os.listdir(trial_path)
                result_paths = [os.path.join(trial_path, f) for f in result_files]

                for result in result_paths:
                    df = pd.read_csv(result)
                    df["trial"] = [trial for _ in range(len(df))]
                    df["model"] = [Path(result).stem for _ in range(len(df))]
                    df["model"] = df["model"].map(lambda x: "-".join(Path(x).stem.split("-")[:-1]))

                    dfs.append(df)

    stacked_df = pd.concat(dfs, axis=0)

    return stacked_df


def compute_metrics(group):
    balanced_acc = balanced_accuracy_score(group['target'], group['prediction'])
    weighted_f1 = f1_score(group['target'], group['prediction'], average='weighted')
    
    return pd.Series({'balanced_accuracy': balanced_acc, 'weighted_f1': weighted_f1})

In [4]:
df = stack_dfs(data_dir)

df.head()

Unnamed: 0,patient_id,loss,prediction,target,trial,model
0,14917C,0.316789,1,1,split-1,attention-mil
1,16421,0.307176,0,0,split-1,attention-mil
2,11293,0.728464,1,0,split-1,attention-mil
3,14696B,1.368097,0,1,split-1,attention-mil
4,13645,1.185069,1,0,split-1,attention-mil


In [5]:
results = df.groupby(['trial', 'model']).apply(compute_metrics).reset_index()

results

  results = df.groupby(['trial', 'model']).apply(compute_metrics).reset_index()


Unnamed: 0,trial,model,balanced_accuracy,weighted_f1
0,split-1,attention-mil,0.627976,0.61104
1,split-1,max-mil,0.559524,0.601808
2,split-1,mean-mil,0.660714,0.684211
3,split-1,resnet18,0.738095,0.761189
4,split-2,attention-mil,0.866071,0.86929
5,split-2,max-mil,0.529762,0.561
6,split-2,mean-mil,0.886905,0.894737
7,split-2,resnet18,0.732143,0.739893
8,split-3,attention-mil,0.690476,0.689593
9,split-3,max-mil,0.702381,0.731984


In [6]:
fig = px.histogram(
    results, x="trial", 
    y="balanced_accuracy", 
    color="model", 
    barmode="group", 
    height=800,
    width=1500,
    color_discrete_sequence=px.colors.qualitative.Prism,
    text_auto=".3f"
    )

fig.update_yaxes(range=[0, 0.95], title_text="Balanced Accuracy")

fig.show()

In [7]:
fig = px.histogram(
    results, x="trial", 
    y="weighted_f1", 
    color="model", 
    barmode="group", 
    height=800,
    width=1500,
    color_discrete_sequence=px.colors.qualitative.Prism,
    text_auto=".3f"
    )

fig.update_yaxes(range=[0, 0.95], title_text="Weighted F1")

fig.show()

In [8]:
avg_balanced_accuracy = results.groupby("model")["balanced_accuracy"].mean().reset_index()
avg_f1 = results.groupby("model")["weighted_f1"].mean().reset_index()

In [9]:
avg_balanced_accuracy

Unnamed: 0,model,balanced_accuracy
0,attention-mil,0.788095
1,max-mil,0.616667
2,mean-mil,0.784524
3,resnet18,0.75119


In [10]:
fig = px.histogram(
    avg_balanced_accuracy, x="model", 
    y="balanced_accuracy", 
    height=500,
    width=800,
    color_discrete_sequence=px.colors.qualitative.Prism,
    text_auto=".3f"
    )

fig.update_yaxes(range=[0, 0.95], title_text="Balanced Accuracy")

fig.show()

In [11]:
fig = px.histogram(
    avg_f1, x="model", 
    y="weighted_f1", 
    height=500,
    width=800,
    color_discrete_sequence=px.colors.qualitative.Prism,
    text_auto=".3f"
    )

fig.update_yaxes(range=[0, 0.95], title_text="Weighted F1")

fig.show()