# Summarize Results

In [114]:
from functools import reduce

In [115]:
import polars as pl
import pandas as pd

## Read BDS and HS-DS results

In [116]:
files = [
    "../results/bds_hsds.csv",
    "../results/bds_human.csv",
]

In [117]:
for i,p in enumerate(files):
    if i == 0:
        df = pl.read_csv(p)
    else:
        tmp = pl.read_csv(p)
        df = pl.concat([df, tmp])

In [118]:
hsds_mcmc_only = df.filter(pl.col("method").str.contains("HSDS-MCMC"))
bds_hsds_em = df.filter(~(pl.col("method").str.contains("HSDS-MCMC")))

In [119]:
bds_hsds_em = bds_hsds_em.drop(["uc_count"])

## Calculate the scores for $\mbox{HS-DS}_{MCMC}$.
The score is based on the case (or average of cases) with the fewest convergent variables.

In [120]:
hsds_mcmc_result = hsds_mcmc_only.group_by("num_ai","method").agg(
    pl.col("accuracy").filter(pl.col("uc_count") == pl.col("uc_count").min()).mean().alias("min_uc_accuracy_mean"),
    pl.col("recall").filter(pl.col("uc_count") == pl.col("uc_count").min()).mean().alias("min_uc_recall_mean")
).sort("num_ai")

In [121]:
hsds_mcmc_result

num_ai,method,min_uc_accuracy_mean,min_uc_recall_mean
i64,str,f64,f64
5,"""HSDS-MCMC(iter_sampling=3000)""",0.905074,0.980583
10,"""HSDS-MCMC(iter_sampling=3000)""",0.905074,0.980583


## Calculate the scores for other methods.

In [122]:
files = [
    "../results/catd_la_lfc_minmax_pmcrh_zc.csv",
    "../results/cbcc.csv",
    "../results/emds_glad_mace_mmsr_mv_onecoin.csv",
]
df = bds_hsds_em

In [123]:
for i,p in enumerate(files):
    tmp = pl.read_csv(p)
    df = pl.concat([df, tmp])

In [124]:
df_tab = (
    df.group_by(['num_ai', 'method'])
    .agg([
        pl.col('accuracy').mean().alias('accuracy_mean'),
        pl.col('accuracy').std().alias('accuracy_std'),
        (pl.col('accuracy').std() / pl.col('accuracy').count().sqrt()).alias('accuracy_se'),
        pl.col('recall').mean().alias('recall_mean'),
        pl.col('recall').std().alias('recall_std'),
        (pl.col('recall').std() / pl.col('recall').count().sqrt()).alias('recall_se'),
    ])
    .sort(['num_ai', 'method'])
)

In [125]:
cols = [
    ('accuracy_mean', 'num_ai=0_mean_acc'),
    ('accuracy_se', 'num_ai=0_se_acc'),
    ('recall_mean', 'num_ai=0_mean_recall'),
    ('recall_se', 'num_ai=0_se_recall'),
    ('accuracy_mean', 'num_ai=5_mean_acc'),
    ('accuracy_se', 'num_ai=5_se_acc'),
    ('recall_mean', 'num_ai=5_mean_recall'),
    ('recall_se', 'num_ai=5_se_recall'),
    ('accuracy_mean', 'num_ai=10_mean_acc'),
    ('accuracy_se', 'num_ai=10_se_acc'),
    ('recall_mean', 'num_ai=10_mean_recall'),
    ('recall_se', 'num_ai=10_se_recall'),
]

dfs = []
for num_ai in [5, 10, 0]:
    df_pivot = (
        df_tab.filter(pl.col("num_ai") == num_ai)
        .select([
            pl.col("method"),
            pl.col("accuracy_mean"),
            pl.col("accuracy_se"),
            pl.col("recall_mean"),
            pl.col("recall_se"),
        ])
        .rename({
            "accuracy_mean": f"num_ai={num_ai}_mean_acc",
            "accuracy_se": f"num_ai={num_ai}_se_acc",
            "recall_mean": f"num_ai={num_ai}_mean_recall",
            "recall_se": f"num_ai={num_ai}_se_recall",
        })
    )
    dfs.append(df_pivot)

tab = reduce(lambda left, right: left.join(right, on="method", how="full"), dfs)
tab = tab.sort("method")
tab = tab.drop(["method_right"])

In [126]:
num_cols = [col for col in tab.columns if col != 'method']

In [127]:
tab_num_fmt = tab.with_columns([
    pl.col(c).map_elements(lambda x: f"{x:.2f}" if x is not None else "-", return_dtype=pl.Utf8)
    for c in num_cols
])

In [128]:
for i in [0,5,10]:
    tab_num_fmt = tab_num_fmt.with_columns(
        (pl.lit("$") + pl.col(f"num_ai={i}_mean_acc") + pl.lit(" \pm ") + pl.col(f"num_ai={i}_se_acc") + pl.lit("$")).alias(f"num_ai={i}_acc_se")
    )
    tab_num_fmt = tab_num_fmt.with_columns(
        (pl.lit("$") + pl.col(f"num_ai={i}_mean_recall") + pl.lit(" \pm ") + pl.col(f"num_ai={i}_se_recall") + pl.lit("$")).alias(f"num_ai={i}_recall_se")
    )

In [129]:
tab_num_fmt = tab_num_fmt.select("method", "num_ai=0_acc_se", "num_ai=0_recall_se", 
                                 "num_ai=5_acc_se", "num_ai=5_recall_se",
                                   "num_ai=10_acc_se", "num_ai=10_recall_se")

In [130]:
pl.Config.set_tbl_rows(-1)

polars.config.Config

In [131]:
order = [
    'EMDS',
    'BDS(iter_sampling=3000)',
    'CATD',
    'CBCC_M=2',
    'CBCC_M=4',
    'CBCC_M=8',
    'GLAD',
    'LA1',
    'LA2',
    'LFC',
    'MACE-EM',
    'minmax',
    'MMSR',
    'MV',
    'OneCoin',
    'PM-CRH',
    'ZC',
    'HSDS-EM',
]

In [132]:
tab_num_fmt = tab_num_fmt.with_columns(
    pl.col("method").cast(pl.Enum(categories=order))
).sort("method")

In [133]:
tab_num_fmt

method,num_ai=0_acc_se,num_ai=0_recall_se,num_ai=5_acc_se,num_ai=5_recall_se,num_ai=10_acc_se,num_ai=10_recall_se
enum,str,str,str,str,str,str
"""EMDS""","""$0.84 \pm 0.00$""","""$0.50 \pm 0.00$""","""$0.76 \pm 0.00$""","""$0.02 \pm 0.00$""","""$0.77 \pm 0.00$""","""$0.07 \pm 0.00$"""
"""BDS(iter_sampling=3000)""","""$0.87 \pm 0.00$""","""$0.54 \pm 0.00$""","""$0.84 \pm 0.04$""","""$0.58 \pm 0.24$""","""$0.77 \pm 0.10$""","""$0.75 \pm 0.19$"""
"""CATD""","""$0.84 \pm 0.01$""","""$0.45 \pm 0.02$""","""$0.76 \pm 0.00$""","""$0.11 \pm 0.01$""","""$0.74 \pm 0.00$""","""$0.02 \pm 0.00$"""
"""CBCC_M=2""","""$0.89 \pm 0.00$""","""$0.73 \pm 0.00$""","""$0.91 \pm 0.00$""","""$1.00 \pm 0.00$""","""$0.91 \pm 0.00$""","""$0.98 \pm 0.00$"""
"""CBCC_M=4""","""$0.89 \pm 0.00$""","""$0.73 \pm 0.00$""","""$0.91 \pm 0.00$""","""$1.00 \pm 0.00$""","""$0.91 \pm 0.00$""","""$0.98 \pm 0.00$"""
"""CBCC_M=8""","""$0.89 \pm 0.00$""","""$0.73 \pm 0.00$""","""$0.91 \pm 0.00$""","""$1.00 \pm 0.00$""","""$0.91 \pm 0.00$""","""$0.98 \pm 0.00$"""
"""GLAD""","""$0.85 \pm 0.00$""","""$0.50 \pm 0.00$""","""$0.74 \pm 0.00$""","""$0.00 \pm 0.00$""","""$0.74 \pm 0.00$""","""$0.01 \pm 0.00$"""
"""LA1""","""$0.83 \pm 0.01$""","""$0.34 \pm 0.02$""","""$0.76 \pm 0.00$""","""$0.12 \pm 0.00$""","""$0.75 \pm 0.00$""","""$0.04 \pm 0.00$"""
"""LA2""","""$0.84 \pm 0.00$""","""$0.35 \pm 0.02$""","""$0.76 \pm 0.00$""","""$0.12 \pm 0.00$""","""$0.75 \pm 0.00$""","""$0.04 \pm 0.00$"""
"""LFC""","""$0.84 \pm 0.00$""","""$0.52 \pm 0.00$""","""$0.84 \pm 0.00$""","""$0.28 \pm 0.00$""","""$0.82 \pm 0.00$""","""$0.11 \pm 0.00$"""
