In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import seaborn as sns
import matplotlib.patches as mpatches

from classes.classes import MODEL_CONFIGS, NO_COSTS_MODEL_CONFIGS
from classes.paths import LocalPaths
from classes.workloads import EvalWorkloads
from cross_db_benchmark.datasets.datasets import Database
from evaluation.eval import Evaluator
from evaluation.evaluation_metrics import QError, PickRate, SelectedRuntime
from evaluation.utils import get_model_results, draw_metric, draw_predictions
from training.dataset.dataset_creation import read_workload_runs
from classes.classes import ColorManager
import seaborn
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from classes.classes import DACEModelConfig, QueryFormerModelConfig, ZeroShotModelConfig, E2EModelConfig, \
    ScaledPostgresModelConfig, FlatModelConfig, QPPNetModelConfig
from classes.classes import MSCNModelConfig
from matplotlib.patches import Rectangle


sns.set_theme(style="whitegrid", font_scale=1.8)
fontsize=12

## 1. Evaluate Plans

In [None]:
evaluator = Evaluator()
metrics = [QError(), PickRate(), SelectedRuntime(display_name="Selected\nRuntime")]
workloads =[EvalWorkloads.PhysicalPlan.imdb, EvalWorkloads.PhysicalPlan.tpc_h_pk, EvalWorkloads.PhysicalPlan.baseball]
databases = [Database("scale", display_name="IMDB"), Database("tpc_h_pk", display_name="TPC-H"), Database("baseball", display_name="Baseball")]
model_confs = MODEL_CONFIGS #+ NO_COSTS_MODEL_CONFIGS
seeds = [0, 1, 2]

for workload in workloads:
    evaluator.eval(workloads=workload,
                   metrics=metrics,
                   plot_single_workloads=False,
                   plot_limit=5,
                   seeds=[0, 1, 2],
                   model_configs=model_confs)

## 2. Create Anecdote Plots

In [None]:
path = LocalPaths().data / "plots" / "physical_plan_anecdote.pdf"
mosaic = """AAAABB\nAAAABB"""
model_confs = MODEL_CONFIGS
workload = workloads[0][1]

figure = plt.figure(figsize=(6, 2.5), dpi=100)
results = get_model_results(workload, model_confs)

# Sort results first according to the query_index and then by the model order in the model_confs
results = results.sort_values(
    by=['query_index', 'model'],
    key=lambda x: x.apply(lambda y: {c.name.DISPLAY_NAME: i for i, c in enumerate(model_confs)}.get(y, len(model_confs))))
prediction, runtime = figure.subplot_mosaic(mosaic, gridspec_kw={'height_ratios': [1,1], 'wspace': 0.75, 'hspace': 0.2}).values()

seaborn.barplot(x="query_index", 
                hue="model", 
                y="prediction", 
                data=results, 
                ax=prediction, 
                palette=ColorManager.COLOR_PALETTE,
                errorbar=None,
                edgecolor='black',
                width=0.9) #, color=model.color())

# Draw real values
min_labels = [results[results['query_index'] == query_index]['runtime'].min() for query_index in results['query_index'].unique()]
for i, min_label in zip([0, 1/3, 2/3], min_labels):
    print(i, min_label)
    prediction.axhline(y=min_label, xmin=i -0, xmax=i + 1/3, linestyle='--', color='black', linewidth=2, zorder=100)
    
# Configure plot
prediction.set_ylabel("Runtime (s)", fontsize=fontsize)
prediction.grid(axis="y", which='both', linestyle='--', linewidth=0.5)
prediction.set_xlabel("")
prediction.set_xticklabels(workload.yticks, fontsize=fontsize)
prediction.tick_params(axis='x', rotation=0, labelsize=fontsize * 0.8)
prediction.tick_params(axis='y', rotation=0, labelsize=fontsize * 0.8)

# Update xticklabels:
prediction.set_xticklabels(workload.yticks, fontsize=fontsize)

# Draw metrics
draw_metric(results, model_confs, runtime, SelectedRuntime(), fontsize=fontsize)
runtime.axhline(y=results['runtime'].min(), linestyle='solid', color='black', linewidth=2, zorder=100)
runtime.set_ylim(prediction.get_ylim())
runtime.annotate(f"Optimal: {results['runtime'].min():.0f}s", xycoords='axes fraction', xy=(0.25, 0.15), fontsize=fontsize * 0.8, backgroundcolor='white')
prediction.annotate(f"Optimal Join", xy=(1, 10), xytext=(0.65, 1.1), fontsize=fontsize * 0.7, backgroundcolor='white', rotation=0)

# Draw common legend
handles, labels = prediction.get_legend_handles_labels()
# Sort both labels and handles by ColorManager.COLOR_PALETTE:
handles.append(Line2D([0], [0], color='black', lw=2, linestyle='--', label='Real Runtime'))
handles.append(Line2D([0], [0], color='black', lw=2, linestyle='-', label='Optimal Runtime'))

legend = prediction.legend(
    handles=handles,
    fontsize=fontsize,
    ncol=3, 
    loc='center left', 
    bbox_to_anchor=(-0.2, -0.5),
    labelspacing=0.1,
    edgecolor='white')

figure.suptitle("SELECT COUNT(*) FROM title, movie_info\nWHERE title.id=movie_info.movie_id\nAND movie_info.info_type_id<8;", fontsize=fontsize*0.8, fontproperties={'family':'monospace'}, y=1.13, horizontalalignment='center')
# Save plot
figure.savefig(path, bbox_inches='tight')

In [None]:
path = LocalPaths().data / "plots" / "physical_plan_anecdote.pdf"
mosaic = """AAAABB\nAAAABB"""
model_confs = MODEL_CONFIGS
workload = workloads[0][1]

figure = plt.figure(figsize=(5, 2), dpi=100)
results = get_model_results(workload, model_confs)
hatch_patterns = {0: "--", 1: "//", 2: ".."}

# Sort results first according to the query_index and then by the model order in the model_confs
results = results.sort_values(
    by=['query_index', 'model'],
    key=lambda x: x.apply(lambda y: {c.name.DISPLAY_NAME: i for i, c in enumerate(model_confs)}.get(y, len(model_confs))))

real_runtime = results.groupby("query_index")["runtime"].min().reset_index()
real_runtime["model"] = "Real Runtime"
real_runtime["prediction"] = real_runtime["runtime"]
real_runtime["label"] = real_runtime["runtime"]
real_runtime["runtime"] = real_runtime["runtime"]
real_runtime["selected_runtime"] = real_runtime["runtime"]
results = pd.concat([results, real_runtime], ignore_index=True)
sort_map = {model.name.DISPLAY_NAME: i for i, model in enumerate(model_confs)}
sort_map["Real Runtime"] = -1

#results = results.sort_values(by='runtime', ascending=True)
results = results.sort_values(by=['model'], key=lambda x: x.map(sort_map))

prediction, runtime = figure.subplot_mosaic(mosaic, gridspec_kw={'height_ratios': [1,1], 'wspace': 0.75, 'hspace': 0.2}).values()

bars = seaborn.barplot(x="model", 
                hue="query_index", 
                y="prediction", 
                data=results, 
                ax=prediction, 
                #palette=ColorManager.COLOR_PALETTE,
                errorbar=None,
                edgecolor='black',
                width=0.9) #, color=model.color())

prediction.set_xticklabels(["Real Runtime"] + [m.name.DISPLAY_NAME for m in model_confs], fontsize=fontsize * 0.8, rotation=45)

bar_groups = []
all_labels = bars.get_xticklabels() * 3
for i, bar in enumerate(bars.patches[0:27]):
    label = all_labels[i].get_text()
    row = {"x": bar.get_x(), "model":  label, "bar": bar, "hatch": i // 9}
    bar_groups.append(row)

# Sort bar groups by x position
bar_groups = sorted(bar_groups, key=lambda x: x['x'])

# Set background color for each group of bars (grouped by model)
for bar_group in bar_groups:
    color = ColorManager.COLOR_PALETTE.get(bar_group["model"], 'white')
    bar_group["bar"].set_facecolor(color)
    bar_group["bar"].set_hatch(hatch_patterns[bar_group["hatch"]])

# Configure plot
prediction.set_ylabel("Runtime (s)", fontsize=fontsize * 0.8)
prediction.grid(axis="y", which='both', linestyle='--', linewidth=0.5)
prediction.set_xlabel("")
prediction.set_xticklabels(workload.yticks, fontsize=fontsize)
prediction.tick_params(axis='x', rotation=0, labelsize=fontsize * 0.8, pad=-4)
prediction.tick_params(axis='y', rotation=0, labelsize=fontsize * 0.8)
#prediction.set_xticklabels(["Real Runtime"] + [m.name.DISPLAY_NAME for m in model_confs], fontsize=fontsize * 0.8, rotation=45)
prediction.set_xticklabels([])
prediction.get_legend().remove()

# Draw metrics
draw_metric(results, model_confs, runtime, SelectedRuntime(), fontsize=fontsize)
runtime.axhline(y=results['runtime'].min(), linestyle='solid', color='black', linewidth=2, zorder=100)
runtime.set_ylim(prediction.get_ylim())

for bar in runtime.patches:
    y = bar.get_height()
    if y == results[results['query_index'] == 0]['runtime'].min():
        bar.set_hatch(hatch_patterns[0])
    elif y == results[results['query_index'] == 1]['runtime'].min():
        bar.set_hatch(hatch_patterns[1])
    else:
        bar.set_hatch(hatch_patterns[2])

for tick in prediction.xaxis.get_majorticklabels():
    tick.set_horizontalalignment("right")
    tick.set_verticalalignment("top")

    
# Draw common legend
handles, labels = runtime.get_legend_handles_labels()
for handle in handles:
    handle.set_hatch(None)

new_rectangle = Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black', label='Real Runtime')
handles.insert(0, new_rectangle)

blank_line = plt.Line2D([], [], linewidth=0)
handles.insert(0, blank_line)

new_rectangle = Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black', label='INLJ')
new_rectangle.set_hatch(hatch_patterns[2])
handles.insert(0, new_rectangle)

new_rectangle = Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black', label='SMJ (optimal)')
new_rectangle.set_hatch(hatch_patterns[1])
handles.insert(0, new_rectangle)

new_rectangle = Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black', label='HJ')
new_rectangle.set_hatch(hatch_patterns[0])
handles.insert(0, new_rectangle)

legend = runtime.legend(
    handles=handles,
    fontsize=fontsize*0.8,
    ncol=1, 
    loc='center left', 
    columnspacing=1, 
    bbox_to_anchor=(-4.5, 0.5),
    labelspacing=0.05,
    edgecolor='white')

figure.suptitle("SELECT COUNT(*) FROM title, movie_info\nWHERE title.id=movie_info.movie_id AND movie_info.info_type_id<8;", fontsize=fontsize*0.85, fontproperties={'family':'monospace'}, y=1.15, x=0.4, horizontalalignment='center')
# Save plot
figure.savefig(path, bbox_inches='tight')

In [None]:
evaluator = Evaluator()
workloads =[EvalWorkloads.PhysicalPlan.imdb, EvalWorkloads.PhysicalPlan.tpc_h_pk, EvalWorkloads.PhysicalPlan.baseball]
model_confs = NO_COSTS_MODEL_CONFIGS
seeds = [0, 1, 2]
metrics = [QError(), PickRate(), SelectedRuntime()]
for workload in workloads:
    evaluator.eval(workloads=workload,
                   metrics=metrics,
                   plot_single_workloads=False,
                   plot_limit=5,
                   seeds=[0, 1, 2],
                   model_configs=model_confs)

In [None]:
path = LocalPaths().data / "plots" /"physical_plan_selection_full.pdf"
databases = [Database("scale", display_name="IMDB"), Database("tpc_h", display_name="TPC-H"), Database("baseball", display_name="Baseball")]

# Plotting
results = pd.DataFrame.from_dict(evaluator.metric_collection)
results["database"] = results["workload"].str.rsplit('_', n=1).str[0]
results["workload"] = results["workload"].str.rsplit('_', n=1).str[-1]

fig, (upper_axs, lower_axs) = plt.subplots(2, 3, figsize=(14, 3.5), sharex="col", dpi=100)
for i, database in enumerate(databases):
    if database.db_name == "scale":
        min_runtime = sum(evaluator.minimal_runtimes["imdb"])
    elif database.db_name == "tpc_h":
        min_runtime = sum(evaluator.minimal_runtimes["tpc_h_pk"])
    else: 
        min_runtime = sum(evaluator.minimal_runtimes[database.db_name])
    pick_rate, runtime = upper_axs[i], lower_axs[i]

    # Prepare dataframes
    result_df = results[results["database"] == database.db_name]
    #print(result_df)
    # Get workload where q-error is > 10.000 and remove all entries corresponding to the workload
    high_q_error_workloads = result_df[result_df["qerror"] > 100000]["workload"].unique()
    min_runtime = min_runtime - results[results["workload"].isin(high_q_error_workloads)]["runtime"].sum()
    result_df = result_df[~result_df["workload"].isin(high_q_error_workloads)]
    
    percentage_true = result_df.groupby('model_name')['pick_rate'].mean() * 100
    percentage_true = percentage_true.reindex([c.name.DISPLAY_NAME for c in model_confs])
    runtimes = result_df.groupby('model_name')['runtime'].sum()
    runtimes = runtimes.reindex([c.name.DISPLAY_NAME for c in model_confs])
    
    palette = ColorManager.COLOR_PALETTE
    # Insert an empty value
    model_names = list([c.name.DISPLAY_NAME for c in model_confs])
    empty_entry = 'empty'
    empty_color = 'white'  # Use the background color of your plot
    palette[empty_entry] = empty_color
    percentage_true[empty_entry] = 0
    runtimes[empty_entry] = 0

    # Insert the empty entry between the sixth and seventh bars
    insert_position = 8  # Index to insert the empty bar
    model_names.insert(insert_position, empty_entry)
    
    # Reindex the dataframes to include the empty entry
    percentage_true = percentage_true.reindex(model_names)
    runtimes = runtimes.reindex(model_names)

    # Melt the dataframes
    percentage_true_df = percentage_true.reset_index().melt(id_vars='model_name', var_name='variable', value_name='value')
    runtimes_df = runtimes.reset_index().melt(id_vars='model_name', var_name='variable', value_name='value')
        
    sns.barplot(data=percentage_true_df, 
                x="model_name", 
                y="value", 
                ax=pick_rate,
                palette=palette, 
                width=1.0, 
                log_scale=(False, False),
                hue="model_name",
                edgecolor='black')
    
    # Add bar labels
    for idx, c in enumerate(pick_rate.containers):
        if idx != 8:
            pick_rate.bar_label(c, fontsize=0.8 * fontsize, fmt='%.0f', padding=1, label_type='edge')

    pick_rate.set_ylim(0, 100)
    pick_rate.grid(False, axis='x', which='both')
    pick_rate.grid(True, axis='y', which='both')
    pick_rate.set_xlabel('')
    pick_rate.set_ylabel('')
    pick_rate.axvspan(xmin=8, xmax=11, alpha=0.1, color='gray')
    pick_rate.set_title(database.display_name, fontsize=fontsize)
    pick_rate.set_yticklabels(pick_rate.get_yticklabels(), fontsize=fontsize)
    pick_rate.set_xlim(left=-1, right=10)

    # Plot the runtimes
    sns.barplot(data=runtimes_df,
                x="model_name", 
                y="value", 
                ax=runtime, 
                width=1.0, 
                palette=ColorManager.COLOR_PALETTE, 
                hue="model_name", 
                edgecolor='black')
    
    # Add bar labels
    for idx, c in enumerate(runtime.containers):
        if idx != 8:
            runtime.bar_label(c, fontsize=0.8 * fontsize, fmt='%.0f', padding=1, label_type='edge')
        
    #runtime.set_yscale('log')
    runtime.set_xlabel('')
    runtime.set_ylabel('')
    runtime.set_xlim(left=-1, right=11)
    runtime.axhline(min_runtime, color='black', linestyle='--')
    if database.db_name in ["scale", "tpc_h"]:
        xy = (0.33, 0.1)
    else:
        xy = (0.33, 0.8)
    runtime.annotate(f"Optimal: {min_runtime:.0f}s", 
                     xy=xy, 
                     xytext=xy, 
                     fontsize=fontsize * 0.9, 
                     xycoords="axes fraction",
                     backgroundcolor='white')
    
    runtime.grid(False, axis='x', which='both')
    runtime.grid(True, axis='y', which='both')
    runtime.axvspan(xmin=8, xmax=11, alpha=0.1, color='gray')
    runtime.set_xticklabels([], fontsize=fontsize)
    runtime.yaxis.set_tick_params(labelsize=fontsize, which='both')

# Set axis labels
upper_axs[0].set_ylabel('Pick Rate(%)', fontsize=fontsize)
lower_axs[0].set_ylabel('Selected\nRuntime(s)', fontsize=fontsize)

# Draw legend
legend_patches = [mpatches.Patch(color=model_config.color(), label=model_config.name.DISPLAY_NAME) for model_config in model_confs]

for p in legend_patches:
    p.set_edgecolor('black')

blank_line = plt.Line2D([], [], linewidth=0)
legend_patches.insert(8, blank_line)
upper_axs[0].legend(handles=legend_patches, 
                    loc='upper left', 
                    bbox_to_anchor=(-0.9, 0.8), 
                    fontsize=fontsize,
                    labelspacing=0.1, 
                    facecolor='white', 
                    edgecolor='white')

# Generate plot
plt.subplots_adjust(hspace=0.10)
fig.align_labels()
plt.savefig(path, bbox_inches='tight')

In [None]:
results

In [None]:
def get_nested_loop_join_type(plan):
    if plan.plan_parameters.op_name == "Nested Loop":
        join_tables = plan.children
        if join_tables[0].plan_parameters.est_card == 1:
            inner_table = join_tables[0]
            outer_table = join_tables[1]
        else:
            inner_table = join_tables[1]
            outer_table = join_tables[0]
            
        assert outer_table.plan_parameters.est_card > 1
        if "Index" in inner_table.plan_parameters.op_name:
            return "Index Nested Loop"
        else:
            return "Nested Loop"
     
    for subplan in plan.children:
        return get_nested_loop_join_type(subplan)

def get_join_type(plan):
    if "Merge Join" in str(plan):
        return "Merge Join"
    elif "Hash Join" in str(plan):
        return "Hash Join"
    elif "Nested Loop" in str(plan):
        return get_nested_loop_join_type(plan)
    print(plan)

In [None]:
evaluator = Evaluator()
metrics = [QError(), PickRate(), SelectedRuntime()]
workloads =[EvalWorkloads.PhysicalPlan.imdb, EvalWorkloads.PhysicalPlan.imdb_with_indexes]
databases = [Database("scale", display_name="IMDB")]
model_confs = [
    ScaledPostgresModelConfig(),
    FlatModelConfig(),
    #MSCNModelConfig(),
    E2EModelConfig(),
    ZeroShotModelConfig(),
    QPPNetModelConfig(),
    QueryFormerModelConfig(),
    DACEModelConfig(),
]
seeds = [0, 1, 2]

evaluator.minimal_runtimes = {}
evaluator.metric_collection = []

collected_results = []
collected_runtimes = []
collected_percentages = []
collected_optimal_runtimes = []

for workload in workloads:
    evaluator.eval(workloads=workload,
                   metrics=metrics,
                   plot_single_workloads=False,
                   plot_limit=5,
                   seeds=[0, 1, 2],
                   model_configs=model_confs)

    results = []
    optimal_runtimes = []
    for single_wl in workload:
        plans, database_statistics = read_workload_runs(workload_run_paths=[single_wl.get_workload_path(LocalPaths().parsed_plans)])
        join_types = []
        plan_runtimes = []
        for plan in plans:
            join_types.append(get_join_type(plan))
            plan_runtimes.append(int(plan.plan_runtime))
        optimal_runtimes.append(min(plan_runtimes))
        # Combine model predictions to common dataframe
        predictions = Evaluator.combine_predictions(model_confs, single_wl, seeds)
        # Add join types to predictions by mapping query_index to join type
        mapping = {0: join_types[0], 1: join_types[1], 2: join_types[2]}
        predictions["join_type"] = predictions['query_index'].map(mapping)
        
        # Aggregate and reduce the predictions over the given seeds
        predictions = predictions.groupby(['model', 'query_index', 'label', 'join_type'])['prediction'].mean().reset_index()
        
        # Add optimal plan first
        results.append({"model": "Optimal", "join_type": predictions[predictions["label"] == min(predictions["label"])]["join_type"].values[0], "workload": single_wl.get_workload_name()})
        for model in model_confs:
            model_predictions = predictions[predictions['model'] == model.name.DISPLAY_NAME]
            # Get join_type where prediction is minimal
            
            selected_join = model_predictions[model_predictions["prediction"] == min(model_predictions["prediction"])]["join_type"].values[0]
            
            results.append({"model": model.name.DISPLAY_NAME, 
                            "join_type": selected_join, 
                            "workload": single_wl.get_workload_name()})
    collected_optimal_runtimes.append(sum(optimal_runtimes))

    results = pd.DataFrame.from_dict(results)
    results = results.groupby(['model', 'join_type']).size().unstack().fillna(0).reset_index()
    results.set_index('model', inplace=True)
    results = results.reindex([c.name.DISPLAY_NAME for c in model_confs] + ["Optimal"])
    collected_results.append(results)
    
    
    # Compute pick rates and runtimes    
    min_runtime = sum(evaluator.minimal_runtimes["imdb"])
    results = pd.DataFrame.from_dict(evaluator.metric_collection)
    
    evaluator.metric_collection = []
    evaluator.minimal_runtimes = {}
    results["database"] = results["workload"].str.rsplit('_', n=1).str[0]
    results["workload"] = results["workload"].str.rsplit('_', n=1).str[-1]

    runtimes = results.groupby('model_name')['runtime'].sum()
    runtimes = runtimes.reindex([c.name.DISPLAY_NAME for c in model_confs])
    runtimes = runtimes.reset_index().melt(id_vars='model_name', var_name='variable', value_name='value')
    collected_runtimes.append(runtimes)
    
    percentage_true = results.groupby('model_name')['pick_rate'].mean() * 100
    percentage_true = percentage_true.reindex([c.name.DISPLAY_NAME for c in model_confs])
    percentage_true = percentage_true.reset_index().melt(id_vars='model_name', var_name='variable', value_name='value')
    collected_percentages.append(percentage_true)

In [None]:
collected_optimal_runtimes

In [None]:
paths = [LocalPaths().data / "plots" /"physical_plan_breakdown.pdf", 
         LocalPaths().data / "plots" /"physical_plan_breakdown_add_index.pdf"]
fontsize = 16

NUM_SUBPLOTS =  len(model_confs) + 3
COLOR_PALETTE = [sns.color_palette("Grays")[0], sns.color_palette("Grays")[2], sns.color_palette("Grays")[3]]

def custom_autopct(pct):
    return ('%1.0f%%' % pct) if pct >= 10 else ('' % pct)
    
for index, (results, pick_rates, runtimes, actual_runtime, path) in enumerate(zip(collected_results, collected_percentages, collected_runtimes, collected_optimal_runtimes, paths)):
    fig, axes = plt.subplots(1,NUM_SUBPLOTS, figsize=(17, 2), dpi=100)
    optimal_ax, prediction_axs, runtime_ax, pick_rate_ax = axes[0], axes[1:-2], axes[-2], axes[-1]
        
    # Draw optimal results
    optimal_ax.pie(results.loc['Optimal'], 
                autopct=lambda pct: custom_autopct(pct), 
                startangle=90, 
                counterclock=True, 
                labels=['', '', ''], 
                textprops={'fontsize': 11}, 
                pctdistance=0.6, 
                colors=COLOR_PALETTE, 
                wedgeprops={"edgecolor":"k"},
                radius=1.2)
    optimal_ax.set_ylabel('')
    optimal_ax.legend(loc=3, labels=results.columns, fontsize=fontsize * 0.8, bbox_to_anchor=(-2.2, 0.1))
    optimal_ax.set_title('Optimal', fontsize=fontsize, y=1.05)
    
    # Draw model selections
    for prediction_ax, model in zip(prediction_axs, model_confs):
        res = results.loc[model.name.DISPLAY_NAME]
        if not (model == QPPNetModelConfig() and index == 1):
            prediction_ax.pie(res,
                   autopct=lambda pct: custom_autopct(pct), 
                   startangle=90, 
                   counterclock=True, 
                   labels=[" " for _ in range(len(res))], 
                   textprops={'fontsize': 11},  
                   pctdistance=0.6,  
                   colors=COLOR_PALETTE, 
                   radius=1.15,
                   wedgeprops={"edgecolor":"k"})
            prediction_ax.set_ylabel('')
            prediction_ax.set_title(model.name.DISPLAY_NAME, fontsize=fontsize, backgroundcolor=model.color(), y=1.05)

        else: 
            prediction_ax.axis('off')
            prediction_ax.text(0.5, 0.5, "not\nsupported", fontsize=fontsize, ha='center', va='center')
            prediction_ax.set_title(model.name.DISPLAY_NAME, fontsize=fontsize, backgroundcolor=model.color(), y=0.9)
    
    if index == 1:
        pick_rates = pick_rates[pick_rates["model_name"] != "QPP-Net"]
        runtimes = runtimes[runtimes["model_name"] != "QPP-Net"]
        
    # Draw runtimes for each workload
    sns.barplot(data=runtimes,
                x="model_name", 
                y="value", 
                ax=axes[-2], 
                width=1.0, 
                palette=ColorManager.COLOR_PALETTE, 
                hue="model_name", 
                edgecolor='black')
    runtime_ax.set_title("Runtime (s)", fontsize=fontsize, y=1.05)
    runtime_ax.set_ylabel('', fontsize=fontsize)
    runtime_ax.set_xlabel('')
    runtime_ax.set_xticklabels([], fontsize=fontsize)
    runtime_ax.set_ylim(0, 660)
    runtime_ax.set_yticklabels(axes[-2].get_yticklabels(), fontsize=fontsize * 0.8, rotation=45)
    runtime_ax.tick_params(axis="y", pad=-10)
    runtime_ax.axhline(y=actual_runtime / 1000, color='black', linestyle='--', linewidth=2)
    runtime_ax.annotate(f"{round(actual_runtime/1000, 2):.0f}s", xy=(0.3, 0.08), xycoords='axes fraction', fontsize=fontsize * 0.8, backgroundcolor='white')
    
    # Draw pick rates
    sns.barplot(data=pick_rates,
                x="model_name", 
                y="value", 
                ax=pick_rate_ax, 
                width=1.0, 
                palette=ColorManager.COLOR_PALETTE, 
                hue="model_name", 
                edgecolor='black')
    pick_rate_ax.set_ylabel('', fontsize=fontsize)
    pick_rate_ax.set_ylim(0, 100)
    pick_rate_ax.set_xlabel('')
    pick_rate_ax.set_xticklabels([], fontsize=fontsize)
    pick_rate_ax.yaxis.tick_right()
    pick_rate_ax.set_yticklabels(pick_rate_ax.get_yticklabels(), fontsize=fontsize * 0.8, rotation=45)
    pick_rate_ax.set_title("Pick Rate (%)", fontsize=fontsize, y=1.05)

    # Adjust layout
    plt.subplots_adjust(hspace=0.10)
    fig.align_labels()
    plt.savefig(path, bbox_inches='tight')