In [None]:
#not for commercial use
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from time import sleep
from sklearn.metrics import (
    roc_curve,
    auc,
    balanced_accuracy_score,
    confusion_matrix,
    jaccard_score,
    cohen_kappa_score
)
from lifelines.utils import concordance_index

In [None]:
rater_files = {
    "gprater1": {
        "colorectal": "gprater1/colorectal_gprater1.csv",
        "dvt": "gprater1/dvt_gprater1.csv",
        "pad": "gprater1/pad_gprater1.csv",
        "pet": "gprater1/pet_gprater1.csv",
        "spect": "gprater1/spect_gprater1.csv",
        "stent": "gprater1/stent_gprater1.csv",
    },
    "gprater2": {
        "colorectal": "gprater2/colorectal_gprater2.csv",
        "dvt": "gprater2/dvt_gprater2.csv",
        "pad": "gprater2/pad_gprater2.csv",
        "pet": "gprater2/pet_gprater2.csv",
        "spect": "gprater2/spect_gprater2.csv",
        "stent": "gprater2/stent_gprater2.csv",
    },
    "gprater3": {
        "colorectal": "gprater3/colorectal_gprater3.csv",
        "dvt": "gprater3/dvt_gprater3.csv",
        "pad": "gprater3/PAD_gprater3.csv",
        "pet": "gprater3/PET_gprater3.csv",
        "spect": "gprater3/SPECT_gprater3.csv",
        "stent": "gprater3/STENT_gprater3.csv",
    },
    "erater1": {
        "colorectal": "erater1/erater1_colorectal.csv",
        "dvt": "erater1/erater1_dvt.csv",
        "pad": "erater1/erater1_pad.csv",
        "pet": "erater1/erater1_pet.csv",
        "spect": "erater1/erater1_spect.csv",
        "stent": "erater1/erater1_stent.csv",
    },
    "erater2": {
        "colorectal": "erater2/erater2_colorectal.csv",
        "dvt": "erater2/erater2_dvt.csv",
        "pad": "erater2/erater2_pad.csv",
        "pet": "erater2/erater2_pet.csv",
        "spect": "erater2/erater2_spect.csv",
        "stent": "erater2/erater2_stent.csv",
    },
    "erater3": {
        "colorectal": "erater3/colorectal.csv",
        "dvt": "erater3/dvt.csv",
        "pad": "erater3/pad.csv",
        "pet": "erater3/pet.csv",
        "spect": "erater3/spect.csv",
        "stent": "erater3/stent.csv",
    },
    "gpt": {
        "colorectal": "GPT3.5/gpt_colorectal.csv",
        "dvt": "GPT3.5/gpt_dvt.csv",
        "pad": "GPT3.5/gpt_pad.csv",
        "pet": "GPT3.5/gpt_pet.csv",
        "spect": "GPT3.5/gpt_spect.csv",
        "stent": "GPT3.5/gpt_stent.csv",
    },
}

In [None]:
raters = list(rater_files.keys())
articles = list(rater_files[raters[0]].keys())

human_raters = [rater for rater in raters if "gpt" not in rater]

gpt_raters = [rater for rater in raters if rater not in human_raters]

gp_raters = ["gprater2", "gprater1", "gprater3"]
expert_raters = ["erater1", "erater2", "erater3"]

print("all raters:", raters)
print("human raters:", human_raters)
print("gpt raters:", gpt_raters)
print("gp raters:", gp_raters)
print("expert raters:", expert_raters)
print("article types:", articles)


In [None]:
gpt_cutoff_dict = {
    "gpt": 2,
}
gp_range = (3, 5)

name_dict = {
    "gpt": "ChatGPT",
    "gprater1": "GP 1",
    "gprater2": "GP 2",
    "gprater3": "GP 3",
    "erater1": "Expert 1",
    "erater2": "Expert 2",
    "erater3": "Expert 3",
    "vote": "Voting Consensus",
    "spec_con": "Specific Consensus",
    "sens_con": "Sensitive Consensus",
    "vote_gp": "Voting Consensus (GPs)",
    "spec_con_gp": "Specific Consensus (GPs)",
    "sens_con_gp": "Sensitive Consensus (GPs)",
}

consensus_types = ["vote", "spec_con", "sens_con"]
consensus_types_gp = ["vote_gp", "spec_con_gp", "sens_con_gp"]
consensus_raters = expert_raters.copy()


In [None]:
inclusion_field = "inclusion"
inclusion_original_field = "inclusion_original"

In [None]:
metric_dict = {
    "sen": lambda _, _2, fn, tp, _3: tp / (tp + fn),
    "spec": lambda tn, fp, _, _2, _3: tn / (tn + fp),
    "ppv": lambda _, fp, _2, tp, _3: tp / (tp + fp),
    "npv": lambda tn, _, fn, _2, _3: tn / (tn + fn),
    "plr": lambda tn, fp, fn, tp, epsilon: (tp / (tp + fn))
    / (1 - (tn / (tn + fp)) - epsilon),
    "nlr": lambda tn, fp, fn, tp, epsilon: (
        (1 - (tp / (tp + fn))) / ((tn / (tn + fp)) + epsilon)
    ),
}
metrics = list(metric_dict.keys())
metric_name_dict = {
    "sen": "Sensitivity",
    "spec": "Specificity",
    "ppv": "Positive Predictive Value",
    "npv": "Negative Predictive Value",
    "plr": "Positive Likelihood Ratio",
    "nlr": "Negative Likelihood Ratio",
}

In [None]:
# make a dataframe for each csv file using a dictionary comprehension
dfs = {
    rater: {
        article: pd.read_csv(rater_files[rater][article])
        for article in articles
        if article in rater_files[rater]
    }
    for rater in raters
}

In [None]:
def apply_cutoff(
    cutoff: dict[str, int],
    gpts: list[str] = ["gpt"],
    articles: list[str] = articles,
    inclusion_field: str = inclusion_field,
    inclusion_original_field: str = inclusion_original_field,
):
    for rater in gpts:
        for article in articles:
            if article in rater_files[rater]:
                dfs[rater][article][inclusion_field] = (
                    dfs[rater][article][inclusion_original_field] > cutoff[rater]
                ).astype(int)

In [None]:
# in gpt inclusion column, use the cutoff value to discretize the values
# before applying the cutoff, save the original values in a new column
for rater in gpt_raters:
    for article in articles:
        if article in rater_files[rater]:
            dfs[rater][article][inclusion_original_field] = dfs[rater][article][
                inclusion_field
            ]
apply_cutoff(gpt_cutoff_dict, gpt_raters)

In [None]:
# in other raters, map include to 1 and exclude to 0 in the inclusion column
for rater in human_raters:
    for article in articles:
        if article in rater_files[rater]:
            dfs[rater][article][inclusion_field] = dfs[rater][article][
                inclusion_field
            ].map({"include": 1, "exclude": 0})

In [None]:
# check if inclusion column has nan values
for rater in raters:
    for article in articles:
        print(
            rater,
            article,
            dfs[rater][article][inclusion_field].isnull().values.any()
            if article in dfs[rater]
            else "not found",
        )

In [None]:
# count the number of rows in each dataframe
for rater in raters:
    for article in articles:
        print(
            rater,
            article,
            dfs[rater][article].shape[0] if article in dfs[rater] else "not found",
        )

In [None]:
# sort the dfs by title so we can compare them
for rater in raters:
    for article in articles:
        if article in dfs[rater]:
            dfs[rater][article] = dfs[rater][article].sort_values(by=["Title"])


In [None]:
# reset the indices
for rater in raters:
    for article in articles:
        if article in dfs[rater]:
            dfs[rater][article] = dfs[rater][article].reset_index(drop=True)


In [None]:
# compare Titles from each rater
# to make sure they are the same
for article in articles:
    print("\n", article)
    for rater in raters:
        if article in dfs[rater]:
            print(
                rater,
                "OK"
                if dfs[rater][article]["Title"].equals(dfs[raters[0]][article]["Title"])
                else "Not Identical",
            )

In [None]:
# concatenate the dfs for each rater
# so we can compare them
dfs_concat = {
    rater: pd.concat(
        [dfs[rater][article] for article in articles if article in dfs[rater]]
    )
    for rater in raters
}

In [None]:
# use the expert raters to reach a sensitive consensus
# but keep the column names the same
dfs_concat["sens_con"] = dfs_concat[consensus_raters[0]].copy()
for rater in consensus_raters:
    if rater != consensus_raters[0]:
        dfs_concat["sens_con"][inclusion_field] += dfs_concat[rater][inclusion_field]
# then divide by the number of raters to get the average
dfs_concat["sens_con"][inclusion_field] = (
    dfs_concat["sens_con"][inclusion_field] > 0
)  # we want to include if any rater included
dfs_concat["sens_con"][inclusion_field] = dfs_concat["sens_con"][
    inclusion_field
].astype(int)

In [None]:
# use the gp raters to reach a sensitive consensus
# but keep the column names the same
dfs_concat["sens_con_gp"] = dfs_concat[gp_raters[0]].copy()
for rater in gp_raters:
    if rater != gp_raters[0]:
        dfs_concat["sens_con_gp"][inclusion_field] += dfs_concat[rater][inclusion_field]
dfs_concat["sens_con_gp"][inclusion_field] = (
    dfs_concat["sens_con_gp"][inclusion_field] > 0
)  # we want to include if any rater included
dfs_concat["sens_con_gp"][inclusion_field] = dfs_concat["sens_con_gp"][
    inclusion_field
].astype(int)

In [None]:
# use the expert raters to reach a specific_spec_con
# but keep the column names the same
dfs_concat["spec_con"] = dfs_concat[consensus_raters[0]].copy()
for rater in consensus_raters:
    if rater != consensus_raters[0]:
        dfs_concat["spec_con"][inclusion_field] += dfs_concat[rater][inclusion_field]
# then divide by the number of raters to get the average
# we want to include if only rated positive by all raters (specific consensus)
dfs_concat["spec_con"][inclusion_field] = dfs_concat["spec_con"][
    inclusion_field
] == len(consensus_raters)
dfs_concat["spec_con"][inclusion_field] = dfs_concat["spec_con"][
    inclusion_field
].astype(int)

In [None]:
# use the gp raters to reach a specific_spec_con
# but keep the column names the same
dfs_concat["spec_con_gp"] = dfs_concat[gp_raters[0]].copy()
for rater in gp_raters:
    if rater != gp_raters[0]:
        dfs_concat["spec_con_gp"][inclusion_field] += dfs_concat[rater][inclusion_field]
# then divide by the number of raters to get the average
# we want to include if only rated positive by all raters (specific consensus)
dfs_concat["spec_con_gp"][inclusion_field] = dfs_concat["spec_con_gp"][
    inclusion_field
] == len(gp_raters)
dfs_concat["spec_con_gp"][inclusion_field] = dfs_concat["spec_con_gp"][
    inclusion_field
].astype(int)

In [None]:
# use the expert raters to reach a consensus
# but keep the column names the same
dfs_concat["vote"] = dfs_concat[consensus_raters[0]].copy()
for rater in consensus_raters:
    if rater != consensus_raters[0]:
        dfs_concat["vote"][inclusion_field] += dfs_concat[rater][inclusion_field]
# then divide by the number of raters to get the average
# we want to include if the average is greater than 0.5
dfs_concat["vote"][inclusion_field] /= len(consensus_raters)
# round the values to the nearest integer
dfs_concat["vote"][inclusion_field] = dfs_concat["vote"][inclusion_field] > 0.5
dfs_concat["vote"][inclusion_field] = dfs_concat["vote"][inclusion_field].astype(
    int)


In [None]:
# use the gp raters to reach a consensus
# but keep the column names the same
dfs_concat["vote_gp"] = dfs_concat[gp_raters[0]].copy()
for rater in gp_raters:
    if rater != gp_raters[0]:
        dfs_concat["vote_gp"][inclusion_field] += dfs_concat[rater][inclusion_field]
# then divide by the number of raters to get the average
# we want to include if the average is greater than 0.5
dfs_concat["vote_gp"][inclusion_field] /= len(gp_raters)
# round the values to the nearest integer
dfs_concat["vote_gp"][inclusion_field] = dfs_concat["vote_gp"][inclusion_field] > 0.5
dfs_concat["vote_gp"][inclusion_field] = dfs_concat["vote_gp"][inclusion_field].astype(
    int
)

In [None]:
def get_kappas(raters: list[str], dfs_concat: dict[str, pd.DataFrame] = dfs_concat):
    """get the kappas between raters

    Args:
        raters (list[str]): raters to get kappas for
        dfs_concat (dict[str, pd.DataFrame], optional): dataframes dict to get kappas from. Defaults to dfs_concat.

    Returns:
        numpy.ndarray: kappas between raters of shape (len(raters), len(raters))
    """
    rater_kappas = np.zeros((len(raters), len(raters)))
    mask = np.tri(rater_kappas.shape[-1], k=-1)
    for rater1 in raters:
        for rater2 in raters:
            if len(dfs_concat[rater1]) == len(dfs_concat[rater2]):
                rater_kappas[
                    raters.index(rater1), raters.index(rater2)
                ] = cohen_kappa_score(
                    dfs_concat[rater1][inclusion_field],
                    dfs_concat[rater2][inclusion_field],
                )
    rater_kappas *= mask
    return rater_kappas


In [None]:
def get_kappas_article(
    raters: list[str], articles: list[str], dfs: dict[str, dict[str, pd.DataFrame]]
):
    """get the kappas between raters for each article

    Args:
        raters (list[str]): raters to get kappas for
        articles (list[str]): articles to get kappas for
        dfs (dict[str, dict[str, pd.DataFrame]]): dataframes dict to get kappas from

    Returns:
          numpy.ndarray: kappas between raters of shape (len(articles), len(raters), len(raters))
    """
    # calculate the kohen kappa score between raters where possible
    article_rater_kappas = np.zeros((len(articles), len(raters), len(raters)))
    for article in articles:
        for rater1 in raters:
            for rater2 in raters:
                if article in dfs[rater1] and article in dfs[rater2]:
                    kappa = cohen_kappa_score(
                        dfs[rater1][article][inclusion_field],
                        dfs[rater2][article][inclusion_field],
                    )
                    article_rater_kappas[articles.index(article)][raters.index(rater1)][
                        raters.index(rater2)
                    ] = kappa
    # to remove the lower triangle we multiply the matrix by a mask
    mask = np.tri(article_rater_kappas.shape[-1], k=-1)
    article_rater_kappas = article_rater_kappas * mask
    return article_rater_kappas  # shape: (num_articles, num_raters, num_raters)

In [None]:
def plot_kappas(
    raters: list[str],
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    name_dict: dict[str, str] = name_dict,
    figsize: tuple[float, float] = (10, 10),
    save_path: str = None,
):
    """plot the kappas between raters

    Args:
        raters (list[str]): raters to plot kappas for
        dfs_concat (dict[str, pd.DataFrame], optional): dataframes dict to get kappas from. Defaults to dfs_concat.
        name_dict (dict[str, str], optional): dict to map raters to names. Defaults to name_dict.
        figsize (tuple[float, float], optional): figure size. Defaults to (10, 10).
        save_path (str, optional): path to save figure to. Defaults to None.
    """
    kappas = get_kappas(raters, dfs_concat)
    fig, ax = plt.subplots(figsize=figsize)
    rater_names = list(map(lambda x: name_dict[x], raters))
    sns.heatmap(
        kappas,
        annot=True,
        xticklabels=rater_names,
        yticklabels=rater_names,
        cmap="Blues",
        vmin=0,
        vmax=1,
        cbar=False,
        square=True,
        ax=ax,
    )
    ax.set_title("Kappa scores for raters")
    if save_path is not None:
        plt.savefig(save_path, dpi=300)
    else:
        plt.show()


In [None]:
def plot_kappas_article(
    articles: list[str],
    raters: list[str],
    dfs: dict[str, dict[str, pd.DataFrame]] = dfs,
    name_dict: dict[str, str] = name_dict,
    n_cols: int = 2,
    size: float = 5,
    save_path: str = None,
):
    """Plot the kappas for each article

    Args:
        articles (list[str]): the articles to plot
        raters (list[str]): the raters to plot
        dfs (dict[str, dict[str, pd.DataFrame]], optional): dfs containing the data. Defaults to dfs.
        name_dict (dict[str, str], optional): dictionary mapping rater names to the names to be displayed on the plot. Defaults to name_dict.
        n_cols (int, optional): number of columns in the plot. Defaults to 2.
        size (float, optional): size of each plot. Defaults to 5.
    """
    kappas_matrix = get_kappas_article(raters, articles, dfs)
    nrows = len(articles) // n_cols + len(articles) % n_cols
    _, axs = plt.subplots(
        nrows=nrows,
        ncols=n_cols,
        figsize=(size * n_cols, size * nrows),
    )
    rater_names = [name_dict[rater] for rater in raters]
    for i, article in enumerate(articles):
        sns.heatmap(
            kappas_matrix[articles.index(article)],
            annot=True,
            xticklabels=rater_names,
            yticklabels=rater_names,
            cmap="Blues",
            square=True,
            cbar=False,
            ax=axs[i // n_cols, i % n_cols],
        )
        axs[i // n_cols, i % n_cols].set_title(f"{article.upper()} Kappa")
    if save_path is not None:
        plt.savefig(save_path, dpi=300)
    else:
        plt.show()


In [None]:
def print_kappa_ci(
    raters: list[str],
    name_dict: dict[str, str],
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    n_bootstraps: int = 1000,
    ci_level: float = 0.95,
    seed: int = 42,
):
    """print the confidence interval for the kappa scores between raters

    Args:
        raters (list[str]): raters to get kappas for
        name_dict (dict[str, str]): dict to map raters to names
        dfs_concat (dict[str, pd.DataFrame], optional): dataframes dict to get kappas from. Defaults to dfs_concat.
        n_bootstraps (int, optional): number of bootstraps to use. Defaults to 1000.
        ci_level (float, optional): confidence interval level. Defaults to 0.95.
        seed (int, optional): seed for the random number generator. Defaults to 42.
    """
    for i, rater1 in enumerate(raters):
        for j, rater2 in enumerate(raters):
            if j <= i:
                continue
            print(f"{name_dict[rater1]} vs {name_dict[rater2]}")
            observed_kappa = cohen_kappa_score(
                dfs_concat[rater1][inclusion_field], dfs_concat[rater2][inclusion_field]
            )
            rng = np.random.default_rng(seed)
            kappas = []
            for _ in range(n_bootstraps):
                bs_indices = rng.choice(
                    len(dfs_concat[rater1]), len(dfs_concat[rater1]), replace=True
                )
                kappa = cohen_kappa_score(
                    dfs_concat[rater1][inclusion_field].iloc[bs_indices],
                    dfs_concat[rater2][inclusion_field].iloc[bs_indices],
                )
                kappas.append(kappa)
            kappas = np.array(kappas)
            lower_ci = np.quantile(kappas, (1 - ci_level) / 2)
            upper_ci = np.quantile(kappas, 1 - (1 - ci_level) / 2)
            print(
                f"Kappa [{ci_level*100:.0f}% CI]: {observed_kappa:.2f} [{lower_ci:.2f}, {upper_ci:.2f}]".rjust(
                    20
                )
            )
            print("\n")

In [None]:
def plot_roc(
    gold_standards: list[str],
    comparators: list[str],
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    name_dict: dict[str, str] = name_dict,
    size: float = 3,
    save_path: str = None,
):
    """Plot the ROC curves for the comparators against the gold standards

    Args:
        gold_standards (list[str]): gold standards to evaluate against
        comparators (list[str]): comparators to evaluate
        dfs_concat (dict[str, pd.DataFrame], optional): dfs to get data from. Defaults to dfs_concat.
        name_dict (dict[str, str], optional): dict to map names to display names. Defaults to name_dict.
        size (float, optional): size of each plot. Defaults to 3.
        save_path (str, optional): path to save the plot to. Defaults to None.
    """
    _, axs = plt.subplots(
        nrows=len(gold_standards),
        ncols=len(comparators),
        sharex=True,
        sharey=True,
        figsize=(len(gold_standards) * size, len(comparators) * (size) * 1.2),
    )
    for j, gold_standard in enumerate(gold_standards):
        for i, comparator in enumerate(comparators):
            fpr, tpr, thresholds = roc_curve(
                dfs_concat[gold_standard][inclusion_field],
                dfs_concat[comparator][inclusion_original_field],
            )
            roc_auc = auc(fpr, tpr)
            youden = np.argmax(tpr - fpr)
            axs[j, i].plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.2f})")
            axs[j, i].plot([0, 1], [0, 1], linestyle="--")
            axs[j, i].plot(
                fpr[youden],
                tpr[youden],
                marker="o",
                markersize=10,
                label=f"Youden threshold = {thresholds[youden]:.0f}",
            )
            axs[j, i].set_xlabel("False Positive Rate")
            axs[j, i].set_ylabel("True Positive Rate")
            axs[j, i].set_title(
                f"{name_dict[comparator]} vs {name_dict[gold_standard]}"
            )
            axs[j, i].legend(loc="lower right")
    if save_path is not None:
        plt.savefig(save_path, dpi=300)
    else: 
        plt.show()

In [None]:
def plot_roc_ci(
    gold_standards: list[str],
    comparators: list[str],
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    name_dict: dict[str, str] = name_dict,
    size: float = 5,
    bootstrap_samples: int = 1000,
    ci_level: float = 0.95,
    save_path: str = None,
    seed: int = 42,
):
    """Plot the ROC curves for the comparators against the gold standards with confidence intervals

    Args:
        gold_standards (list[str]): gold standards to evaluate against
        comparators (list[str]): comparators to evaluate
        dfs_concat (dict[str, pd.DataFrame], optional): dict of dfs to get data from. Defaults to dfs_concat.
        name_dict (dict[str, str], optional): dict to map names to display names. Defaults to name_dict.
        size (float, optional): size of each plot. Defaults to 5.
        bootstrap_samples (int, optional): number of bootstrap samples to use. Defaults to 1000.
        ci_level (float, optional): confidence interval level. Defaults to 0.95.
        save_path (str, optional): path to save the plot to. Defaults to None.
        seed (int, optional): random seed. Defaults to 42.
    """
    _, axs = plt.subplots(
        nrows=len(gold_standards),
        ncols=len(comparators),
        sharex=True,
        sharey=True,
        figsize=(len(gold_standards) * size, len(comparators) * (size) * 1.5),
    )
    for j, gold_standard in enumerate(gold_standards):
        for i, comparator in enumerate(comparators):
            # calculate original ROC curve
            fpr, tpr, thresholds = roc_curve(
                dfs_concat[gold_standard][inclusion_field],
                dfs_concat[comparator][inclusion_original_field],
            )
            roc_auc = auc(fpr, tpr)
            youden = np.argmax(tpr - fpr)

            # calculate ROC curves for bootstrapped samples
            bootstrapped_tprs = []
            aucs = []
            rng = np.random.default_rng(seed)
            for _ in range(bootstrap_samples):
                # resample with replacement
                bootstrap_indices = rng.choice(
                    np.arange(len(dfs_concat[gold_standard][inclusion_field])),
                    size=len(dfs_concat[gold_standard][inclusion_field]),
                )
                bootstrapped_fpr, bootstrapped_tpr, _ = roc_curve(
                    dfs_concat[gold_standard][inclusion_field].iloc[bootstrap_indices],
                    dfs_concat[comparator][inclusion_original_field].iloc[
                        bootstrap_indices
                    ],
                )
                aucs.append(auc(bootstrapped_fpr, bootstrapped_tpr))
                bootstrapped_tprs.append(
                    np.interp(
                        np.linspace(0, 1, 100), bootstrapped_fpr, bootstrapped_tpr
                    )
                )

            # calculate lower and upper confidence intervals
            bootstrapped_tprs = np.array(bootstrapped_tprs)
            tprs_lower = np.percentile(
                bootstrapped_tprs, ((1.0 - ci_level) / 2) * 100, axis=0
            )
            tprs_upper = np.percentile(
                bootstrapped_tprs, (1.0 - ((1.0 - ci_level) / 2)) * 100, axis=0
            )
            aucs = np.array(aucs)
            aucs_lower = np.percentile(aucs, ((1.0 - ci_level) / 2) * 100, axis=0)
            aucs_upper = np.percentile(
                aucs, (1.0 - ((1.0 - ci_level) / 2)) * 100, axis=0
            )

            if len(gold_standards) == len(comparators) == 1:
                axs.plot(
                    fpr,
                    tpr,
                    label=f"ROC (AUC = {roc_auc:.2f} [{aucs_lower:.2f},{aucs_upper:.2f}])",
                )
                axs.fill_between(
                    np.linspace(0, 1, 100),
                    tprs_lower,
                    tprs_upper,
                    color="lightblue",
                    alpha=0.3,
                    label=f"{ci_level*100:.0f}% CI",
                )
                axs.plot([0, 1], [0, 1], linestyle="--")
                axs.plot(
                    fpr[youden],
                    tpr[youden],
                    marker="o",
                    markersize=10,
                    label=f"Youden threshold = {thresholds[youden]:.0f}",
                )

                axs.set_xlabel("False Positive Rate")
                axs.set_ylabel("True Positive Rate")
                axs.set_title(f"{name_dict[comparator]} vs {name_dict[gold_standard]}")
                axs.set_xlim([0, 1])
                axs.set_ylim([0, 1])
                # square axes
                axs.set_aspect("equal", adjustable="box")
                axs.legend(loc="lower right")

            elif len(gold_standards) == 1 or len(comparators) == 1:
                axs[i + j].plot(
                    fpr,
                    tpr,
                    label=f"ROC (AUC = {roc_auc:.2f} [{aucs_lower:.2f},{aucs_upper:.2f}])",
                )
                axs[i + j].fill_between(
                    np.linspace(0, 1, 100),
                    tprs_lower,
                    tprs_upper,
                    color="lightblue",
                    alpha=0.3,
                    label=f"{ci_level*100:.0f}% CI",
                )
                axs[i + j].plot([0, 1], [0, 1], linestyle="--")
                axs[i + j].plot(
                    fpr[youden],
                    tpr[youden],
                    marker="o",
                    markersize=10,
                    label=f"Youden threshold = {thresholds[youden]:.0f}",
                )

                axs[i + j].set_xlabel("False Positive Rate")
                axs[i + j].set_ylabel("True Positive Rate")
                axs[i + j].set_title(
                    f"{name_dict[comparator]} vs {name_dict[gold_standard]}"
                )
                axs[i + j].set_xlim([0, 1])
                axs[i + j].set_ylim([0, 1])
                # square axes
                axs[i + j].set_aspect("equal", adjustable="box")
                axs[i + j].legend(loc="lower right")

            else:
                axs[j, i].plot(
                    fpr,
                    tpr,
                    label=f"ROC (AUC = {roc_auc:.2f} [{aucs_lower:.2f},{aucs_upper:.2f}])",
                )
                axs[j, i].fill_between(
                    np.linspace(0, 1, 100),
                    tprs_lower,
                    tprs_upper,
                    color="lightblue",
                    alpha=0.3,
                    label=f"{ci_level*100:.0f}% CI",
                )
                axs[j, i].plot([0, 1], [0, 1], linestyle="--")
                axs[j, i].plot(
                    fpr[youden],
                    tpr[youden],
                    marker="o",
                    markersize=10,
                    label=f"Youden threshold = {thresholds[youden]:.0f}",
                )

                axs[j, i].set_xlabel("False Positive Rate")
                axs[j, i].set_ylabel("True Positive Rate")
                axs[j, i].set_title(
                    f"{name_dict[comparator]} vs {name_dict[gold_standard]}"
                )
                axs[j, i].set_xlim([0, 1])
                axs[j, i].set_ylim([0, 1])
                # square axes
                axs[j, i].set_aspect("equal", adjustable="box")
                axs[j, i].legend(loc="lower right")
    if save_path is not None:
        plt.savefig(save_path, dpi=300)
    else:
        plt.show()

In [None]:
def plot_confusion(
    gold_standards: list[str],
    comparators: list[str],
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    name_dict: dict[str, str] = name_dict,
    size: float = 5,
    save_path: str = None,
):
    """plot confusion matrix for each gold standard and comparator

    Args:
        gold_standards (list[str]): gold standard names to evaluate against
        comparators (list[str]): comparator names to evaluate
        dfs_concat (dict[str, pd.DataFrame], optional): dictionary of dataframes. Defaults to dfs_concat.
        name_dict (dict[str, str], optional): dictionary of names. Defaults to name_dict.
        size (float, optional): size of each plot. Defaults to 5.
        save_path (str, optional): path to save figure. Defaults to None.
    """
    _, axs = plt.subplots(
        len(comparators),
        len(gold_standards),
        sharex=True,
        sharey=True,
        figsize=(len(comparators) * size, len(gold_standards) * size),
    )

    for i, comparator in enumerate(comparators):
        print(name_dict[comparator])
        for j, gold_standard in enumerate(gold_standards):
            print(name_dict[gold_standard])
            cm = confusion_matrix(
                dfs_concat[gold_standard][inclusion_field],
                dfs_concat[comparator][inclusion_field],
            )
            tn, fp, fn, tp = cm.ravel()
            npv = tn / (tn + fn)
            ppv = tp / (tp + fp)
            sen = tp / (tp + fn)
            spec = tn / (tn + fp)
            plr = sen / (1 - spec)
            nlr = (1 - sen) / spec
            print(f"tn: {tn}, fp: {fp}, fn: {fn}, tp: {tp}")
            print(
                f"NPV: {npv:.2f}, PPV: {ppv:.2f}, \nSensitivity: {sen:.2f}, Specificity: {spec:.2f}\nPLR: {plr:.2f}, NLR: {nlr:.2f}"
            )
            # plot confusion matrix
            # normalize the confusion matrix
            cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
            if len(comparators) == len(gold_standards) == 1:
                sns.heatmap(
                    cm,
                    annot=True,
                    xticklabels=["Excluded", "Included"],
                    yticklabels=["Excluded", "Included"],
                    cmap="Blues",
                    vmin=0,
                    vmax=1,
                    square=True,
                    cbar=False,
                    ax=axs,
                )
                axs.set_xlabel(name_dict[comparator])
                axs.set_ylabel(name_dict[gold_standard])
                axs.set_title(f"{name_dict[comparator]}")
            elif len(comparators) == 1 or len(gold_standards) == 1:
                sns.heatmap(
                    cm,
                    annot=True,
                    xticklabels=["Excluded", "Included"],
                    yticklabels=["Excluded", "Included"],
                    cmap="Blues",
                    vmin=0,
                    vmax=1,
                    square=True,
                    cbar=False,
                    ax=axs[i + j],
                )
                axs[i + j].set_xlabel(name_dict[comparator])
                axs[i + j].set_ylabel(name_dict[gold_standard])
                axs[i + j].set_title(f"{name_dict[comparator]}")
            else:
                sns.heatmap(
                    cm,
                    annot=True,
                    xticklabels=["Excluded", "Included"],
                    yticklabels=["Excluded", "Included"],
                    cmap="Blues",
                    vmin=0,
                    vmax=1,
                    square=True,
                    cbar=False,
                    ax=axs[i, j],
                )
                axs[i, j].set_xlabel(name_dict[comparator])
                axs[i, j].set_ylabel(name_dict[gold_standard])
                axs[i, j].set_title(f"{name_dict[comparator]}")
            print("\n")
            sleep(0.1)
    if save_path is not None:
        plt.savefig(save_path, dpi=300)
    else:
        # show the plot
        plt.show()

In [None]:
def print_metrics_ci(
    gold_standards: list[str],
    comparators: list[str],
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    name_dict: dict[str, str] = name_dict,
    bootstrap_samples: int = 1000,
    ci_level: float = 0.95,
    epsilon: float = 1e-6,
    seed: int = 42,
):
    """print metrics with confidence intervals

    Args:
        gold_standards (list[str]): gold standard names to evaluate against
        comparators (list[str]): comparator names to evaluate
        dfs_concat (dict[str, pd.DataFrame], optional): dictionary of dataframes. Defaults to dfs_concat.
        name_dict (dict[str, str], optional): dictionary of names. Defaults to name_dict.
        bootstrap_samples (int, optional): bootstrap samples. Defaults to 1000.
        ci_level (float, optional): confidence interval level. Defaults to 0.95.
        epsilon (float, optional): epsilon to avoid division by zero. Defaults to 1e-6.
        seed (int, optional): random seed. Defaults to 42.
    """
    for comparator in comparators:
        print(name_dict[comparator])
        print(f"{bootstrap_samples} bootstrapped samples, {ci_level*100:.0f}% CI")
        for gold_standard in gold_standards:
            print(name_dict[gold_standard])
            # calculate metrics for the original data
            tn, fp, fn, tp = confusion_matrix(
                dfs_concat[gold_standard][inclusion_field],
                dfs_concat[comparator][inclusion_field],
            ).ravel()
            sen = metric_dict["sen"](tn, fp, fn, tp, epsilon)
            spec = metric_dict["spec"](tn, fp, fn, tp, epsilon)
            ppv = metric_dict["ppv"](tn, fp, fn, tp, epsilon)
            npv = metric_dict["npv"](tn, fp, fn, tp, epsilon)
            plr = metric_dict["plr"](tn, fp, fn, tp, epsilon)
            nlr = metric_dict["nlr"](tn, fp, fn, tp, epsilon)
            # calculate metrics for the bootstrapped data
            sens = []
            specs = []
            ppvs = []
            npvs = []
            plrs = []
            nlrs = []
            rng = np.random.default_rng(seed=seed)
            for _ in range(bootstrap_samples):
                bs_indices = rng.choice(
                    range(len(dfs_concat[gold_standard])),
                    size=len(dfs_concat[gold_standard]),
                    replace=True,
                )
                tn, fp, fn, tp = confusion_matrix(
                    dfs_concat[gold_standard][inclusion_field].iloc[bs_indices],
                    dfs_concat[comparator][inclusion_field].iloc[bs_indices],
                ).ravel()
                sens.append(metric_dict["sen"](tn, fp, fn, tp, epsilon))
                specs.append(metric_dict["spec"](tn, fp, fn, tp, epsilon))
                ppvs.append(metric_dict["ppv"](tn, fp, fn, tp, epsilon))
                npvs.append(metric_dict["npv"](tn, fp, fn, tp, epsilon))
                plrs.append(metric_dict["plr"](tn, fp, fn, tp, epsilon))
                nlrs.append(metric_dict["nlr"](tn, fp, fn, tp, epsilon))
            sens = np.array(sens)
            specs = np.array(specs)
            ppvs = np.array(ppvs)
            npvs = np.array(npvs)
            plrs = np.array(plrs)
            nlrs = np.array(nlrs)
            sens_lower = np.percentile(sens, ((1.0 - ci_level) / 2) * 100, axis=0)
            sens_upper = np.percentile(
                sens, (1.0 - ((1.0 - ci_level) / 2)) * 100, axis=0
            )

            specs_lower = np.percentile(specs, ((1.0 - ci_level) / 2) * 100, axis=0)
            specs_upper = np.percentile(
                specs, (1.0 - ((1.0 - ci_level) / 2)) * 100, axis=0
            )

            ppvs_lower = np.percentile(ppvs, ((1.0 - ci_level) / 2) * 100, axis=0)
            ppvs_upper = np.percentile(
                ppvs, (1.0 - ((1.0 - ci_level) / 2)) * 100, axis=0
            )

            npvs_lower = np.percentile(npvs, ((1.0 - ci_level) / 2) * 100, axis=0)
            npvs_upper = np.percentile(
                npvs, (1.0 - ((1.0 - ci_level) / 2)) * 100, axis=0
            )

            plrs_lower = np.percentile(plrs, ((1.0 - ci_level) / 2) * 100, axis=0)
            plrs_upper = np.percentile(
                plrs, (1.0 - ((1.0 - ci_level) / 2)) * 100, axis=0
            )

            nlrs_lower = np.percentile(nlrs, ((1.0 - ci_level) / 2) * 100, axis=0)
            nlrs_upper = np.percentile(
                nlrs, (1.0 - ((1.0 - ci_level) / 2)) * 100, axis=0
            )
            # print metrics
            print(
                f"Sensitivity: {sen:.2f} ({sens_lower:.2f}-{sens_upper:.2f})".rjust(40)
            )
            print(
                f"Specificity: {spec:.2f} ({specs_lower:.2f}-{specs_upper:.2f})".rjust(
                    40
                )
            )
            print(f"PPV: {ppv:.2f} ({ppvs_lower:.2f}-{ppvs_upper:.2f})".rjust(40))
            print(f"NPV: {npv:.2f} ({npvs_lower:.2f}-{npvs_upper:.2f})".rjust(40))
            print(f"PLR: {plr:.2f} ({plrs_lower:.2f}-{plrs_upper:.2f})".rjust(40))
            print(f"NLR: {nlr:.2f} ({nlrs_lower:.2f}-{nlrs_upper:.2f})".rjust(40))
            print("\n")

In [None]:
def compare_metrics(
    gold_standards: list[str],
    comparators: list[str],
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    name_dict: dict[str, str] = name_dict,
    metrics: list[str] = ["sen", "spec", "ppv", "npv", "plr", "nlr"],
    bootstrap_samples: int = 1000,
    ci_level: float = 0.95,
    plot_histograms: bool = False,
    figsize: tuple[float, float] = (8, 5),
    epsilon: float = 1e-6,
    seed: int = 42,
):
    """compare metrics between comparators against gold standards

    Args:
        gold_standards (list[str]): gold standard names to compare against
        comparators (list[str]): comparator names to compare against each other
        dfs_concat (dict[str, pd.DataFrame], optional): dict of dfs. Defaults to dfs_concat.
        name_dict (dict[str, str], optional): dict of names. Defaults to name_dict.
        metrics (list[str], optional): list of metrics to compare. Defaults to ["sen", "spec", "ppv", "npv", "plr", "nlr"].
        bootstrap_samples (int, optional): bootstrap samples. Defaults to 1000.
        ci_level (float, optional): confidence interval level. Defaults to 0.95.
        plot_histograms (bool, optional): whether to plot histograms. Defaults to False.
        figsize (tuple[float, float], optional): figure size. Defaults to (8, 5).
        epsilon (float, optional): epsilon to avoid division by zero. Defaults to 1e-6.
        seed (int, optional): random seed. Defaults to 42.
    """
    for gold_standard in gold_standards:
        print(f"Gold Standard: {name_dict[gold_standard]}")
        for i, comparator1 in enumerate(comparators):
            for j, comparator2 in enumerate(comparators):
                if j <= i:
                    continue
                print(
                    f"\tComparing {name_dict[comparator1]} to {name_dict[comparator2]}:"
                )
                # we want to compare metrics of the two comparators
                # and report p-values for the difference
                # using bootstrapped samples
                for metric in metrics:
                    # compute observed difference
                    o_tn1, o_fp1, o_fn1, o_tp1 = confusion_matrix(
                        dfs_concat[gold_standard][inclusion_field],
                        dfs_concat[comparator1][inclusion_field],
                    ).ravel()
                    o_tn2, o_fp2, o_fn2, o_tp2 = confusion_matrix(
                        dfs_concat[gold_standard][inclusion_field],
                        dfs_concat[comparator2][inclusion_field],
                    ).ravel()
                    observed_metric1 = metric_dict[metric](
                        o_tn1, o_fp1, o_fn1, o_tp1, epsilon
                    )
                    observed_metric2 = metric_dict[metric](
                        o_tn2, o_fp2, o_fn2, o_tp2, epsilon
                    )
                    observed_diff = observed_metric1 - observed_metric2

                    # compute bootstrapped differences
                    rng = np.random.default_rng(seed=seed)
                    bootstrapped_diffs = []
                    for _ in range(bootstrap_samples):
                        bs_indices = rng.choice(
                            range(len(dfs_concat[gold_standard])),
                            size=len(dfs_concat[gold_standard]),
                            replace=True,
                        )
                        tn1, fp1, fn1, tp1 = confusion_matrix(
                            dfs_concat[gold_standard][inclusion_field].iloc[bs_indices],
                            dfs_concat[comparator1][inclusion_field].iloc[bs_indices],
                        ).ravel()
                        tn2, fp2, fn2, tp2 = confusion_matrix(
                            dfs_concat[gold_standard][inclusion_field].iloc[bs_indices],
                            dfs_concat[comparator2][inclusion_field].iloc[bs_indices],
                        ).ravel()
                        bs_metric1 = metric_dict[metric](tn1, fp1, fn1, tp1, epsilon)
                        bs_metric2 = metric_dict[metric](tn2, fp2, fn2, tp2, epsilon)
                        bootstrapped_diffs.append(bs_metric1 - bs_metric2)
                    bootstrapped_diffs = np.array(bootstrapped_diffs)
                    # calculate the p-value
                    # we want to see what percentage of the bootstrapped differences cross 0
                    p_value = np.min(
                        [
                            np.mean(bootstrapped_diffs > 0),
                            np.mean(bootstrapped_diffs < 0),
                        ]
                    )

                    # compute confidence interval
                    lower_diff = np.percentile(
                        bootstrapped_diffs, ((1.0 - ci_level) / 2) * 100
                    )
                    upper_diff = np.percentile(
                        bootstrapped_diffs, (1.0 - ((1.0 - ci_level) / 2)) * 100
                    )

                    # if plot_histograms is True, plot the bootstrapped differences and mark the observed difference
                    # with a vertical line at 0
                    # shade the area of p-value
                    if plot_histograms:
                        plt.figure(figsize=figsize)
                        sns.histplot(
                            bootstrapped_diffs,
                            kde=True,
                            stat="density",
                            color="grey",
                            label="bootstrapped differences",
                            alpha=0.25,
                        )
                        plt.axvline(
                            x=observed_diff, color="orange", label="observed difference"
                        )
                        plt.axvspan(
                            xmin=lower_diff,
                            xmax=upper_diff,
                            color="green",
                            alpha=0.2,
                            label=f"{ci_level*100:.0f}% CI",
                        )
                        if observed_diff >= 0:  # highlight values below 0
                            plt.axvspan(
                                xmin=-1000,
                                xmax=0,
                                color="red",
                                alpha=0.3,
                                label="H0 true",
                            )
                        else:  # highlight values above 0
                            plt.axvspan(
                                xmin=0,
                                xmax=1000,
                                color="red",
                                alpha=0.3,
                                label="H0 true",
                            )
                        plt.xlim(
                            bootstrapped_diffs.min()
                            - 0.25 * np.abs(bootstrapped_diffs.min()),
                            bootstrapped_diffs.max()
                            + 0.25 * np.abs(bootstrapped_diffs.max()),
                        )
                        plt.xlabel(
                            f"\u0394 {metric_name_dict[metric]} ({name_dict[comparator1]} - {name_dict[comparator2]})"
                        )
                        plt.ylabel("Frequency")
                        plt.legend()
                        plt.show()

                    print(f"\t\t{metric_name_dict[metric]}:")
                    print(f"\t\t{name_dict[comparator1]}: {observed_metric1:.2f}")
                    print(f"\t\t{name_dict[comparator2]}: {observed_metric2:.2f}")
                    # greek letter delta = \u0394
                    print(
                        f"\t\t\u0394: {observed_diff:.2f} [{lower_diff:.2f},{upper_diff:.2f}]"
                    )
                    print(f"\t\tp-value : {p_value:.3f}")
                    print("\n")

In [None]:
def compare_metrics_article(
    gold_standards: list[str],
    comparators: list[str],
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    name_dict: dict[str, str] = name_dict,
    articles: list[str] = articles,
    metrics: list[str] = ["sen", "spec", "ppv", "npv", "plr", "nlr"],
    bootstrap_samples: int = 1000,
    ci_level: float = 0.95,
    plot_histograms: bool = False,
    figsize: tuple[float, float] = (8, 5),
    epsilon: float = 1e-6,
    seed: int = 42,
):
    """Compare metrics between two comparators for each article.

    Args:
        gold_standards (list[str]): gold standard names to compare against
        comparators (list[str]): comparator names to compare to each other
        dfs_concat (dict[str, pd.DataFrame], optional): dictionary of dataframes. Defaults to dfs_concat.
        name_dict (dict[str, str], optional): dictionary of names. Defaults to name_dict.
        articles (list[str], optional): list of articles to compare. Defaults to articles.
        metrics (list[str], optional): list of metrics to compare. Defaults to ["sen", "spec", "ppv", "npv", "plr", "nlr"].
        bootstrap_samples (int, optional): bootstrap samples. Defaults to 1000.
        ci_level (float, optional): confidence interval level. Defaults to 0.95.
        plot_histograms (bool, optional): whether to plot histograms. Defaults to False.
        figsize (tuple[float, float], optional): figure size. Defaults to (8, 5).
        epsilon (float, optional): epsilon to avoid division by zero. Defaults to 1e-6.
        seed (int, optional): random seed. Defaults to 42.
    """
    for gold_standard in gold_standards:
        print(f"Gold Standard: {name_dict[gold_standard]}")
        for i, comparator1 in enumerate(comparators):
            for j, comparator2 in enumerate(comparators):
                if j <= i:
                    continue
                print(
                    f"\tComparing {name_dict[comparator1]} to {name_dict[comparator2]}:"
                )
                # we want to compare metrics of the two comparators
                # and report p-values for the difference
                # using bootstrapped samples
                for article in articles:
                    print(f"\t\tArticle: {article.upper()}")
                    for metric in metrics:
                        # compute observed difference
                        o_tn1, o_fp1, o_fn1, o_tp1 = confusion_matrix(
                            dfs_concat[gold_standard].loc[
                                dfs_concat[gold_standard]["article"] == article
                            ][inclusion_field],
                            dfs_concat[comparator1].loc[
                                dfs_concat[comparator1]["article"] == article
                            ][inclusion_field],
                        ).ravel()
                        o_tn2, o_fp2, o_fn2, o_tp2 = confusion_matrix(
                            dfs_concat[gold_standard].loc[
                                dfs_concat[gold_standard]["article"] == article
                            ][inclusion_field],
                            dfs_concat[comparator2].loc[
                                dfs_concat[comparator2]["article"] == article
                            ][inclusion_field],
                        ).ravel()
                        observed_metric1 = metric_dict[metric](
                            o_tn1, o_fp1, o_fn1, o_tp1, epsilon
                        )
                        observed_metric2 = metric_dict[metric](
                            o_tn2, o_fp2, o_fn2, o_tp2, epsilon
                        )
                        observed_diff = observed_metric1 - observed_metric2

                        # compute bootstrapped differences
                        rng = np.random.default_rng(seed=seed)
                        bootstrapped_diffs = []
                        for _ in range(bootstrap_samples):
                            bs_indices = rng.choice(
                                range(
                                    len(
                                        dfs_concat[gold_standard].loc[
                                            dfs_concat[gold_standard]["article"]
                                            == article
                                        ]
                                    )
                                ),
                                size=len(
                                    dfs_concat[gold_standard].loc[
                                        dfs_concat[gold_standard]["article"] == article
                                    ]
                                ),
                                replace=True,
                            )
                            tn1, fp1, fn1, tp1 = confusion_matrix(
                                dfs_concat[gold_standard]
                                .loc[dfs_concat[gold_standard]["article"] == article][
                                    inclusion_field
                                ]
                                .iloc[bs_indices],
                                dfs_concat[comparator1]
                                .loc[dfs_concat[comparator1]["article"] == article][
                                    inclusion_field
                                ]
                                .iloc[bs_indices],
                            ).ravel()
                            tn2, fp2, fn2, tp2 = confusion_matrix(
                                dfs_concat[gold_standard]
                                .loc[dfs_concat[gold_standard]["article"] == article][
                                    inclusion_field
                                ]
                                .iloc[bs_indices],
                                dfs_concat[comparator2]
                                .loc[dfs_concat[comparator2]["article"] == article][
                                    inclusion_field
                                ]
                                .iloc[bs_indices],
                            ).ravel()
                            bs_metric1 = metric_dict[metric](
                                tn1, fp1, fn1, tp1, epsilon
                            )
                            bs_metric2 = metric_dict[metric](
                                tn2, fp2, fn2, tp2, epsilon
                            )
                            bootstrapped_diffs.append(bs_metric1 - bs_metric2)
                        bootstrapped_diffs = np.array(bootstrapped_diffs)
                        # calculate the p-value
                        # we want to see what percentage of the bootstrapped differences cross 0
                        p_value = np.min(
                            [
                                np.mean(bootstrapped_diffs > 0),
                                np.mean(bootstrapped_diffs < 0),
                            ]
                        )

                        # compute confidence interval
                        lower_diff = np.percentile(
                            bootstrapped_diffs, ((1.0 - ci_level) / 2) * 100
                        )
                        upper_diff = np.percentile(
                            bootstrapped_diffs, (1.0 - ((1.0 - ci_level) / 2)) * 100
                        )

                        # if plot_histograms is True, plot the bootstrapped differences and mark the observed difference
                        # with a vertical line at 0
                        # shade the area of p-value
                        if plot_histograms:
                            plt.figure(figsize=figsize)
                            sns.histplot(
                                bootstrapped_diffs,
                                kde=True,
                                stat="density",
                                color="grey",
                                label="bootstrapped differences",
                                alpha=0.25,
                            )
                            plt.axvline(
                                x=observed_diff,
                                color="orange",
                                label="observed difference",
                            )
                            plt.axvspan(
                                xmin=lower_diff,
                                xmax=upper_diff,
                                color="green",
                                alpha=0.2,
                                label=f"{ci_level*100:.0f}% CI",
                            )
                            if observed_diff >= 0:  # highlight values below 0
                                plt.axvspan(
                                    xmin=-1000,
                                    xmax=0,
                                    color="red",
                                    alpha=0.3,
                                    label="H0 true",
                                )
                            else:  # highlight values above 0
                                plt.axvspan(
                                    xmin=0,
                                    xmax=1000,
                                    color="red",
                                    alpha=0.3,
                                    label="H0 true",
                                )
                            plt.xlim(
                                bootstrapped_diffs.min()
                                - 0.25 * np.abs(bootstrapped_diffs.min()),
                                bootstrapped_diffs.max()
                                + 0.25 * np.abs(bootstrapped_diffs.max()),
                            )
                            plt.xlabel(
                                f"\u0394 {metric_name_dict[metric]} ({name_dict[comparator1]} - {name_dict[comparator2]}) in {article.upper()}"
                            )
                            plt.ylabel("Frequency")
                            plt.legend()
                            plt.show()

                        print(f"\t\t{metric_name_dict[metric]}:")
                        print(f"\t\t{name_dict[comparator1]}: {observed_metric1:.2f}")
                        print(f"\t\t{name_dict[comparator2]}: {observed_metric2:.2f}")
                        # greek letter delta = \u0394
                        print(
                            f"\t\t\u0394: {observed_diff:.2f} [{lower_diff:.2f},{upper_diff:.2f}]"
                        )
                        print(f"\t\tp-value: {p_value:.3f}")
                        print("\n")

In [None]:
def plot_confusion_article(
    gold_standards: list[str],
    comparators: list[str],
    articles: list[str] = articles,
    name_dict: dict[str, str] = name_dict,
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    size: float = 5,
    save_path: str = None,
):
    """plots confusion matrices for each article

    Args:
        gold_standards (list[str]): gold standard names to compare against
        comparators (list[str]): comparator names
        articles (list[str], optional): articles to plot. Defaults to articles.
        name_dict (dict[str, str], optional): dictionary of names. Defaults to name_dict.
        dfs_concat (dict[str, pd.DataFrame], optional): dictionary of dataframes. Defaults to dfs_concat.
        size (float, optional): size of plot. Defaults to 5.
    """
    for article in articles:
        print(f"Article: {article.upper()}")
        fig, axs = plt.subplots(
            len(comparators),
            len(gold_standards),
            sharex=True,
            sharey=True,
            figsize=(len(comparators) * size, len(gold_standards) * size),
        )
        for i, comparator in enumerate(comparators):
            print(f"{name_dict[comparator]}")
            for j, gold_standard in enumerate(gold_standards):
                print(f"\t{name_dict[gold_standard]}")
                cm = confusion_matrix(
                    dfs_concat[gold_standard].loc[
                        dfs_concat[gold_standard]["article"] == article
                    ][inclusion_field],
                    dfs_concat[comparator].loc[
                        dfs_concat[comparator]["article"] == article
                    ][inclusion_field],
                )
                tn, fp, fn, tp = cm.ravel()
                npv = tn / (tn + fn)
                ppv = tp / (tp + fp)
                sen = tp / (tp + fn)
                spec = tn / (tn + fp)
                plr = sen / (1 - spec)
                nlr = (1 - sen) / spec
                print(f"tn: {tn}, fp: {fp}, fn: {fn}, tp: {tp}")
                print(
                    f"NPV: {npv:.2f}, PPV: {ppv:.2f}, \nSensitivity: {sen:.2f}, Specificity: {spec:.2f}\nPLR: {plr:.2f}, NLR: {nlr:.2f}"
                )
                # normalize the confusion matrix
                cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
                if len(gold_standards) == len(comparators) == 1:
                    sns.heatmap(
                        cm,
                        annot=True,
                        xticklabels=["Excluded", "Included"],
                        yticklabels=["Excluded", "Included"],
                        cmap="Blues",
                        vmin=0,
                        vmax=1,
                        cbar=False,
                        square=True,
                        ax=axs,
                    )
                    axs.set_xlabel(f"{name_dict[comparator]}")
                    axs.set_ylabel(name_dict[gold_standard])
                    axs.set_title(
                        f"{article.upper()}: {name_dict[comparator]}",
                        fontsize=7.5,
                    )
                elif (
                    len(gold_standards) == 1 or len(comparators) == 1
                ):  # only one gold standard or comparator
                    sns.heatmap(
                        cm,
                        annot=True,
                        xticklabels=["Excluded", "Included"],
                        yticklabels=["Excluded", "Included"],
                        cmap="Blues",
                        vmin=0,
                        vmax=1,
                        cbar=False,
                        square=True,
                        ax=axs[i + j],
                    )
                    axs[i + j].set_xlabel(f"{name_dict[comparator]}")
                    axs[i + j].set_ylabel(name_dict[gold_standard])
                    axs[i + j].set_title(
                        f"{article.upper()}: {name_dict[comparator]}",
                        fontsize=7.5,
                    )
                else:
                    sns.heatmap(
                        cm,
                        annot=True,
                        xticklabels=["Excluded", "Included"],
                        yticklabels=["Excluded", "Included"],
                        cmap="Blues",
                        vmin=0,
                        vmax=1,
                        cbar=False,
                        square=True,
                        ax=axs[i, j],
                    )
                    axs[i, j].set_xlabel(f"{name_dict[comparator]}")
                    axs[i, j].set_ylabel(name_dict[gold_standard])
                    axs[i, j].set_title(
                        f"{article.upper()}: {name_dict[comparator]}",
                        fontsize=7.5,
                    )
                sleep(0.1)
        if save_path is not None:
            plt.savefig(save_path, dpi=300)
        else: 
            plt.show()


In [None]:
def print_c_index(
    gold_standards: list[str],
    comparators: list[str],
    articles: list[str] = articles,
    name_dict: dict[str, str] = name_dict,
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
):
    """prints the c-index for each article"""
    print("C-index ")
    for comparator in comparators:
        print(f"\n{name_dict[comparator]}")
        for gold_standard in gold_standards:
            print(f"\t{name_dict[gold_standard]}")
            for article in articles:
                print(
                    f"""{article.upper()}: {concordance_index(dfs_concat[gold_standard].loc[dfs_concat[gold_standard]['article'] == article][inclusion_field], 
                    dfs_concat[comparator].loc[dfs_concat[comparator]['article'] == article][inclusion_original_field]):.2f}""".rjust(
                        20
                    )
                )
            # print the total c-index
            print(
                f"""Total: {concordance_index(dfs_concat[gold_standard][inclusion_field], dfs_concat[comparator][inclusion_original_field]):.2f}""".rjust(
                    20
                ),
                end="\n\n",
            )

In [None]:
def print_jaccard(
    gold_standards: list[str],
    comparators: list[str],
    articles: list[str] = articles,
    name_dict: dict[str, str] = name_dict,
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    ci_level: float = 0.95,
    n_bootstraps: int = 1000,
    seed: int = 42,
):
    """prints the jaccard index for each article with confidence intervals"""
    print("Jaccard index")
    for comparator in comparators:
        print(f"\n{name_dict[comparator]}")
        for gold_standard in gold_standards:
            print(f"\t{name_dict[gold_standard]}")
            for article in articles:
                rng = np.random.default_rng(seed)
                bs_jaccard_scores = []
                for _ in range(n_bootstraps):
                    bs_indices = rng.choice(
                        range(
                            len(
                                dfs_concat[gold_standard].loc[
                                    dfs_concat[gold_standard]["article"] == article
                                ][inclusion_field]
                            )
                        ),
                        size=len(
                            dfs_concat[gold_standard].loc[
                                dfs_concat[gold_standard]["article"] == article
                            ][inclusion_field]
                        ),
                        replace=True,
                    )
                    bs_jaccard_scores.append(
                        jaccard_score(
                            dfs_concat[gold_standard]
                            .loc[dfs_concat[gold_standard]["article"] == article][
                                inclusion_field
                            ]
                            .iloc[bs_indices],
                            dfs_concat[comparator]
                            .loc[dfs_concat[comparator]["article"] == article][
                                inclusion_field
                            ]
                            .iloc[bs_indices],
                        )
                    )

                bs_jaccard_scores = np.array(bs_jaccard_scores)
                upper_bound = np.quantile(
                    bs_jaccard_scores,
                    (1 + ci_level) / 2,
                )
                lower_bound = np.quantile(
                    bs_jaccard_scores,
                    (1 - ci_level) / 2,
                )
                observed_jaccard = jaccard_score(
                    dfs_concat[gold_standard].loc[
                        dfs_concat[gold_standard]["article"] == article
                    ][inclusion_field],
                    dfs_concat[comparator].loc[
                        dfs_concat[comparator]["article"] == article
                    ][inclusion_field],
                )

                print(
                    f"""{article.upper()}: {observed_jaccard:.2f} [{lower_bound:.2f}, {upper_bound:.2f}]""".rjust(
                        20
                    )
                )
            # print the total jaccard index
            rng = np.random.default_rng(seed)
            bs_jaccard_scores = []
            for _ in range(n_bootstraps):
                bs_indices = rng.choice(
                    range(len(dfs_concat[gold_standard][inclusion_field])),
                    size=len(dfs_concat[gold_standard][inclusion_field]),
                    replace=True,
                )
                bs_jaccard_scores.append(
                    jaccard_score(
                        dfs_concat[gold_standard][inclusion_field].iloc[bs_indices],
                        dfs_concat[comparator][inclusion_field].iloc[bs_indices],
                    )
                )
            bs_jaccard_scores = np.array(bs_jaccard_scores)
            upper_bound = np.quantile(
                bs_jaccard_scores,
                (1 + ci_level) / 2,
            )
            lower_bound = np.quantile(
                bs_jaccard_scores,
                (1 - ci_level) / 2,
            )
            observed_jaccard = jaccard_score(
                dfs_concat[gold_standard][inclusion_field],
                dfs_concat[comparator][inclusion_field],
            )
            print(
                f"Total: {observed_jaccard:.2f} [{lower_bound:.2f}, {upper_bound:.2f}]".rjust(
                    20
                ),
                end="\n\n",
            )

In [None]:
def print_balanced_accuracy(
    gold_standards: list[str],
    comparators: list[str],
    articles: list[str] = articles,
    name_dict: dict[str, str] = name_dict,
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
):
    """prints the balanced accuracy score for each article"""
    print("Balanced Accuracy Score")
    for comparator in comparators:
        print(f"\n{name_dict[comparator]}")
        for gold_standard in gold_standards:
            print(f"\t{name_dict[gold_standard]}")
            for article in articles:
                print(
                    f"""{article.upper()}: {balanced_accuracy_score(dfs_concat[gold_standard].loc[dfs_concat[gold_standard]['article'] == article][inclusion_field], 
                        dfs_concat[comparator].loc[dfs_concat[comparator]['article'] == article][inclusion_field]):.2f}""".rjust(
                        20
                    )
                )
            # print the total accuracy score
            print(
                f"""Total: {balanced_accuracy_score(dfs_concat[gold_standard][inclusion_field], dfs_concat[comparator][inclusion_field]):.2f}""".rjust(
                    20
                ),
                end="\n\n",
            )

In [None]:
def print_balanced_accuracy_ci(
    gold_standards: list[str],
    comparators: list[str],
    articles: list[str] = articles,
    name_dict: dict[str, str] = name_dict,
    dfs_concat: dict[str, pd.DataFrame] = dfs_concat,
    ci_level: float = 0.95,
    n_bootstraps: int = 1000,
    seed: int = 42,
):
    """prints the balanced accuracy score for each article with confidence intervals

    Args:
        gold_standards (list[str]): gold standard names
        comparators (list[str]): comparator names
        articles (list[str], optional): article names. Defaults to articles.
        name_dict (dict[str, str], optional): name dictionary. Defaults to name_dict.
        dfs_concat (dict[str, pd.DataFrame], optional): dictionary of dataframes. Defaults to dfs_concat.
        ci_level (float, optional): confidence interval level. Defaults to 0.95.
        n_bootstraps (int, optional): number of bootstraps. Defaults to 1000.
        seed (int, optional): random seed. Defaults to 42.
    """
    print("Balanced Accuracy Score")
    for comparator in comparators:
        print(f"\n{name_dict[comparator]}")
        for gold_standard in gold_standards:
            print(f"\t{name_dict[gold_standard]}")
            for article in articles:
                rng = np.random.default_rng(seed)
                bs_accuracy_scores = []
                for _ in range(n_bootstraps):
                    bs_indices = rng.choice(
                        range(
                            len(
                                dfs_concat[gold_standard].loc[
                                    dfs_concat[gold_standard]["article"] == article
                                ][inclusion_field]
                            )
                        ),
                        size=len(
                            dfs_concat[gold_standard].loc[
                                dfs_concat[gold_standard]["article"] == article
                            ][inclusion_field]
                        ),
                        replace=True,
                    )
                    bs_accuracy_scores.append(
                        balanced_accuracy_score(
                            dfs_concat[gold_standard]
                            .loc[dfs_concat[gold_standard]["article"] == article][
                                inclusion_field
                            ]
                            .iloc[bs_indices],
                            dfs_concat[comparator]
                            .loc[dfs_concat[comparator]["article"] == article][
                                inclusion_field
                            ]
                            .iloc[bs_indices],
                        )
                    )

                bs_accuracy_scores = np.array(bs_accuracy_scores)
                upper_bound = np.quantile(
                    bs_accuracy_scores,
                    (1 + ci_level) / 2,
                )
                lower_bound = np.quantile(
                    bs_accuracy_scores,
                    (1 - ci_level) / 2,
                )
                observed_accuracy_score = balanced_accuracy_score(
                    dfs_concat[gold_standard].loc[
                        dfs_concat[gold_standard]["article"] == article
                    ][inclusion_field],
                    dfs_concat[comparator].loc[
                        dfs_concat[comparator]["article"] == article
                    ][inclusion_field],
                )

                print(
                    f"{article.upper()}: {observed_accuracy_score:.2f} [{lower_bound:.2f}, {upper_bound:.2f}]".rjust(
                        20
                    )
                )
            # print the total accuracy score
            rng = np.random.default_rng(seed)
            bs_accuracy_scores = []
            for _ in range(n_bootstraps):
                bs_indices = rng.choice(
                    range(len(dfs_concat[gold_standard][inclusion_field])),
                    size=len(dfs_concat[gold_standard][inclusion_field]),
                    replace=True,
                )
                bs_accuracy_scores.append(
                    balanced_accuracy_score(
                        dfs_concat[gold_standard][inclusion_field].iloc[bs_indices],
                        dfs_concat[comparator][inclusion_field].iloc[bs_indices],
                    )
                )
            bs_accuracy_scores = np.array(bs_accuracy_scores)
            upper_bound = np.quantile(
                bs_accuracy_scores,
                (1 + ci_level) / 2,
            )
            lower_bound = np.quantile(
                bs_accuracy_scores,
                (1 - ci_level) / 2,
            )
            observed_accuracy_score = balanced_accuracy_score(
                dfs_concat[gold_standard][inclusion_field],
                dfs_concat[comparator][inclusion_field],
            )
            print(
                f"Total: {observed_accuracy_score:.2f} [{lower_bound:.2f}, {upper_bound:.2f}]".rjust(
                    20
                ),
                end="\n\n",
            )

In [None]:
print_kappa_ci(human_raters + gpt_raters, name_dict)


In [None]:
plot_kappas_article(articles, human_raters + gpt_raters, size=7)


In [None]:
plot_kappas(raters, figsize=(6.5,6.5))

In [None]:
# compare gpt-scores to consensus by experts
plot_roc_ci(
    gold_standards=consensus_types,
    comparators=["gpt"],
    size=12,
    ci_level=0.95,
)

In [None]:
# compare gpt-scores to consensus by experts
plot_confusion(
    consensus_types,
    gpt_raters+gp_raters+consensus_types_gp,
    size=8,
)

In [None]:
print_metrics_ci(
    ["vote"],
    ["gpt"] + gp_raters + consensus_types_gp,
    name_dict,
    dfs_concat,
    ci_level=0.95,
    bootstrap_samples=1000,
)


In [None]:
compare_metrics(
    ["vote"],
    ["gpt", "vote_gp", "sens_con_gp"],
    metrics=["sen", "spec", "ppv", "npv"],
    name_dict=name_dict,
    bootstrap_samples=1000,
)

In [None]:
compare_metrics_article(
    ["vote"],
    ["gpt", "sens_con_gp"],
    name_dict,
    dfs_concat,
    ["colorectal"],
    ["sen", "spec"],
)


In [None]:
# compare gpt-scores to consensus by experts - article level
for article in articles:
    plot_confusion_article( 
        gold_standards=consensus_types,
        comparators=gpt_raters+gp_raters+consensus_types_gp,
        articles=[article],
        size=8,
    )

In [None]:
print_c_index(
    ["vote", "sens_con"], gpt_raters, name_dict, name_dict, articles, dfs_concat
)

In [None]:
print_jaccard(
    gold_standards=["vote", "sens_con"],
    comparators=gp_raters + gpt_raters,
)

In [None]:
print_balanced_accuracy(
    ["vote", "sens_con"],
    gp_raters + gpt_raters,
)


In [None]:
print_balanced_accuracy_ci(["vote"], gp_raters + consensus_types_gp)


In [None]:
print_jaccard(["vote"], gp_raters + consensus_types_gp)


In [None]:
total_inclusions = 0
for rater in human_raters:
    print(
        f"{name_dict[rater]} included {dfs_concat[rater][inclusion_field].sum()} articles out of {len(dfs_concat[rater])} ({dfs_concat[rater][inclusion_field].sum()/len(dfs_concat[rater]):.2%})"
    )
    total_inclusions += dfs_concat[rater][inclusion_field].sum()

In [None]:
for rater in raters:
    print(f"{name_dict[rater]}")
    for article in articles:
        print(
            f"{article.upper()}: {dfs_concat[rater].loc[dfs_concat[rater]['article'] == article][inclusion_field].sum()} out of {len(dfs_concat[rater].loc[dfs_concat[rater]['article'] == article])} ({dfs_concat[rater].loc[dfs_concat[rater]['article'] == article][inclusion_field].sum()/len(dfs_concat[rater].loc[dfs_concat[rater]['article'] == article]):.2%})"
        )