In [None]:
import os
import pandas as pd
import numpy as np
import pickle
from scipy.stats import norm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from upsetplot import UpSet, from_indicators
from sklearn.metrics import roc_curve, roc_auc_score
%config InlineBackend.figure_format='retina'

# Load data

In [None]:
data_folder = "XXX"

models = ['Random Forest', 'RAG', 'gpt-oss-20b', 'gpt-oss-120b', 'llama3.1-70b', 'deepseek-r1-70b']

# Define a main color for each model
main_colors = ["#4a6741", "#818181", "#3182bd", "#8856a7", "darkorange", "#f03b20"]

In [None]:
data_date = "XXX"
data_path = "XXX"
data_df = pd.read_csv(data_path, low_memory=False)
data_df

In [None]:
pathogen = "all-viral"
train_test_data_folder = f"{data_folder}/test_train_splits"

test_data_path = f"{train_test_data_folder}/X_test_{pathogen}_rf.pkl"
test_df = pd.read_pickle(test_data_path)

# Add ground truth labels to test_df
test_df = test_df.merge(data_df[['record_id','all-viral_label']], on='record_id', how='left')

Y_TRUE = test_df['all-viral_label']

test_df

In [None]:
# Get model breakdowns for LLMs
temp = 0.5
prompt = "short"

def get_result_paths(model):
    return [
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge",
        # f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_json2text",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/llm_training_summary_subset_v3",
        f"{data_folder}/model_llm/{model}/data_test_rf/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/llm_training_summary_subset_rf_v3",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_rag",
        # f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_rag_json2text",
    ]

# For supplementary figures
def get_result_paths_supp(model):
    return [
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_json2text",
        # f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/llm_training_summary_subset_v3",
        # f"{data_folder}/model_llm/{model}/data_test_rf/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/llm_training_summary_subset_rf_v3",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_rag",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_rag_json2text",
    ]

def get_result_paths_all(model):
    return [
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_json2text",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/llm_training_summary_subset_v3",
        f"{data_folder}/model_llm/{model}/data_test_rf/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/llm_training_summary_subset_rf_v3",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_rag",
        f"{data_folder}/model_llm/{model}/data_test/system_prompt_{prompt}/temp_{str(temp).replace(".", "-")}/no_knowledge_rag_json2text",
    ]

# Rename breakdowns in plots
k_naming_dict = {
    "no_knowledge": "zero-shot",
    "no_knowledge_json2text": "zero-shot\n(natural language)",
    "llm_training_summary_subset_v2": "medical context",
    "llm_training_summary_subset_rf_v2": "medical context + RF",
    "llm_training_summary_subset_v3": "medical context",
    "llm_training_summary_subset_rf_v3": "medical context + RF",
    "no_knowledge_rag": "RAG",
    "no_knowledge_rag_json2text": "RAG\n(natural language)",
}

rag_label = "RAG (w/o LLM)"

# Plot NaN/unknown/yes/no 'viral' results returned by LLMs

In [None]:
value_counts_dict = {}   # {label: {run_id: Series}}
labels = []              # x-axis labels aligned to insertion order
RUN_IDS = [1, 2, 3]

def normalize_label(val):
    return str(val).strip().lower()

for model in models:
    if model in ("Random Forest", "RAG"):
        continue

    paths = get_result_paths_all(model)
    for path in paths:
        per_run_counts = {}  # run_id -> Series
        any_data = False

        for run_id in RUN_IDS:
            path_to_csv = f"{path}/{model}_run0{run_id}_predictions.csv"
            if not os.path.exists(path_to_csv):
                continue

            pred_df = pd.read_csv(path_to_csv)
            if "viral" not in pred_df.columns:
                continue

            vals = [normalize_label(v) for v in pred_df["viral"].tolist()]
            if len(vals) == 0:
                continue

            vc = pd.Series(vals).value_counts(dropna=False)
            per_run_counts[run_id] = vc
            any_data = True

        if not any_data:
            continue

        knowledge_type = path.split("/")[-1]
        nice_name = k_naming_dict.get(knowledge_type, knowledge_type)
        label = f"{model} {nice_name}"

        value_counts_dict[label] = per_run_counts
        labels.append(label)

# --- Fixed class order ---
CLASS_ORDER = ["yes", "no", "unknown", "nan"]  # desired stack order (bottom -> top)

# Build tensor: (n_runs, n_labels, n_classes) in the fixed order
n_labels = len(labels)
n_classes = len(CLASS_ORDER)
n_runs = len(RUN_IDS)

plot_data = np.zeros((n_runs, n_labels, n_classes), dtype=float)

for li, label in enumerate(labels):
    per_run = value_counts_dict[label]  # dict run_id -> Series
    for ri, run_id in enumerate(RUN_IDS):
        vc = per_run.get(run_id, pd.Series(dtype=float))
        # Ensure VC keys are normalized strings
        vc_norm = pd.Series({normalize_label(k): v for k, v in vc.to_dict().items()})
        for ci, cls in enumerate(CLASS_ORDER):
            plot_data[ri, li, ci] = float(vc_norm.get(cls, 0))

# Colors for classes (keep as provided)
custom_colors = [
    "darkblue",
    "cornflowerblue",
    "grey",
    "tab:red",
]
COLOR_MAP = {cls: custom_colors[i % len(custom_colors)] for i, cls in enumerate(CLASS_ORDER)}

# Plot: 3 close-by stacked bars per label (small gap between bars)
fig, ax = plt.subplots(figsize=(12, 5))

x = np.arange(n_labels)
bar_width = 0.25
gap = 0.03  # small space between the 3 bars within each label group

# Offsets to center the 3 bars around each label position
offsets = np.linspace(
    -((n_runs - 1) * (bar_width + gap)) / 2,
    ((n_runs - 1) * (bar_width + gap)) / 2,
    n_runs
)

# draw stacks per run
for ri, run_id in enumerate(RUN_IDS):
    bottoms = np.zeros(n_labels)
    for ci, cls in enumerate(CLASS_ORDER):
        ax.bar(
            x + offsets[ri],
            plot_data[ri, :, ci],
            width=bar_width,
            bottom=bottoms,
            label=cls if (ri == 0) else None,  # legend once
            color=COLOR_MAP[cls],
            edgecolor="none",
        )
        bottoms += plot_data[ri, :, ci]

ax.margins(x=0.01)
ax.set_ylabel("Count")
ax.set_title("Model Output for Viral Predictions", fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=45, ha="right")

# Legend for classes (runs encoded by side-by-side bars)
handles, class_labels = ax.get_legend_handles_labels()
legend = ax.legend(handles, class_labels, title="Value", loc="lower left", ncol=2)
legend.get_frame().set_alpha(0.9)

plt.tight_layout()
fig.savefig("figures/LLM_viral_output_value_counts.png", dpi=300, bbox_inches="tight")
plt.show()


# Prediction histograms 

In [None]:
def plot_hist(y_scores, y_labels, ax, title, legend=True, fontsize=12, include_unknowns=True):
    linestyles = ['solid', 'dashed', 'dotted']
    bins = 20
    for i, (score_list, label_list) in enumerate(zip(y_scores, y_labels)):
        linestyle = linestyles[i]
        if i == 0:
            ax.hist(
                [score for score, label in zip(score_list, label_list) if label == "yes"],
                bins=bins,
                label="yes",
                color="red",
                histtype='step',
                linestyle=linestyle
            )
            ax.hist(
                [score for score, label in zip(score_list, label_list) if label == "no"],
                bins=bins,
                label="no",
                color="black",
                histtype='step',
                linestyle=linestyle
            )
            if include_unknowns:
                ax.hist(
                    [score for score, label in zip(score_list, label_list) if label == "unkown"],
                    bins=bins,
                    label="unkown",
                    color="grey",
                    histtype='stepfilled',
                    linestyle=linestyle
                )
        else:
            ax.hist(
                [score for score, label in zip(score_list, label_list) if label == "yes"],
                bins=bins,
                color="red",
                histtype='step',
                linestyle=linestyle
            )
            ax.hist(
                [score for score, label in zip(score_list, label_list) if label == "no"],
                bins=bins,
                color="black",
                histtype='step',
                linestyle=linestyle
            )
            if include_unknowns:
                ax.hist(
                    [score for score, label in zip(score_list, label_list) if label == "unknown"],
                    bins=bins,
                    color="grey",
                    histtype='stepfilled',
                    linestyle=linestyle
                )

    # ax.set_title(title, fontsize=fontsize)
    if legend:
        ax.legend(title="Predicted Viral", fontsize=fontsize, title_fontsize=fontsize, loc="upper right")
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    ax.margins(x=0.01)
    ax.grid(True, color='grey', linestyle='--', linewidth=0.7, alpha=0.5)
    ax.set_axisbelow(True)
    ax.set_xlim(0,100)

In [None]:
def plot_model_prediction_histograms(
    models,
    include_unknowns=False,
    test_df=None,
    train_test_data_folder=None,
    pathogen=None,
    get_result_paths=None,
    k_naming_dict=None,
    Y_TRUE=None,
    figsize=None,
    save_path="figures/LLM_test_data_preds.png",
    y_title_padding=0.047
):
    """
    Plots histograms of predicted viral probabilities for a list of models.
    
    Parameters:
        models (list): List of model names to plot.
        include_unknowns (bool): Whether to include 'unknown' predictions in the histograms.
        test_df (pd.DataFrame): DataFrame with test set and 'all-viral_label' and 'record_id' columns.
        train_test_data_folder (str): Path to folder with prediction files.
        pathogen (str): Pathogen name for file naming.
        get_result_paths (callable): Function to get result paths for a model.
        k_naming_dict (dict): Dictionary mapping knowledge type to display name.
        Y_TRUE (list/array): Ground truth labels for test set.
        save_path (str): Path to save the resulting figure.
    """
    # Calculate number of rows for the subplot grid
    nrows = (len(models)-2)*len(get_result_paths(models[0]))+1
    if figsize==None:
        figsize=(14, 2*nrows)
    fig, axs = plt.subplots(figsize=figsize, nrows=nrows, ncols=2, sharex=True)
    fontsize = 12

    fig.subplots_adjust(hspace=0.4, wspace=0.20)

    # Define consistent bins for all histograms (0-100 for probabilities as in your data)
    hist_bins = np.linspace(0, 100, 21)  # 20 bins from 0 to 100

    # Get record IDs of confirmed pos/neg patients
    pos_record_ids = set(test_df.loc[test_df['all-viral_label'] == 1, 'record_id'])
    neg_record_ids = set(test_df.loc[test_df['all-viral_label'] == 0, 'record_id'])

    row_idx = 0
    for model in models:

        if model == "Random Forest":
            seeds = ["", "_02", "_03"]
            y_scores_pos, y_scores_neg = [], []
            y_labels_pos, y_labels_neg = [], []

            for seed in seeds:
                path_to_pkl = f"{train_test_data_folder}/X_test_{pathogen}_rf{seed}.pkl"
                try:
                    with open(path_to_pkl, "rb") as f:
                        pred_df = pickle.load(f)
                except FileNotFoundError:
                    print(f"Warning: File not found: {path_to_pkl}")
                    continue

                if not len(test_df) == len(pred_df):
                    continue

                # Filter positive and negative samples
                subset_pos = pred_df[pred_df['record_id'].isin(pos_record_ids)]
                subset_neg = pred_df[pred_df['record_id'].isin(neg_record_ids)]

                # Ensure numeric dtype
                y_score_list_pos = pd.to_numeric(subset_pos['probability_of_viral_rf'], errors='coerce').dropna()
                y_score_list_neg = pd.to_numeric(subset_neg['probability_of_viral_rf'], errors='coerce').dropna()

                y_scores_pos.append(y_score_list_pos)
                y_scores_neg.append(y_score_list_neg)

                y_label_list_pos = subset_pos['viral_rf']
                y_label_list_neg = subset_neg['viral_rf']
                y_labels_pos.append(y_label_list_pos)
                y_labels_neg.append(y_label_list_neg)

            # Plot histograms for positive and negative samples
            plot_hist(y_scores_pos, y_labels_pos, axs[row_idx, 0], title=model, legend=False, fontsize=fontsize, include_unknowns=include_unknowns)
            plot_hist(y_scores_neg, y_labels_neg, axs[row_idx, 1], title=model, fontsize=fontsize, include_unknowns=include_unknowns)

            # Place the row title centered above both subplots, directly above the current row
            # Use the y1 of the uppermost axis in the row for the title position
            bbox0 = axs[row_idx, 0].get_position()
            bbox1 = axs[row_idx, 1].get_position()
            y_top = max(bbox0.y0, bbox1.y0)  # Use y0 (bottom) instead of y1 (top) for correct placement
            x_left = bbox0.x0
            x_right = bbox1.x1
            x_center = (x_left + x_right) / 2

            # Place the title just above the axes row, with a small offset
            fig.text(x_center, 
                     y_top + y_title_padding,
                     model, 
                     ha='center', va='bottom', fontsize=fontsize)

            row_idx += 1

        elif model == "RAG":
            path_to_pkl = f"{train_test_data_folder}/X_test_{pathogen}_rag.pkl"
            pred_df = pd.read_pickle(path_to_pkl)

            # Filter positive and negative samples
            subset_pos = pred_df[pred_df['record_id'].isin(pos_record_ids)]
            subset_neg = pred_df[pred_df['record_id'].isin(neg_record_ids)]
            
            # Ensure numeric dtype
            y_scores_pos = pd.to_numeric(subset_pos['weighted_averages']*100, errors='coerce').dropna()
            y_scores_neg = pd.to_numeric(subset_neg['weighted_averages']*100, errors='coerce').dropna()

            y_labels_pos = subset_pos['weighted_averages_round'].map({0: "no", 1: "yes"})
            y_labels_neg = subset_neg['weighted_averages_round'].map({0: "no", 1: "yes"})

            # Plot histograms for positive and negative samples
            plot_hist([y_scores_pos], [y_labels_pos], axs[row_idx, 0], title=rag_label, legend=False, fontsize=fontsize, include_unknowns=include_unknowns)
            plot_hist([y_scores_neg], [y_labels_neg], axs[row_idx, 1], title=rag_label, legend=False, fontsize=fontsize, include_unknowns=include_unknowns)

            # Place the row title centered above both subplots, directly above the current row
            bbox0 = axs[row_idx, 0].get_position()
            bbox1 = axs[row_idx, 1].get_position()
            y_top = max(bbox0.y0, bbox1.y0)
            x_left = bbox0.x0
            x_right = bbox1.x1
            x_center = (x_left + x_right) / 2

            fig.text(x_center, 
                     y_top + y_title_padding,
                     rag_label, 
                     ha='center', va='bottom', fontsize=fontsize)

            row_idx += 1

        else:
            paths = get_result_paths(model)
            for i, path in enumerate(paths):

                if "deepseek" in model and "_rf" in path:
                    continue

                y_scores_pos, y_scores_neg = [], []
                y_labels_pos, y_labels_neg = [], []

                for run_id in [1, 2, 3]:
                    path_to_csv = f"{path}/{model}_run0{str(run_id)}_predictions.csv"
                    if not os.path.exists(path_to_csv):
                        continue
                    pred_df = pd.read_csv(path_to_csv)

                    y_score_list = pred_df['probability_of_viral']
                    if len(y_score_list) != len(Y_TRUE):
                        continue

                    # Filtering logic for unknowns
                    if include_unknowns:
                        allowed_labels = ['yes', 'no', 'unknown']
                        filtered_pred_df = pred_df[
                            pred_df['viral'].isin(allowed_labels) &
                            pred_df['probability_of_viral'].notna()
                        ]
                    else:
                        allowed_labels = ['yes', 'no']
                        filtered_pred_df = pred_df[
                            pred_df['viral'].isin(allowed_labels) &
                            pred_df['probability_of_viral'].notna() &
                            (pred_df['probability_of_viral'] != 'unknown')
                        ]

                    # Filter positive and negative samples using row indices
                    pos_indices = test_df.index[test_df['all-viral_label'] == 1]
                    neg_indices = test_df.index[test_df['all-viral_label'] == 0]

                    subset_pos = filtered_pred_df.loc[filtered_pred_df.index.intersection(pos_indices)]
                    subset_neg = filtered_pred_df.loc[filtered_pred_df.index.intersection(neg_indices)]

                    # Ensure numeric dtype and drop NaNs
                    y_score_list_pos = pd.to_numeric(subset_pos['probability_of_viral'], errors='coerce').dropna()
                    y_score_list_neg = pd.to_numeric(subset_neg['probability_of_viral'], errors='coerce').dropna()
                    
                    y_scores_pos.append(y_score_list_pos)
                    y_scores_neg.append(y_score_list_neg)

                    y_label_list_pos = subset_pos['viral']
                    y_label_list_neg = subset_neg['viral']
                    y_labels_pos.append(y_label_list_pos)
                    y_labels_neg.append(y_label_list_neg)

                if len(y_scores_pos) == 0:
                    continue

                knowledge_type = path.split("/")[-1]
                label = f"{model} ({k_naming_dict[knowledge_type]})"
                plot_hist(y_scores_pos, y_labels_pos, axs[row_idx, 0], title=label, legend=False, fontsize=fontsize, include_unknowns=include_unknowns)
                plot_hist(y_scores_neg, y_labels_neg, axs[row_idx, 1], title=label, legend=False, fontsize=fontsize, include_unknowns=include_unknowns)

                # Place the row title centered above both subplots, directly above the current row
                bbox0 = axs[row_idx, 0].get_position()
                bbox1 = axs[row_idx, 1].get_position()
                y_top = max(bbox0.y0, bbox1.y0)
                x_left = bbox0.x0
                x_right = bbox1.x1
                x_center = (x_left + x_right) / 2

                fig.text(x_center, 
                        y_top + y_title_padding,
                        label, 
                        ha='center', va='bottom', fontsize=fontsize)

                row_idx += 1

    # Add labels for the whole plot
    fig.text(0.05, 0.5, 'Count', va='center', rotation='vertical', fontsize=fontsize+2)
    fig.text(0.5, 0.075, 'Predicted Viral Probability', ha='center', va='center', fontsize=fontsize+2)
    fig.text(0.3, 0.91, 'Viral', ha='center', fontsize=fontsize+2)
    fig.text(0.7, 0.91, 'Non-viral', ha='center', fontsize=fontsize+2)

    # plt.tight_layout()

    fig.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()

In [None]:
# plot_model_prediction_histograms(
#     models,
#     True,  # include unknowns
#     test_df,
#     train_test_data_folder,
#     pathogen,
#     get_result_paths,
#     k_naming_dict,
#     Y_TRUE,
#     save_path="figures/LLM_test_data_preds_all.png"
# )

In [None]:
plot_model_prediction_histograms(
    models,
    True, # include unknowns
    test_df,
    train_test_data_folder,
    pathogen,
    get_result_paths,
    k_naming_dict,
    Y_TRUE,
    figsize=(6.5,26),
    y_title_padding=0.045,
    save_path="figures/LLM_test_data_preds_all.png"
)

In [None]:
# models_short = ['Random Forest', 'RAG', 'gpt-oss-120b']

# plot_model_prediction_histograms(
#     models_short,
#     False, # include unknowns
#     test_df,
#     train_test_data_folder,
#     pathogen,
#     get_result_paths,
#     k_naming_dict,
#     Y_TRUE,
#     figsize=(6.5,13),
#     y_title_padding=0.105,
#     save_path="figures/LLM_test_data_preds.png"
# )

# ROC curve

In [None]:
# For each model, generate three shades of its main color for the three paths
def get_shades(color, n=3):
    # Returns n shades from main (darkest) to lightest color
    rgb = mcolors.to_rgb(color)
    factors = np.linspace(0.25, 1, n)
    shades = [mcolors.to_hex(tuple(1 - (1 - c) * f for c in rgb)) for f in factors]
    return shades

In [None]:
def plot_roc_curve(y_trues, y_scores, ax, color, label, ls=None):
    """
    y_trues: either a 1D list/array of ground-truth binary labels (0/1)
    y_scores: either a 1D list/array of scores OR a list of lists (e.g., [scores1, scores2, scores3])
    """
    # Helper function to check if input is a 1D array-like (not a list of lists)
    def is_1d_array_like(x):
        if isinstance(x, pd.Series):
            return True
        if isinstance(x, np.ndarray) and x.ndim == 1:
            return True
        if isinstance(x, list) and (len(x) == 0 or not isinstance(x[0], (list, np.ndarray, pd.Series))):
            return True
        return False

    # Case 1: Both y_trues and y_scores are 1D arrays/lists
    if is_1d_array_like(y_scores) and is_1d_array_like(y_trues):
        fpr, tpr, _ = roc_curve(y_trues, y_scores)
        auc_value = roc_auc_score(y_trues, y_scores)
        # label = f'{label}\n(AUC = {auc_value:.2f})'
        ax.plot(fpr, tpr, label=label, lw=2, color=color, ls=ls)
    else:
        # Case 2: y_scores is a list of arrays/lists, and y_trues is either a single vector or a list of vectors
        fpr_grid = np.linspace(0.0, 1.0, 101)
        tpr_stack = []
        aucs = []

        # If y_trues is a single vector, repeat it for each y_scores entry
        if is_1d_array_like(y_trues):
            y_trues_list = [y_trues] * len(y_scores)
        else:
            y_trues_list = y_trues

        for yt, ys in zip(y_trues_list, y_scores):
            fpr, tpr, _ = roc_curve(yt, ys)
            aucs.append(roc_auc_score(yt, ys))
            # Interpolate TPR on a common FPR grid
            tpr_i = np.interp(fpr_grid, fpr, tpr)
            tpr_i[0] = 0.0
            tpr_i[-1] = 1.0
            tpr_stack.append(tpr_i)

        tpr_stack = np.vstack(tpr_stack)
        mean_tpr = tpr_stack.mean(axis=0)
        lo_tpr = tpr_stack.min(axis=0)
        hi_tpr = tpr_stack.max(axis=0)

        ax.plot(
            fpr_grid,
            mean_tpr,
            lw=2,
            # label=f'{label}\n(Avg AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
            label=label,
            color=color,
            ls=ls,
        )
        ax.fill_between(fpr_grid, lo_tpr, hi_tpr, alpha=0.2, color=color)

In [None]:
def plot_roc_curve_panel(get_result_paths, figname="ROC_curves_all-viral_predictions", skip_deepseek=False):
    fig, axs = plt.subplots(figsize=(13, 13), ncols=2, nrows=2)
    fontsize = 16

    # Flatten axs for easier indexing, but keep 2x2 structure for barplot
    axs_flat = axs.flatten()

    # We'll collect AUCs for all (model, path) combinations for the barplot
    auc_barplot_entries = []  # List of dicts: {'model': ..., 'path': ..., 'auc_mean': ..., 'auc_std': ..., 'n_runs': ..., 'color': ..., 'hatch': ...}

    col_row_idx = 0
    for color, model in zip(main_colors, models):

        aucs_for_model = []

        if model == "Random Forest":
            y_scores = []
            aucs_rf = []
            for i, seed in enumerate(["", "_02", "_03"]):
                path_to_pkl = f"{train_test_data_folder}/X_test_{pathogen}_rf{seed}.pkl"
                with open(path_to_pkl, "rb") as f:
                    pred_df = pickle.load(f)

                y_score_list = pred_df['probability_of_viral_rf']
                y_scores.append(y_score_list)

                auc_val = roc_auc_score(Y_TRUE, y_score_list)
                aucs_rf.append(auc_val)
                aucs_for_model.append(auc_val)

            label = model

            for ax in axs_flat:
                plot_roc_curve(Y_TRUE, y_scores, ax, color=color, label=label)

            # For barplot: aggregate across seeds/runs
            auc_barplot_entries.append({
                'model': model,
                'path': "RF",
                'auc_mean': np.mean(aucs_rf),
                'auc_std': np.std(aucs_rf, ddof=1) if len(aucs_rf) > 1 else 0,
                'n_runs': len(aucs_rf),
                'color': color,
                'hatch': None
            })

        elif model == "RAG":
            path_to_pkl = f"{train_test_data_folder}/X_test_{pathogen}_rag.pkl"
            pred_df = pd.read_pickle(path_to_pkl)

            y_scores_weighted = pred_df['weighted_averages']
            label = rag_label
            for ax in axs_flat:
                plot_roc_curve(Y_TRUE, y_scores_weighted, ax, color=color, label=label)

            auc_val = roc_auc_score(Y_TRUE, y_scores_weighted)
            auc_barplot_entries.append({
                'model': model,
                'path': rag_label,
                'auc_mean': auc_val,
                'auc_std': 0,
                'n_runs': 1,
                'color': color,
                'hatch': None
            })
            aucs_for_model.append(auc_val)

        else:
            paths = get_result_paths(model)
            path_shades = get_shades(color, n=len(paths)-1)
            for i, path in enumerate(paths):

                if skip_deepseek:
                    if "deepseek" in model or "llama" in model:
                        continue

                if "deepseek" in model and "_rf" in path:
                    continue

                y_scores = []
                y_trues = []
                aucs_this_path = []
                for run_id in [1, 2, 3]:
                    path_to_csv = f"{path}/{model}_run0{str(run_id)}_predictions.csv"
                    if not os.path.exists(path_to_csv):
                        continue
                    pred_df = pd.read_csv(path_to_csv)

                    y_score_list = pred_df['probability_of_viral']
                    if len(y_score_list) != len(Y_TRUE):
                        continue

                    # Replace missing predictions (NaN or 'unknown') with 0  !!! Is this a good way to handle this??
                    nan_count = y_score_list.isna().sum()
                    unknown_count = (y_score_list == 'unknown').sum()
                    if nan_count > 0 or unknown_count > 0:
                        if nan_count > 0:
                            print(f"WARNING: {nan_count} NaN values found in 'probability_of_viral' for {path_to_csv}")
                        if unknown_count > 0:
                            print(f"WARNING: {unknown_count} 'unknown' values found in 'probability_of_viral' for {path_to_csv}")
                        
                        # Remove missing/unknown/nan values from y_score_list and remove the same indices from y_true
                        mask = ~(y_score_list.isna() | (y_score_list == 'unknown'))
                        y_score_list = y_score_list[mask].astype(float)
                        y_true_filtered = np.array(Y_TRUE)[mask.values] if isinstance(Y_TRUE, (list, np.ndarray)) else Y_TRUE[mask]
                    else:
                        y_true_filtered = Y_TRUE.copy()

                    y_trues.append(y_true_filtered)
                    y_scores.append(y_score_list)

                    # Compute AUC for this run
                    try:
                        auc_val = roc_auc_score(y_true_filtered, y_score_list)
                        aucs_this_path.append(auc_val)
                        aucs_for_model.append(auc_val)
                    except Exception:
                        pass

                knowledge_type = path.split("/")[-1]
                label = f"{k_naming_dict[knowledge_type]}"

                if len(y_scores) == 0:
                    continue
                elif len(y_scores) == 1:
                    y_scores = y_scores[0]
                if len(y_trues) == 1:
                    y_trues = y_trues[0]

                if "rag" in path:
                    ls = "--"
                    color_to_use = path_shades[i-1]
                else:
                    ls = None
                    color_to_use = path_shades[i]

                ax = axs_flat[col_row_idx]
                ax.set_title(model, fontsize=fontsize+2, fontweight="bold", color=color_to_use)
                plot_roc_curve(y_trues, y_scores, ax, color=color_to_use, label=label, ls=ls)

                # For barplot: aggregate across runs for this (model, path)
                if len(aucs_this_path) > 0:
                    auc_barplot_entries.append({
                        'model': model,
                        'path': f"{k_naming_dict[knowledge_type]}",
                        'auc_mean': np.mean(aucs_this_path),
                        'auc_std': np.std(aucs_this_path, ddof=1) if len(aucs_this_path) > 1 else 0,
                        'n_runs': len(aucs_this_path),
                        'color': color_to_use,
                        'hatch': "/" if "rag" in path else None
                    })

            col_row_idx += 1

    # Add diagonal line and formatting to ROC subplots (first 3)
    for ax in axs_flat:
        ax.plot([0, 1], [0, 1], color='lightgrey', linestyle='-', label='Random guess', lw=2)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.02])
        ax.legend(loc='lower right', fontsize=fontsize)
        ax.tick_params(axis='both', which='major', labelsize=fontsize)
        ax.grid(True, color='grey', linestyle='--', linewidth=0.7, alpha=0.5)
        ax.set_axisbelow(True)
        # ax.set_ylabel("True Positive Rate", fontsize=fontsize+2)
        # ax.set_xlabel("False Positive Rate", fontsize=fontsize+2)

    fig.text(0.5, -0.01, 'False Positive Rate', ha='center', va='center', fontsize=fontsize+4)
    fig.text(-0.01, 0.5, 'True Positive Rate', ha='center', va='center', rotation='vertical', fontsize=fontsize+4)


    # ## Barplot of AUCs for all (model, path) combinations in the last subplot (bottom right)
    # ax_bar = axs_flat[3]

    # # Prepare barplot data
    # bar_labels = []
    # bar_aucs = []
    # bar_stds = []
    # bar_colors = []
    # bar_hatches = []
    # for entry in auc_barplot_entries:
    #     if entry['model'] == "Random Forest":
    #         bar_labels.append(entry['model'])
    #     else:
    #         bar_labels.append(entry['path'])
    #     bar_aucs.append(entry['auc_mean'])
    #     bar_stds.append(entry['auc_std'])
    #     bar_colors.append(entry['color'])
    #     bar_hatches.append(entry.get('hatch', None))

    # bar_positions = np.arange(len(bar_labels))
    # bars = ax_bar.bar(bar_positions, bar_aucs, color=bar_colors, alpha=0.7, yerr=bar_stds, capsize=5)

    # # Set hatch for RAG bar
    # for bar, hatch in zip(bars, bar_hatches):
    #     if hatch is not None:
    #         bar.set_hatch(hatch)

    # # Add numbers above the bars
    # for idx, (bar, auc, std) in enumerate(zip(bars, bar_aucs, bar_stds)):
    #     height = bar.get_height() + std if std > 0 else bar.get_height()
    #     # Format: mean ± std, but only show std if > 0
    #     # if std > 0:
    #     #     label_text = f"{auc:.2f}±{std:.2f}"
    #     # else:
    #     label_text = f"{auc:.2f}"
    #     ax_bar.text(
    #         bar.get_x() + bar.get_width() / 2,
    #         height + 0.01,
    #         label_text,
    #         ha='center',
    #         va='bottom',
    #         fontsize=fontsize-2
    #     )

    # ax_bar.axhline(0.5, color='lightgrey', linestyle='-', lw=2)
    # ax_bar.tick_params(axis='both', which='major', labelsize=fontsize)
    # ax_bar.set_xticks(bar_positions)
    # ax_bar.set_xticklabels(bar_labels, rotation=45, ha='right', fontsize=fontsize-1)
    # ax_bar.set_ylabel("AUC", fontsize=fontsize+2)
    # # ax_bar.set_title("Area Under the Curve (AUC)", fontsize=fontsize+2)
    # ax_bar.set_ylim(0, 1)
    # ax_bar.grid(True, axis='y', color='grey', linestyle='--', linewidth=0.7, alpha=0.5)
    # ax_bar.set_axisbelow(True)
    # ax_bar.margins(x=0.02)

    plt.tight_layout()

    fig.savefig(f"figures/{figname}.png", dpi=300, bbox_inches="tight")
    plt.show()

    return auc_barplot_entries

auc_barplot_entries = plot_roc_curve_panel(get_result_paths)

Same but for natural language models:

In [None]:
_ = plot_roc_curve_panel(get_result_paths_supp, figname="ROC_curves_all-viral_predictions_json2text", skip_deepseek=True)

# Plot accuracy, sensitivity, specificity

In [None]:
def compute_metrics(y_true, y_pred):
    """Compute accuracy, sensitivity, specificity for binary classification."""
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    # Convert predicted labels to 1/0
    y_pred_bin = np.array([1 if x == "yes" else 0 for x in y_pred])
    # Accuracy
    acc = np.mean(y_pred_bin == y_true)
    # Sensitivity (Recall, True Positive Rate)
    if np.sum(y_true == 1) > 0:
        sens = np.sum((y_pred_bin == 1) & (y_true == 1)) / np.sum(y_true == 1)
    else:
        sens = np.nan
    # Specificity (True Negative Rate)
    if np.sum(y_true == 0) > 0:
        spec = np.sum((y_pred_bin == 0) & (y_true == 0)) / np.sum(y_true == 0)
    else:
        spec = np.nan
    return acc, sens, spec

def mean_std(values):
    """Compute mean and standard deviation for a list of values."""
    arr = np.array(values)
    mean = np.nanmean(arr)
    std = np.nanstd(arr, ddof=1)
    return mean, std

def plot_metrics_with_std(y_trues, y_labels, axes, x_idx, color, fontsize=12, hatch=None):
    accs, senss, specs = [], [], []
    for trues, preds in zip(y_trues, y_labels):
        acc, sens, spec = compute_metrics(trues, preds)
        accs.append(acc)
        senss.append(sens)
        specs.append(spec)
    metrics = [accs, senss, specs]
    metric_names = ["Accuracy", "Sensitivity", "Specificity"]

    bw = 0.7
    for i, (vals, ax) in enumerate(zip(metrics, axes)):
        mean, std = mean_std(vals)
        ax.bar([x_idx], [mean], yerr=[[std], [std]], color=color, alpha=0.8, capsize=4, width=bw, hatch=hatch)
        ax.set_ylim(0, 1)
        # ax.set_xticks([0])
        ax.set_ylabel(metric_names[i], fontsize=fontsize+2)
        # ax.set_title(f"{metric_names[i]} (mean ± std)", fontsize=fontsize)
        ax.grid(axis='y', linestyle='--', alpha=0.5)
        ax.set_axisbelow(True)
        ax.tick_params(axis='both', which='major', labelsize=fontsize)
        ax.margins(x=0.02)

In [None]:
fig, axs_all = plt.subplots(figsize=(10, 13), nrows=4, sharex=True)
fontsize = 16

x_idx = 0
xticklabels = []
xtick_positions = []

# Helper to add numbers above bars
def annotate_bars(ax, xpos, mean, std, fontsize=10):
    """Annotate a single bar at xpos with mean ± std."""
    # Find the bar (there should be only one at xpos)
    for bar in ax.patches:
        # bar.get_x() is left edge, bar.get_width() is width
        if abs(bar.get_x() + bar.get_width()/2 - xpos) < 1e-6:
            height = bar.get_height()
            # ax.annotate(f"{mean:.2f}±{std:.2f}", 
            ax.annotate(f"{mean:.2f}", 
                        xy=(bar.get_x() + bar.get_width()/2, height + std if not np.isnan(std) else height),
                        xytext=(0, 2),  # 2 points vertical offset
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=fontsize)
            break


## Barplot of AUCs for all (model, path) combinations
ax_bar = axs_all[0]

# Prepare barplot data
bar_labels = []
bar_aucs = []
bar_stds = []
bar_colors = []
bar_hatches = []
for entry in auc_barplot_entries:
    if entry['model'] == "Random Forest":
        bar_labels.append(entry['model'])
    else:
        bar_labels.append(entry['path'])
    bar_aucs.append(entry['auc_mean'])
    bar_stds.append(entry['auc_std'])
    bar_colors.append(entry['color'])
    bar_hatches.append(entry.get('hatch', None))

bar_positions = np.arange(len(bar_labels))
bars = ax_bar.bar(bar_positions, bar_aucs, color=bar_colors, alpha=0.7, yerr=bar_stds, capsize=5)

# Set hatch for RAG bar
for bar, hatch in zip(bars, bar_hatches):
    if hatch is not None:
        bar.set_hatch(hatch)

# Add numbers above the bars
for idx, (bar, auc, std) in enumerate(zip(bars, bar_aucs, bar_stds)):
    height = bar.get_height() + std if std > 0 else bar.get_height()
    # Format: mean ± std, but only show std if > 0
    # if std > 0:
    #     label_text = f"{auc:.2f}±{std:.2f}"
    # else:
    label_text = f"{auc:.2f}"
    ax_bar.text(
        bar.get_x() + bar.get_width() / 2,
        height + 0.01,
        label_text,
        ha='center',
        va='bottom',
        fontsize=fontsize-4
    )

ax_bar.axhline(0.5, color='lightgrey', linestyle='-', lw=2)
ax_bar.tick_params(axis='both', which='major', labelsize=fontsize)
ax_bar.set_xticks(bar_positions)
ax_bar.set_xticklabels(bar_labels, rotation=45, ha='right', fontsize=fontsize-1)
ax_bar.set_ylabel("AUC", fontsize=fontsize+2, fontweight="bold")
# ax_bar.set_title("Area Under the Curve (AUC)", fontsize=fontsize+2)
ax_bar.set_ylim(0, 1)
ax_bar.grid(True, axis='y', color='grey', linestyle='--', linewidth=0.7, alpha=0.5)
ax_bar.set_axisbelow(True)
ax_bar.margins(x=0.02)


## Plot acc, sens, spec
axs = axs_all[1:]
for color, model in zip(main_colors, models):

    if model == "Random Forest":
        y_labels = []
        for seed in ["", "_02", "_03"]:
            path_to_pkl = f"{train_test_data_folder}/X_test_{pathogen}_rf{seed}.pkl"
            with open(path_to_pkl, "rb") as f:
                pred_df = pickle.load(f)

            y_score_list = pred_df['viral_rf']
            y_labels.append(y_score_list)

        xticklabels.append(model)
        xtick_positions.append(x_idx)
        # Compute metrics for annotation
        accs, senss, specs = [], [], []
        for trues, preds in zip([Y_TRUE for _ in range(len(y_labels))], y_labels):
            acc, sens, spec = compute_metrics(trues, preds)
            accs.append(acc)
            senss.append(sens)
            specs.append(spec)
        metrics = [accs, senss, specs]
        for i, (vals, ax) in enumerate(zip(metrics, axs)):
            mean, std = mean_std(vals)
            ax.bar([x_idx], [mean], yerr=[[std], [std]], color=color, alpha=0.8, capsize=4, width=0.7)
            ax.set_ylim(0, 1)
            ax.set_ylabel(["Accuracy", "Sensitivity", "Specificity"][i], fontsize=fontsize+2)
            ax.grid(axis='y', linestyle='--', alpha=0.5)
            ax.set_axisbelow(True)
            ax.tick_params(axis='both', which='major', labelsize=fontsize)
            ax.margins(x=0.02)
            annotate_bars(ax, x_idx, mean, std, fontsize=fontsize-4)
        x_idx += 1

    elif model == "RAG":
        path_to_pkl = f"{train_test_data_folder}/X_test_{pathogen}_rag.pkl"
        pred_df = pd.read_pickle(path_to_pkl)

        # y_labels = pred_df['averages_round'].map({0: "no", 1: "yes"})
        y_labels = pred_df['weighted_averages_round'].map({0: "no", 1: "yes"})

        # Compute metrics for annotation
        accs, senss, specs = [], [], []
        acc, sens, spec = compute_metrics(Y_TRUE, y_labels)
        accs.append(acc)
        senss.append(sens)
        specs.append(spec)
        metrics = [accs, senss, specs]
        for i, (vals, ax) in enumerate(zip(metrics, axs)):
            mean, std = mean_std(vals)
            ax.bar([x_idx], [mean], yerr=[[std], [std]], color=color, alpha=0.8, capsize=4, width=0.7)
            ax.set_ylim(0, 1)
            ax.set_ylabel(["Accuracy", "Sensitivity", "Specificity"][i], fontsize=fontsize+2)
            ax.grid(axis='y', linestyle='--', alpha=0.5)
            ax.set_axisbelow(True)
            ax.tick_params(axis='both', which='major', labelsize=fontsize)
            ax.margins(x=0.02)
            annotate_bars(ax, x_idx, mean, std, fontsize=fontsize-4)
        xticklabels.append(rag_label)
        xtick_positions.append(x_idx)
        x_idx += 1

    else:
        paths = get_result_paths(model)
        path_shades = get_shades(color, n=len(paths)-1)
        for i, path in enumerate(paths):

            if "deepseek" in model and "_rf" in path:
                continue

            y_scores = []
            y_trues = []
            for run_id in [1, 2, 3]:
                path_to_csv = f"{path}/{model}_run0{str(run_id)}_predictions.csv"
                if not os.path.exists(path_to_csv):
                    continue
                pred_df = pd.read_csv(path_to_csv)

                y_score_list = pred_df['viral']
                if len(y_score_list) != len(Y_TRUE):
                    continue

                # Replace missing predictions (NaN or 'unknown') with 0  !!! Is this a good way to handle this??
                nan_count = y_score_list.isna().sum()
                unknown_count = (y_score_list == 'unknown').sum()
                if nan_count > 0 or unknown_count > 0:
                    if nan_count > 0:
                        print(f"WARNING: {nan_count} NaN values found in 'viral' for {path_to_csv}")
                    if unknown_count > 0:
                        print(f"WARNING: {unknown_count} 'unknown' values found in 'viral' for {path_to_csv}")
                    
                    # Remove missing/unknown/nan values from y_score_list and remove the same indices from y_true
                    mask = ~(y_score_list.isna() | (y_score_list == 'unknown'))
                    y_score_list = y_score_list[mask].astype(str)
                    y_true_filtered = np.array(Y_TRUE)[mask.values] if isinstance(Y_TRUE, (list, np.ndarray)) else Y_TRUE[mask]
                else:
                    y_true_filtered = Y_TRUE.copy()

                y_trues.append(y_true_filtered)
                y_scores.append(y_score_list)

            knowledge_type = path.split("/")[-1]
            # label = f"{model}\n({k_naming_dict[knowledge_type]})"
            label = k_naming_dict[knowledge_type]

            if len(y_scores) == 0:
                continue

            if "rag" in path:
                hatch = "/"
                color = path_shades[i-1]
            else:
                hatch = None
                color = path_shades[i]

            # Compute metrics for annotation
            accs, senss, specs = [], [], []
            for trues, preds in zip(y_trues, y_scores):
                acc, sens, spec = compute_metrics(trues, preds)
                accs.append(acc)
                senss.append(sens)
                specs.append(spec)
            metrics = [accs, senss, specs]
            for j, (vals, ax) in enumerate(zip(metrics, axs)):
                mean, std = mean_std(vals)
                ax.bar([x_idx], [mean], yerr=[[std], [std]], color=color, alpha=0.8, capsize=4, width=0.7, hatch=hatch)
                ax.set_ylim(0, 1)
                ax.set_ylabel(["Accuracy", "Sensitivity", "Specificity"][j], fontsize=fontsize+2, fontweight="bold")
                ax.grid(axis='y', linestyle='--', alpha=0.5)
                ax.set_axisbelow(True)
                ax.tick_params(axis='both', which='major', labelsize=fontsize)
                ax.margins(x=0.02)
                annotate_bars(ax, x_idx, mean, std, fontsize=fontsize-4)
            xticklabels.append(label)
            xtick_positions.append(x_idx)
            x_idx += 1

for ax in axs_all:
    ax.set_ylim(0, 1.05)

# Set both the tick positions and the labels to ensure all are shown and aligned
axs[-1].set_xticks(xtick_positions)
axs[-1].set_xticklabels(xticklabels, rotation=45, ha="right")

axs[0].axhline(y=0.5, color='lightgray', linestyle='-', linewidth=2)

plt.tight_layout()

fig.savefig("figures/acc_sens_spec_all-viral_predictions.png", dpi=300, bbox_inches="tight")
plt.show()

# Plot breakdown of pos / neg and pathogens in test/train datasets

In [None]:
test_data_path = f"{train_test_data_folder}/X_test_{pathogen}_rf.pkl"
test_df2 = pd.read_pickle(test_data_path)

train_data_path = f"{train_test_data_folder}/X_train_{pathogen}_rf.pkl"
train_df = pd.read_pickle(train_data_path)

In [None]:
# Add pathogen labels to train and test data
label_cols = [col for col in data_df.columns if col.endswith('_label')]
cols_to_keep = ['record_id'] + label_cols

test_df_labels = test_df2.merge(data_df[cols_to_keep], on='record_id', how='left')
train_df_labels = train_df.merge(data_df[cols_to_keep], on='record_id', how='left')

In [None]:
def plot_label_value_counts(df, title_suffix="Test Set"):
    """
    Plots the value counts for all *_label columns in the given DataFrame.

    Parameters:
    - df: pandas DataFrame containing *_label columns
    - title_suffix: str, suffix to add to the plot title (e.g., 'Test Set' or 'Train Set')
    """
    label_cols = [col for col in df.columns if col.endswith('_label')]

    # For each label, get counts for 0, 1, 2 (or whatever unique values exist)
    label_value_counts = {}
    all_possible_values = set()
    for col in label_cols:
        # Get all unique values (including NaN)
        values = df[col].value_counts(dropna=False).sort_index()
        # Convert NaN to string for plotting
        values.index = values.index.map(
            lambda x: str(int(x)) if pd.notnull(x) and isinstance(x, (int, float)) and float(x).is_integer() else str(x)
        )
        label_value_counts[col] = values
        all_possible_values.update(values.index)

    # Sort the possible values for consistent bar order
    all_possible_values = sorted(all_possible_values, key=lambda x: (x == 'nan', x))

    # Prepare data for plotting
    bar_width = 0.2
    x = np.arange(len(label_cols))  # label positions

    fig, ax = plt.subplots(figsize=(max(8, len(label_cols)*0.8), 6))

    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # Add more if needed

    for i, val in enumerate(all_possible_values):
        counts = []
        for col in label_cols:
            # Get count for this value, or 0 if not present
            counts.append(label_value_counts[col].get(val, 0))
        ax.bar(
            x + (i - (len(all_possible_values)-1)/2)*bar_width,
            counts,
            width=bar_width,
            label=f"Label {val}",
            color=colors[i % len(colors)]
        )

    ax.set_xticks(x)
    ax.set_xticklabels(label_cols, rotation=45, ha='right')
    ax.set_ylabel("Count")
    ax.set_title(f"Label Value Counts per *_label Column ({title_suffix})")
    ax.legend(title="Label Value")

    # Add value labels on top of bars
    for i, val in enumerate(all_possible_values):
        for j, col in enumerate(label_cols):
            count = label_value_counts[col].get(val, 0)
            # Find a reasonable offset for the text
            offset = max(1, 0.01*max(label_value_counts[col].values)) if len(label_value_counts[col].values) > 0 else 1
            ax.text(
                j + (i - (len(all_possible_values)-1)/2)*bar_width,
                count + offset,
                str(int(count)),
                ha='center',
                va='bottom',
                fontsize=8
            )

    plt.tight_layout()
    plt.show()

plot_label_value_counts(test_df_labels, title_suffix="Test Set")
plot_label_value_counts(train_df_labels, title_suffix="Train Set")

In [None]:
def plot_all_viral_label_bar(test_df, train_df, fontsize=10):
    """
    Plots a grouped bar plot showing the counts of all-viral_label==1 and all-viral_label==0
    for train and test data, with 'train' and 'test' on the x-axis and all-viral_label as the legend.
    For all-viral_label==0, the bar is further broken down (stacked) by Malaria_label==0 and Malaria_label==1.

    Parameters:
    - test_df: pandas DataFrame for test data, must contain 'all-viral_label' and 'Malaria_label'
    - train_df: pandas DataFrame for train data, must contain 'all-viral_label' and 'Malaria_label'
    - fontsize: int, controls the fontsize of the title, axes labels, and ticklabels
    """

    # Helper to get counts for all-viral_label==0 split by Malaria_label
    def get_nonviral_malaria_counts(df):
        nonviral = df[df['all-viral_label'] == 0]
        malaria_counts = nonviral['Malaria_label'].value_counts().sort_index()
        # Ensure both 0 and 1 are present
        return [malaria_counts.get(0, 0), malaria_counts.get(1, 0)]

    # Get counts for all-viral_label==1
    train_viral = train_df['all-viral_label'].value_counts().get(1, 0)
    test_viral = test_df['all-viral_label'].value_counts().get(1, 0)

    # Get counts for all-viral_label==0, split by Malaria_label
    train_nonviral_malaria = get_nonviral_malaria_counts(train_df)
    test_nonviral_malaria = get_nonviral_malaria_counts(test_df)

    bar_width = 0.4
    x = np.arange(2)  # 0: train, 1: test

    fig, ax = plt.subplots(figsize=(7, 5))

    # Stacked bars for all-viral_label==0 (Malaria_label==0 and ==1)
    colors_nonviral = ['grey', 'lightgrey']
    bars_nonviral_0 = [train_nonviral_malaria[0], test_nonviral_malaria[0]]
    bars_nonviral_1 = [train_nonviral_malaria[1], test_nonviral_malaria[1]]

    p1 = ax.bar(x, bars_nonviral_0, width=bar_width, label="Non-viral", color=colors_nonviral[0])
    p2 = ax.bar(x, bars_nonviral_1, width=bar_width, bottom=bars_nonviral_0, label="Non-viral (Malaria)", color=colors_nonviral[1])

    # Bar for all-viral_label==1
    color_viral = "#FF0000"
    bars_viral = [train_viral, test_viral]
    p3 = ax.bar(x + bar_width, bars_viral, width=bar_width, label="Viral", color=color_viral)

    # Add value labels
    for i in range(2):
        # Non-viral, Malaria=0
        offset = max(1, 0.001 * max(bars_nonviral_0 + bars_nonviral_1 + bars_viral))
        if bars_nonviral_0[i] > 0:
            ax.text(x[i], bars_nonviral_0[i] / 2, str(bars_nonviral_0[i]), ha='center', va='center', fontsize=fontsize, color='black')
        # Non-viral, Malaria=1
        if bars_nonviral_1[i] > 0:
            ax.text(x[i], bars_nonviral_0[i] + bars_nonviral_1[i] / 2, str(bars_nonviral_1[i]), ha='center', va='center', fontsize=fontsize, color='black')
        # Viral
        if bars_viral[i] > 0:
            ax.text(x[i] + bar_width, bars_viral[i] / 2, str(bars_viral[i]), ha='center', va='center', fontsize=fontsize, color='black')

    ax.set_xticks(x + bar_width / 2)
    ax.set_xticklabels(['Training Data', 'Testing Data'], fontsize=fontsize)
    ax.set_ylabel("Count", fontsize=fontsize)
    ax.set_title("Viral vs Non-viral Counts in the Training and Testing Data", fontsize=fontsize, fontweight="bold")
    ax.legend(loc="upper center", fontsize=fontsize)

    # Set yticklabels fontsize
    ax.tick_params(axis='y', labelsize=fontsize)

    plt.tight_layout()
    plt.savefig(f"figures/train_test_data_viral_breakdown.png", dpi=300, bbox_inches="tight")
    plt.show()

plot_all_viral_label_bar(test_df_labels, train_df_labels, fontsize=12)

In [None]:
# Model: GPT-4, version 2024-06-13

def plot_label_upset_and_malaria_bar(df, title_suffix="Test Set", fontsize=12):
    """
    For a given DataFrame, creates:
    - An UpSet plot for rows where all-viral_label == 1, showing (co-)occurrences of *_label columns == 1 (excluding 'all-viral_label').
    - For rows where all-viral_label == 0, plots a bar plot showing counts for Malaria_label==0 versus Malaria_label==1,
      and confirms that no other *_label column is == 1 in these rows.

    Parameters:
    - df: pandas DataFrame containing *_label columns
    - title_suffix: str, suffix to add to the plot title (e.g., 'Test Set' or 'Train Set')
    - fontsize: int, controls the fontsize of the title, axes labels, and ticklabels
    """

    # Select *_label columns, excluding 'all-viral_label'
    all_label_cols = [col for col in df.columns if col.endswith('_label') and col != 'all-viral_label']

    # Split the dataframe by all-viral_label
    for_viral = df[df['all-viral_label'] == 1]
    for_nonviral = df[df['all-viral_label'] == 0]

    # --- UpSet plot for all-viral_label == 1 ---
    def get_nonzero_label_cols(subset_df, label_cols):
        # Only keep *_label columns that have at least one value == 1
        nonzero_cols = [col for col in label_cols if subset_df[col].eq(1).any()]
        return nonzero_cols

    viral_label_cols = get_nonzero_label_cols(for_viral, all_label_cols)
    label_bool_viral = for_viral[viral_label_cols] == 1 if viral_label_cols else pd.DataFrame()
    label_bool_viral = label_bool_viral.fillna(False)

    def bool_df_to_upset_series(bool_df, label_cols):
        if bool_df.empty:
            return None
        from collections import Counter
        if len(label_cols) == 1:
            # Series with index True/False
            counts = Counter(bool_df.iloc[:, 0])
            s = pd.Series([counts.get(True, 0), counts.get(False, 0)], index=[True, False], name=label_cols[0])
            # Only keep True if there are any
            s = s[s.index == True]
            return s if not s.empty else None
        else:
            tuples = [tuple(row) for row in bool_df.values]
            counts = Counter(tuples)
            index = pd.MultiIndex.from_tuples(counts.keys(), names=label_cols)
            s = pd.Series(list(counts.values()), index=index)
            return s if not s.empty else None

    upset_data_viral = bool_df_to_upset_series(label_bool_viral, viral_label_cols) if not for_viral.empty and viral_label_cols else None

    def is_upset_compatible(series):
        if series is None or series.empty:
            return False
        idx = series.index
        if isinstance(idx, pd.MultiIndex):
            return idx.nlevels > 0
        elif isinstance(idx, pd.Index):
            return True
        return False

    # Plot for all-viral_label == 1
    if (
        upset_data_viral is not None
        and not label_bool_viral.empty
        and label_bool_viral.sum(axis=1).max() > 0
        and is_upset_compatible(upset_data_viral)
    ):
        try:
            upset_viral = UpSet(upset_data_viral, show_counts=True, sort_by='cardinality')
            upset_viral.plot()
            plt.title(f"{title_suffix} (Viral)", fontsize=fontsize+2, fontweight="bold")
            # Set ticklabel font sizes for all axes in the UpSet plot
            for ax in plt.gcf().axes:
                ax.tick_params(axis='both', labelsize=fontsize)
                # Set axes label font sizes if present
                if ax.get_xlabel():
                    ax.set_xlabel(ax.get_xlabel(), fontsize=fontsize)
                if ax.get_ylabel():
                    ax.set_ylabel(ax.get_ylabel(), fontsize=fontsize)
                # Remove "_label" from y axis ticklabels
                yticklabels = [label.get_text() for label in ax.get_yticklabels()]
                if any("_label" in label for label in yticklabels):
                    new_yticklabels = [label.replace("_label", "") for label in yticklabels]
                    ax.set_yticklabels(new_yticklabels, fontsize=fontsize)
            plt.savefig(f"figures/UpSet_plot_viral_positive_breakdown_{title_suffix}.png", dpi=300, bbox_inches="tight")
            plt.show()
        except Exception as e:
            print(f"Could not plot UpSet for {title_suffix} (all-viral_label == 1): {e}")
    else:
        print(f"{title_suffix} (all-viral_label == 1): No rows with any *_label == 1 found or only one label present.")

    # --- Bar plot for Malaria_label in all-viral_label == 0 ---
    if for_nonviral.empty:
        print(f"{title_suffix} (all-viral_label == 0): No rows found.")
        return

    # Confirm that no other *_label column is == 1 except Malaria_label
    malaria_col = "Malaria_label"
    other_label_cols = [col for col in all_label_cols if col != malaria_col]
    # For each row, check if all other *_label columns are 0 or NaN
    if other_label_cols:
        # Find rows where any other *_label column is 1
        mask_other_1 = (for_nonviral[other_label_cols].fillna(0).astype(int) == 1)
        rows_with_other_1 = mask_other_1.any(axis=1)
        if rows_with_other_1.any():
            print(f"WARNING: Some rows in {title_suffix} (all-viral_label == 0) have other *_label columns == 1 besides Malaria_label.")
            # Show which columns are ==1 for these rows
            idxs = rows_with_other_1[rows_with_other_1].index
            for idx in idxs:
                row = for_nonviral.loc[idx, other_label_cols]
                cols_with_1 = row[row.fillna(0).astype(int) == 1].index.tolist()
                print(f"  Row index {idx}: columns == 1: {cols_with_1}")
        else:
            print(f"Confirmed: In {title_suffix} (all-viral_label == 0), only Malaria_label is 1; all other *_label columns are 0 or NaN.")
    else:
        print(f"No other *_label columns found except Malaria_label.")

    # Count Malaria_label==0 and Malaria_label==1
    malaria_counts = for_nonviral[malaria_col].value_counts().sort_index()
    malaria_bar = [malaria_counts.get(0, 0), malaria_counts.get(1, 0)]

    fig, ax = plt.subplots(figsize=(4, 4))
    bars = ax.bar([0, 1], malaria_bar, color=['#1f77b4', '#ff7f0e'], width=0.6)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Malaria_label=0', 'Malaria_label=1'])
    ax.set_ylabel("Count")
    ax.set_title(f"{title_suffix} (all-viral_label == 0): Malaria_label breakdown")

    # Add value labels
    for i, val in enumerate(malaria_bar):
        ax.text(i, val + max(1, 0.001 * max(malaria_bar)), str(val), ha='center', va='bottom', fontsize=10)

    plt.tight_layout()
    plt.savefig(f"figures/Malaria_label_barplot_viral_negative_{title_suffix}.png", dpi=300, bbox_inches="tight")
    plt.show()

# Example usage:
plot_label_upset_and_malaria_bar(test_df_labels, title_suffix="Testing Data")
plot_label_upset_and_malaria_bar(train_df_labels, title_suffix="Training Data")