In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import scipy.stats

In [2]:
def barplot_with_CI_errorbars_colored_by_label(df, x_label, y_label, y_lower_label, y_upper_label, color_label, figsize=False):
    """Creates bar plot of a given dataframe with asymmetric error bars for y axis.

        Args:
            df: Pandas Dataframe that should have columns with columnnames specified in other arguments.
            x_label: str, column name of x axis categories
            y_label: str, column name of y axis values
            y_lower_label: str, column name of lower error values of y axis
            y_upper_label: str, column name of upper error values of y axis
            color_label: str, column name of label that will determine the color of bars
            figsize: tuple, size in inches. Default value is False.

        """
    # Column names for new columns for delta y_err which is calculated as | y_err - y |
    delta_lower_yerr_label = "$\Delta$" + y_lower_label
    delta_upper_yerr_label = "$\Delta$" + y_upper_label
    data = df  # Pandas DataFrame
    data.loc[:, delta_lower_yerr_label] = data.loc[:, y_label] - data.loc[:, y_lower_label]
    data.loc[:, delta_upper_yerr_label] = data.loc[:, y_upper_label] - data.loc[:, y_label]

    # Color
    #current_palette = sns.color_palette()
    #sns_color = current_palette[2] # Error bar color

    # Zesty colorblind-friendly color palette
    color0 = "#0F2080"
    color1 = "#F5793A"
    color2 = "#A95AA1"
    color3 = "#85C0F9"
    current_palette = [color0, color1, color2, color3]
    error_color = 'gray'

    # Bar colors
    if color_label == "category":
        category_list = ["Physical", "Empirical", "Mixed", "Other"]
    elif color_label == "reassigned_category":
        category_list = ["Physical (MM)", "Empirical", "Mixed", "Physical (QM)"]
    elif color_label == "type":
        category_list = ["Standard", "Reference"]
    else:
        Exception("Error: Unsupported label used for coloring")
    bar_color_dict = {}
    for i, cat in enumerate(category_list):
        bar_color_dict[cat] = current_palette[i]
    print("bar_color_dict:\n", bar_color_dict)


    # Plot style
    plt.close()
    plt.style.use(["seaborn-talk", "seaborn-whitegrid"])
    plt.rcParams['axes.labelsize'] = 20 # 18
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 18 #16
    plt.rcParams['legend.fontsize'] = 16
    plt.rcParams['legend.handlelength'] = 2
    # plt.tight_layout()

    # If figsize is specified
    if figsize != False:
        plt.figure(figsize=figsize)

    # Plot
    x = range(len(data[y_label]))
    y = data[y_label]
    #barlist = plt.bar(x, y)
    fig, ax = plt.subplots(figsize=figsize)
    barlist = ax.bar(x, y)

    plt.xticks(x, data[x_label], rotation=90)
    plt.errorbar(x, y, yerr=(data[delta_lower_yerr_label], data[delta_upper_yerr_label]),
                 fmt="none", ecolor=error_color, capsize=3, elinewidth=2, capthick=True)
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    # Reset color of bars based on color label
    #print("data.columns:\n",data.columns)
    #print("\nData:\n", data)
    for i, c_label in enumerate(data.loc[:, color_label]):
        barlist[i].set_color(bar_color_dict[c_label])

    # create legend
    from matplotlib.lines import Line2D
    if color_label == 'category':
        custom_lines = [Line2D([0], [0], color=bar_color_dict["Physical"], lw=5),
                        Line2D([0], [0], color=bar_color_dict["Empirical"], lw=5),
                        Line2D([0], [0], color=bar_color_dict["Mixed"], lw=5),
                        Line2D([0], [0], color=bar_color_dict["Other"], lw=5)]
    elif color_label == 'reassigned_category':
        custom_lines = [Line2D([0], [0], color=bar_color_dict["Physical (MM)"], lw=5),
                        Line2D([0], [0], color=bar_color_dict["Empirical"], lw=5),
                        Line2D([0], [0], color=bar_color_dict["Mixed"], lw=5),
                        Line2D([0], [0], color=bar_color_dict["Physical (QM)"], lw=5)]
    elif color_label == 'type':
        custom_lines = [Line2D([0], [0], color=bar_color_dict["Standard"], lw=5),
                        Line2D([0], [0], color=bar_color_dict["Reference"], lw=5)]

In [7]:
df_statistics = pd.read_csv("statistics.csv")
directory_path="."
       

# RMSE comparison plot with each category colored separately
barplot_with_CI_errorbars_colored_by_label(df=df_statistics, x_label="ID", y_label="RMSE",
                            y_lower_label="RMSE_lower_bound",
                            y_upper_label="RMSE_upper_bound", color_label = "reassigned_category", figsize=(28,10))
plt.ylim(0.0, 7.0)
plt.savefig(directory_path + "/test_tight_plot.pdf",  bbox_inches='tight')

bar_color_dict:
 {'Physical (MM)': '#0F2080', 'Empirical': '#F5793A', 'Mixed': '#A95AA1', 'Physical (QM)': '#85C0F9'}
