In [35]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate

In [36]:
# load data
df = pd.read_csv('./summary_results.csv')
print(df.columns)
metric_names = ['accuracy', 'f1_score', 'precision', 'recall', 'roc_auc']
models = df['model'].unique()

Index(['dataset', 'seed', 'model', 'accuracy', 'f1_score', 'precision',
       'recall', 'roc_auc'],
      dtype='object')


In [37]:
# Group and summarize: mean ± std per model per dataset
def format_metric(series):
    return f"{series.mean():.2f} ± {series.std():.2f}"

# Get unique datasets
datasets = df['dataset'].unique()

# Loop through each dataset and create separate tables
for dataset in datasets:
    # Filter by dataset
    sub_df = df[df['dataset'] == dataset]

    # Group by model and compute formatted metrics
    summary = sub_df.groupby('model')[metric_names].agg(['mean', 'std'])

    # Create formatted mean ± std strings
    formatted = pd.DataFrame(index=summary.index)
    for metric in metric_names:
        metric_series = sub_df.groupby('model')[metric].agg(format_metric)
        formatted[metric] = metric_series

    # Display in console
    print(f"\n=== Dataset: {dataset} ===")
    print(tabulate(formatted.reset_index(), headers='keys', tablefmt='github', showindex=False))

    # store into a csv file
    formatted.to_csv(f'./csv/summary_{dataset}.csv')


=== Dataset: ALL ===
| model            | accuracy    | f1_score    | precision   | recall      | roc_auc     |
|------------------|-------------|-------------|-------------|-------------|-------------|
| logistic         | 0.72 ± 0.00 | 0.73 ± 0.00 | 0.70 ± 0.00 | 0.76 ± 0.01 | 0.80 ± 0.00 |
| lstm             | 0.74 ± 0.00 | 0.74 ± 0.00 | 0.73 ± 0.01 | 0.75 ± 0.01 | 0.80 ± 0.00 |
| lstm-features    | 0.73 ± 0.01 | 0.73 ± 0.01 | 0.71 ± 0.01 | 0.76 ± 0.02 | 0.81 ± 0.00 |
| lstmfcn          | 0.72 ± 0.00 | 0.73 ± 0.00 | 0.71 ± 0.01 | 0.76 ± 0.02 | 0.80 ± 0.00 |
| lstmfcn-features | 0.72 ± 0.01 | 0.73 ± 0.01 | 0.70 ± 0.01 | 0.76 ± 0.02 | 0.80 ± 0.00 |
| randomforest     | 0.71 ± 0.01 | 0.74 ± 0.01 | 0.67 ± 0.01 | 0.83 ± 0.01 | 0.79 ± 0.00 |
| svm              | 0.72 ± 0.00 | 0.75 ± 0.00 | 0.67 ± 0.01 | 0.86 ± 0.00 | 0.80 ± 0.00 |
| vivit            | 0.69 ± 0.01 | 0.71 ± 0.01 | 0.67 ± 0.02 | 0.77 ± 0.04 | 0.76 ± 0.01 |

=== Dataset: CF ===
| model            | accuracy    | f1_score    