In [1]:
import os
from pathlib import Path

import pandas as pd
import plotly.express as px
from sklearn.metrics import balanced_accuracy_score, f1_score

In [2]:
data_dir = os.path.join("..", "assets", "inference-results")

trial_nums = sorted(os.listdir(data_dir), key=lambda x: int(x.split("-")[-1]))

In [3]:
def stack_dfs(data_dir, trial_nums):
    dfs = []

    for trial in trial_nums:
        trial_path = os.path.join(data_dir, trial)
        result_files = os.listdir(trial_path)
        result_paths = [os.path.join(trial_path, f) for f in result_files]

        for result in result_paths:
            df = pd.read_csv(result)
            df["trial"] = [trial for _ in range(len(df))]
            df["model"] = [Path(result).stem for _ in range(len(df))]
            df["model"] = df["model"].map(lambda x: "MIL" if "mil" in x else "CNN")

            dfs.append(df)

    stacked_df = pd.concat(dfs, axis=0)

    return stacked_df


def compute_metrics(group):
    balanced_acc = balanced_accuracy_score(group['target'], group['prediction'])
    weighted_f1 = f1_score(group['target'], group['prediction'], average='weighted')
    
    return pd.Series({'balanced_accuracy': balanced_acc, 'weighted_f1': weighted_f1})

In [4]:
df = stack_dfs(data_dir, trial_nums)

df.head()

Unnamed: 0,patient_id,loss,prediction,target,trial,model
0,11189,0.699046,0,1,trial-1,MIL
1,11783,0.312435,1,1,trial-1,MIL
2,12186,0.193176,0,0,trial-1,MIL
3,15513,0.358032,1,1,trial-1,MIL
4,11749,0.720617,0,1,trial-1,MIL


In [5]:
results = df.groupby(['trial', 'model']).apply(compute_metrics).reset_index()

results

  results = df.groupby(['trial', 'model']).apply(compute_metrics).reset_index()


Unnamed: 0,trial,model,balanced_accuracy,weighted_f1
0,trial-1,CNN,0.761905,0.741265
1,trial-1,MIL,0.64881,0.637771
2,trial-2,CNN,0.702381,0.731984
3,trial-2,MIL,0.672619,0.715577
4,trial-3,CNN,0.60119,0.643609
5,trial-3,MIL,0.630952,0.67004
6,trial-4,CNN,0.785714,0.829346
7,trial-4,MIL,0.672619,0.715577
8,trial-5,CNN,0.630952,0.67004
9,trial-5,MIL,0.619048,0.63585


In [9]:
fig = px.histogram(
    results, x="trial", 
    y="balanced_accuracy", 
    color="model", 
    barmode="group", 
    height=500,
    width=800,
    color_discrete_sequence=px.colors.qualitative.Prism,
    text_auto=".3f"
    )

fig.update_yaxes(range=[0, 0.85], title_text="Balanced Accuracy")

fig.show()

In [10]:
fig = px.histogram(
    results, x="trial", 
    y="weighted_f1", 
    color="model", 
    barmode="group", 
    height=500,
    width=800,
    color_discrete_sequence=px.colors.qualitative.Prism,
    text_auto=".3f"
    )

fig.update_yaxes(range=[0, 0.85], title_text="Weighted F1")

fig.show()