In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

from pathlib import Path

from rich.jupyter import display

from classifier.file_reader import read_files_from_folder
from evaluations.utils.wandb_loader import download_log_data, load_all_histories_to_dataframe
from plots.utils.plotting import write_figure_to_disk

NOTEBOOK_PATH = Path("experiments.ipynb").absolute().parent

DATA_DIR = f"{NOTEBOOK_PATH}/data/online"

BENCHMARK_NAMES = ["arc_challenge", "arc_easy", "boolq", "lambada_standard", "logiqa", "logiqa2", "piqa", "sciq", "social_iqa", "winogrande"]
# BENCHMARK_NAMES = ["winogrande"]

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Helvetica"
})



In [None]:
run_summary_df = download_log_data(
    entity="tum-i13",
    project_name="mess-plus_3-models_online_vFINAL",
    save_dir=DATA_DIR,
    batch_size=50
)

In [None]:
display(run_summary_df)
run_df = load_all_histories_to_dataframe(DATA_DIR)

for name in BENCHMARK_NAMES:
	run_df.loc[run_df["run_name"].str.contains(name), "benchmark_name"] = name
	run_df.loc[run_df["run_name"].str.contains(name), "run_name"] = run_df.loc[run_df["run_name"].str.contains(name), "run_name"].str.replace(f"{name}_", "")

run_df[["V", "alpha", "c", "seed"]] = run_df["run_name"].str.split("_", expand=True)
run_df["alpha"] = run_df["alpha"].str.replace("a=", "")
run_df["V"] = run_df["V"].str.replace("V=", "")
run_df["c"] = run_df["c"].str.replace("c=", "")
run_df["seed"] = run_df["seed"].str.replace("seed=", "")
run_df["alpha"] = run_df["alpha"].astype(float)
run_df["V"] = run_df["V"].astype(float)
run_df["c"] = run_df["c"].astype(float)
run_df["seed"] = run_df["seed"].astype(int)

run_df["models/small_chosen"] = run_df["models/small_chosen"].astype(float)
run_df["models/medium_chosen"] = run_df["models/medium_chosen"].astype(float)
run_df["models/large_chosen"] = run_df["models/large_chosen"].astype(float)

display(run_df.head())

In [None]:
analysis_df = run_df.loc[(run_df["c"] == 1.0) & (run_df["benchmark_name"] == "winogrande")].pivot_table(index=["benchmark_name", "alpha", "V", "c"], values=["avg_accuracy", "mess_plus/energy", "mess_plus/q_length", "total_runtime"], aggfunc={"avg_accuracy": "mean", "mess_plus/energy": "sum", "mess_plus/q_length": "mean", "total_runtime": "max"})

In [None]:
def add_value_labels(axx, spacing=5):
    """Add labels to the end of each bar in a bar chart.

    Arguments:
        ax (matplotlib.axes.Axes): The matplotlib object containing the axes
            of the plot to annotate.
        spacing (int): The distance between the labels and the bars.
    """

    # For each bar: Place a label
    for rect in axx.patches:
        # Get X and Y placement of label from rect.
        y_value = rect.get_height()
        x_value = rect.get_x() + rect.get_width() / 2

        # Number of points between bar and label. Change to your liking.
        space = spacing
        # Vertical alignment for positive values
        va = 'bottom'

        # If value of bar is negative: Place label below bar
        if y_value < 0:
            # Invert space to place label below
            space *= -1
            # Vertically align label at top
            va = 'top'

        # Use Y value as label and format number with one decimal place
        label = "{:.2f}".format(y_value / 1_000_000) # MJ conversion

        # Create annotation
        axx.annotate(
            label,                      # Use `label` as label
            (x_value, y_value),         # Place label at end of the bar
            xytext=(0, space),          # Vertically shift label by `space`
            textcoords="offset points", # Interpret `xytext` as offset in points
            ha='center',                # Horizontally center label
            va=va)                      # Vertically align label differently for
                                        # positive and negative values.

def fmt_to_megajoules(x, pos):
    return f'{(x / 1_000_000):.0f}'


In [None]:
# Load raw inference data

infer_df = pd.DataFrame()
def get_inference_data(benchmark_name):
	try:
		input_df = read_files_from_folder(folder_path=f"{NOTEBOOK_PATH.parent}/data/inference_outputs/{benchmark_name}")
		input_df["idx_original"] = input_df.index
		input_df = input_df.sample(frac=1).reset_index(drop=True)

		return input_df
	except ValueError:
		return pd.DataFrame()

for name in BENCHMARK_NAMES:
	infer_df = pd.concat([infer_df, get_inference_data(name)], ignore_index=True)

infer_df.reset_index(inplace=True)

# Get baseline dataframe
BASELINE_DATA_DIR = f"{NOTEBOOK_PATH}/data/random_baseline"
baseline_summary_df = download_log_data(
    entity="tum-i13",
    project_name="mess_plus_random_baseline_with_constraint_v01",
    save_dir=BASELINE_DATA_DIR,
    batch_size=5
)

baseline_df = load_all_histories_to_dataframe(BASELINE_DATA_DIR)

for benchmark in BENCHMARK_NAMES:
	baseline_df.loc[baseline_df["run_name"].str.contains(benchmark), "benchmark_name"] = benchmark
	baseline_df.loc[baseline_df["run_name"].str.contains(benchmark), "run_name"] = baseline_df.loc[baseline_df["run_name"].str.contains(benchmark), "run_name"].str.replace(f"{benchmark}_alpha=", "")
	baseline_df.loc[baseline_df["benchmark_name"] == benchmark, "alpha"] = baseline_df.loc[baseline_df["benchmark_name"] == benchmark, "run_name"]

	baseline_df["alpha"] = baseline_df["alpha"].astype(float)


print(baseline_df.head())


In [None]:
# Set the style for all plots
sns.set_style("whitegrid")
sns.set_palette(palette="dark:#5A9_r")
sns.set(font_scale=1.3)

# Create a figure and a grid of subplots: 4 rows, 10 columns
fig, axes = plt.subplots(nrows=3, ncols=7, figsize=(24, 8.5), gridspec_kw={'width_ratios': [2.3, 2.3, 2.3, 1, 1, 1, 1]})

# # Flatten the 2D array of axes for easier iteration
# axes = axes.flatten()

name = "arc_challenge"
subset = run_df.loc[(run_df["benchmark_name"] == name) & (run_df["c"] == 0.1) & (run_df["V"].isin(v_values_per_benchmark[name])) & (run_df["_step"] > 10)]

iterator = 0
for alpha in subset["alpha"].unique().tolist():
	v_values = subset["V"].unique().tolist()
	c_values = subset["c"].unique().tolist()

	# alpha = target_alpha_per_benchmark[name]

	# Accuracy Plot
	raw_inference_accuracies_per_model = infer_df[["benchmark_name", "label_small", "label_medium", "label_large"]].groupby("benchmark_name").mean().loc[name]

	axes[iterator][0].text(s="Llama 3.1 1B", x=subset["_step"].min() + 20, y=raw_inference_accuracies_per_model["label_small"] + 0.01, color='gray', fontsize=12, ha="left")
	axes[iterator][0].text(s="Llama 3.1 8B", x=(subset["_step"].min() + 1/2 * subset["_step"].max()), y=accuracies_per_model[name][1] + 0.01, color='gray', fontsize=12, ha="center")
	axes[iterator][0].text(s="Llama 3.3 70B", x=subset["_step"].max() - 20, y=raw_inference_accuracies_per_model["label_large"] + 0.01, color='gray', fontsize=12, ha="right")
	axes[iterator][0].axhline(y=raw_inference_accuracies_per_model["label_small"], color='gray', linestyle='--')
	axes[iterator][0].axhline(y=accuracies_per_model[name][1], color='gray', linestyle='--')
	axes[iterator][0].axhline(y=raw_inference_accuracies_per_model["label_large"], color='gray', linestyle='--')

	sns.lineplot(
	    data=subset.loc[(subset["alpha"] == alpha)],
	    x="_step",
	    y="avg_accuracy",
	    hue="V",
		errorbar=None,
		ax=axes[iterator][0],
		legend=True if iterator == 0 else False,
		palette="dark:#5A9_r",
	)

	axes[iterator][0].plot(
		baseline_df.loc[(baseline_df["benchmark_name"] == name) & (baseline_df["alpha"] == alpha), "_step"],
		baseline_df.loc[(baseline_df["benchmark_name"] == name) & (baseline_df["alpha"] == alpha),"avg_accuracy"],
		color="violet", linestyle="dotted", label="Rand."
	)

	# if iterator == 0:
	# 	axes[iterator][0].legend(ncols=4, loc='upper center', bbox_to_anchor=(0.5, 1.55), fontsize=12, title_fontsize=12, title="V", labelspacing =0.1)

	axes[iterator][0].axhline(y=alpha, color='red', linestyle='-', label="alpha")
	axes[iterator][0].text(s=r"$ \alpha = {alpha_val} $ ".format(alpha_val=alpha), x=subset["_step"].max() - 20, y=alpha + 0.01, color='red', fontsize=12, ha="right")

	axes[iterator][0].set(ylim=[0.97 * raw_inference_accuracies_per_model["label_small"], 1.15 * raw_inference_accuracies_per_model["label_large"]])
	axes[iterator][0].yaxis.set_major_formatter(mtick.PercentFormatter(1.0, decimals=0))

	# Q Plot for SLA violations
	sns.lineplot(
	    data=subset.loc[(subset["alpha"] == alpha)],
	    x="_step",
	    y="mess_plus/q_length",
	    hue="V",
		errorbar=None,
		ax=axes[iterator][1],
		legend=True if iterator == 0 else False,
		palette="dark:#5A9_r",
	)

	# if iterator == 0:
	# 	axes[iterator][1].legend(ncols=3, loc='upper center', bbox_to_anchor=(0.5, 1.55), fontsize=12, title_fontsize=12, title="V", labelspacing =0.1)

	# Energy consumption plot
	random_baseline_energy = baseline_df.loc[baseline_df["alpha"] == alpha, ["benchmark_name", "mess_plus/energy"]].groupby("benchmark_name").sum().loc[name].to_frame()
	random_baseline_energy["V"] = "Rand."
	random_baseline_energy["mess_plus/energy"] = random_baseline_energy[name]
	random_baseline_energy.reset_index(inplace=True)

	raw_inference_energy_data = infer_df[["benchmark_name", "energy_consumption_large", "energy_consumption_medium", "energy_consumption_small"]].groupby("benchmark_name").sum().loc[name].to_frame()
	raw_inference_energy_data["V"] = raw_inference_energy_data.index
	raw_inference_energy_data["mess_plus/energy"] = raw_inference_energy_data[name]
	raw_inference_energy_data.rename({name: "mess_plus/energy"}, inplace=True)
	raw_inference_energy_data.reset_index(inplace=True)

	raw_inference_energy_data["V"] = raw_inference_energy_data["V"].replace({"energy_consumption_large": "Llama 70B", "energy_consumption_medium": "Llama 8B", "energy_consumption_small": "Llama 1B"}, inplace=False)

	raw_inference_energy_data.drop([name, "index"], inplace=True, axis=1)

	energy_data = subset.loc[(subset["alpha"] == alpha)].groupby(["_step", "V"]).agg({"mess_plus/energy": "mean"}).groupby("V")["mess_plus/energy"].sum().reset_index()

	energy_data["V"] = energy_data["V"].apply(lambda sample: f"V={sample}")

	energy_data = pd.concat([random_baseline_energy, raw_inference_energy_data, energy_data], ignore_index=True)
	energy_data.reset_index(inplace=True)
	energy_data = energy_data.sort_values(by=["mess_plus/energy"], ascending=False)

	sns.barplot(
	    data=energy_data,
	    x="V",
	    y="mess_plus/energy",
		ax=axes[iterator][2],
		errorbar=("ci", 0.95)
	)

	add_value_labels(axes[iterator][2])
	axes[iterator][2].yaxis.set_major_formatter(plt.FuncFormatter(fmt_to_megajoules))
	axes[iterator][2].set(ylim=[0, 1.4 * energy_data["mess_plus/energy"].max()])
	axes[iterator][2].tick_params(axis='x', labelrotation=45)

	# Classifier training loss plot
	# sns.lineplot(
	#     data=subset.loc[(subset["alpha"] == alpha)],
	#     x="_step",
	#     y="classifier/train_loss",
	#     hue="V",
	# 	errorbar=None,
	# 	ax=axes[3][iterator],
	# 	legend=False,
	# )

	# Stackplot for Model Call Ratio
	v_values_per_benchmark[name] = sorted(v_values_per_benchmark[name], reverse=False)
	# v_values_per_benchmark[name].reverse()
	for jdx, V in enumerate(v_values_per_benchmark[name]):

		stack_df = subset.loc[
			(run_df["benchmark_name"] == name) &
			(run_df["V"] == V) &
			(subset["alpha"] == alpha),
			["_step", "V", "models/small_chosen", "models/medium_chosen", "models/large_chosen"]
		].groupby(["_step"]).mean().reset_index()

		x = stack_df["_step"]
		y = stack_df[["models/small_chosen", "models/medium_chosen", "models/large_chosen"]]
		y_stack = np.cumsum(y, axis=1)

		axes[iterator][3 + jdx].fill_between(x, 0, y_stack.iloc[:, 0], color="#2f364d", alpha=1.0)
		axes[iterator][3 + jdx].fill_between(x, y_stack.iloc[:, 0], y_stack.iloc[:, 1], color="#3f758a", alpha=1.0)
		axes[iterator][3 + jdx].fill_between(x, y_stack.iloc[:, 1], y_stack.iloc[:, 2], color="#69cf81", alpha=1.0)
		axes[iterator][3 + jdx].set(xlabel=f"Requests @ V={V}", xlim=[0, subset.loc[(subset["alpha"] == alpha), "_step"].max()], ylim=[0, 1])
		axes[iterator][3 + jdx].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
		axes[iterator][3 + jdx].set(xlim=[0, stack_df["_step"].max()])

		# if jdx > 0:
		# 	axes[iterator][3 + jdx].get_yaxis().set_visible(False)

	# Add area plot for random baseline with constraint.
	baseline_stack_df = baseline_df.loc[
			(baseline_df["benchmark_name"] == name) &
			(baseline_df["alpha"] == alpha),
			["_step", "models/small_chosen", "models/medium_chosen", "models/large_chosen"]
		].groupby(["_step"]).mean().reset_index()

	x_base = baseline_stack_df["_step"]
	y_base = baseline_stack_df[["models/small_chosen", "models/medium_chosen", "models/large_chosen"]]
	y_stack_base = np.cumsum(y_base, axis=1)

	axes[iterator][6].fill_between(x_base, 0, y_stack_base.iloc[:, 0], color="#2f364d", alpha=0.95)
	axes[iterator][6].fill_between(x_base, y_stack_base.iloc[:, 0], y_stack_base.iloc[:, 1], color="#3f758a", alpha=0.95)
	axes[iterator][6].fill_between(x_base, y_stack_base.iloc[:, 1], y_stack_base.iloc[:, 2], color="#69cf81", alpha=0.95)
	axes[iterator][6].set(xlabel=f"Requests (Rand.)", xlim=[0, baseline_stack_df["_step"].max()], ylim=[0, 1])
	axes[iterator][6].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
	axes[iterator][6].set(xlim=[0, baseline_stack_df["_step"].max()])
	# axes[iterator][6].get_yaxis().set_visible(False)

	if iterator == 0:
		axes[iterator][6].legend(["Llama 3.1 1B", "Llama 3.1 8B", "Llama 3.3 70B"], ncols=1, loc='upper center', fontsize=10.5, title_fontsize=10.5, title="Model")


	axes[iterator][0].set(xlabel="Requests", xlim=[0, subset.loc[(subset["alpha"] == alpha), "_step"].max()])
	axes[iterator][1].set(xlabel="Requests", xlim=[0, subset.loc[(subset["alpha"] == alpha), "_step"].max()])
	axes[iterator][2].set(xlabel="")
	# axes[3][iterator].set(xlabel="Request", xlim=[0, subset.loc[(subset["alpha"] == alpha), "_step"].max()])

	if iterator > 0:
		axes[iterator][0].set(ylabel=None)
		axes[iterator][1].set(ylabel=None)
		axes[iterator][2].set(ylabel=None)
		# axes[3][iterator].set(ylabel=None)

	for ax, col in zip(axes[iterator], [r"USF ($\alpha = {alpha_val}$)".format(alpha_val=alpha), "Q Length", "Cost (in MJ)", "Model Call Ratio (MCR)", "", ""]):

		if (col == "Model Call Ratio (MCR)" and iterator == 1) or col != "Model Call Ratio (MCR)":
			ax.set_ylabel(col, rotation=90, size=18)

	iterator += 1

fig.tight_layout()
write_figure_to_disk(plt, file_name=f"{name}_all_alpha", chapter_name="evaluations")


In [None]:
print(infer_df.columns)
print(infer_df.groupby("benchmark_name")["energy_consumption_large"].mean())
print(infer_df.groupby("benchmark_name")["energy_consumption_medium"].mean())
print(infer_df.groupby("benchmark_name")["energy_consumption_small"].mean())

In [None]:
# Plot generator
v_values_per_benchmark = {
    "arc_challenge": [0.001, 0.0001, 0.00001],
    "arc_easy": [0.01, 0.001, 0.0001],
    "boolq": [0.01, 0.001, 0.0001],
    # "lambada_standard": [0.01, 0.001, 0.0001],
    "logiqa": [0.001, 0.0001, 0.00001],
    # "logiqa2": [0.01, 0.001, 0.0001],
    "piqa": [0.01, 0.001, 0.0001],
    "sciq": [0.0001, 0.00001, 0.000001],
    "social_iqa": [0.001, 0.0001, 0.00001],
    "winogrande": [0.01, 0.001, 0.0001],
}

BENCHMARK_NAME_DICT = {
    "arc_challenge": "ARC Challenge",
    "arc_easy": "ARC Easy",
    "boolq": "BoolQ",
    "lambada_standard": "Lambada",
    "logiqa": "LogiQA",
    "logiqa2": "LogiQA2",
    "piqa": "PiQA",
    "sciq": "SciQ",
    "social_iqa": "SocialIQA",
    "winogrande": "WinoGrande",
}

# Create a list of all benchmark-alpha combinations
benchmark_alpha_combinations = []
for name in v_values_per_benchmark.keys():
    config_path = Path(f"{NOTEBOOK_PATH.parent}/config/online/{name}.yaml")
    with config_path.open("r") as f:
        import yaml
        CONFIG = yaml.safe_load(f)

    algorithm_config = CONFIG["algorithm"]
    for alpha in algorithm_config["alpha_values"]:
        benchmark_alpha_combinations.append((name, alpha))

# Initialize plotting variables
plot_num = 0
col_count = 0

# Iterate through all benchmark-alpha combinations
for combo_idx, (name, alpha) in enumerate(benchmark_alpha_combinations):

    # Create new figure every 6 columns
    if col_count == 0:
        sns.set(style="whitegrid")
        fig, axes = plt.subplots(nrows=7, ncols=6, figsize=(20, 12))
        plot_num += 1

    # Get current column index
    col_idx = col_count

    # Skip if this benchmark doesn't have V values configured
    if name not in v_values_per_benchmark.keys():
        continue

    # Filter data for current benchmark and alpha
    subset = run_df.loc[(run_df["benchmark_name"] == name) &
                       (run_df["c"] == 0.1) &
                       (run_df["V"].isin(v_values_per_benchmark[name])) &
                       (run_df["_step"] > 10) &
                       (run_df["alpha"] == alpha)]

    v_values = subset["V"].unique().tolist()

    # Accuracy Plot
    raw_inference_accuracies_per_model = infer_df[["benchmark_name", "label_small", "label_medium", "label_large"]].groupby("benchmark_name").mean().loc[name]

    axes[0][col_idx].text(s="Llama 3.1 1B", x=subset["_step"].min() + 20, y=raw_inference_accuracies_per_model["label_small"] + 0.025, color='gray', fontsize=8, ha="left")
    axes[0][col_idx].text(s="Llama 3.1 8B", x=(subset["_step"].min() + 1/2 * subset["_step"].max()), y=accuracies_per_model[name][1] + 0.025, color='gray', fontsize=8, ha="center")
    axes[0][col_idx].text(s="Llama 3.3 70B", x=subset["_step"].max() - 20, y=raw_inference_accuracies_per_model["label_large"] + 0.025, color='gray', fontsize=8, ha="right")
    axes[0][col_idx].axhline(y=raw_inference_accuracies_per_model["label_small"], color='gray', linestyle='--')
    axes[0][col_idx].axhline(y=accuracies_per_model[name][1], color='gray', linestyle='--')
    axes[0][col_idx].axhline(y=raw_inference_accuracies_per_model["label_large"], color='gray', linestyle='--')

    sns.lineplot(
        data=subset.loc[(subset["alpha"] == alpha)],
        x="_step",
        y="avg_accuracy",
        hue="V",
        errorbar=None,
        ax=axes[0][col_idx],
        legend=True if col_idx == 0 else False,
	    palette=["#2f364d", "#3f758a", "#69cf81"]
    )

    axes[0][col_idx].plot(
		baseline_df.loc[(baseline_df["benchmark_name"] == name) & (baseline_df["alpha"] == alpha), "_step"],
		baseline_df.loc[(baseline_df["benchmark_name"] == name) & (baseline_df["alpha"] == alpha),"avg_accuracy"],
		color="violet", linestyle="dotted", label="Rand."
	)

    axes[0][col_idx].axhline(y=alpha, color='red', linestyle='-')
    axes[0][col_idx].text(s=r"$ \alpha = {alpha_val} $ ".format(alpha_val=alpha), x=subset["_step"].max() - 20, y=alpha + 0.01, color='red', fontsize=8, ha="right")

    axes[0][col_idx].set(ylim=[0.97 * raw_inference_accuracies_per_model["label_small"], 1.15 * raw_inference_accuracies_per_model["label_large"]])

    if col_idx == 0:
        axes[0][col_idx].legend(ncols=2)

    # Q Plot for SLA violations
    sns.lineplot(
        data=subset.loc[(subset["alpha"] == alpha)],
        x="_step",
        y="mess_plus/q_length",
        hue="V",
        errorbar=None,
        ax=axes[1][col_idx],
        legend=True if col_idx == 0 else False,
	    palette=["#2f364d", "#3f758a", "#69cf81"]
    )

    if col_idx == 0:
        axes[1][col_idx].legend(ncols=2)

    # Energy consumption plot
    random_baseline_energy = baseline_df.loc[baseline_df["alpha"] == alpha, ["benchmark_name", "mess_plus/energy"]].groupby("benchmark_name").sum().loc[name].to_frame()
    random_baseline_energy["V"] = "Rand."
    random_baseline_energy["mess_plus/energy"] = random_baseline_energy[name]
    random_baseline_energy.reset_index(inplace=True)

    raw_inference_energy_data = infer_df[["benchmark_name", "energy_consumption_large", "energy_consumption_medium", "energy_consumption_small"]].groupby("benchmark_name").sum().loc[name].to_frame()
    raw_inference_energy_data["V"] = raw_inference_energy_data.index
    raw_inference_energy_data["mess_plus/energy"] = raw_inference_energy_data[name]
    raw_inference_energy_data.rename({name: "mess_plus/energy"}, inplace=True)
    raw_inference_energy_data.reset_index(inplace=True)

    raw_inference_energy_data["V"] = raw_inference_energy_data["V"].replace({"energy_consumption_large": "70B", "energy_consumption_medium": "8B", "energy_consumption_small": "1B"}, inplace=False)

    raw_inference_energy_data.drop([name, "index"], inplace=True, axis=1)
    energy_data = subset.loc[(subset["alpha"] == alpha)].groupby(["_step", "V"]).agg({"mess_plus/energy": "mean"}).groupby("V")["mess_plus/energy"].sum().reset_index()

    energy_data["V"] = energy_data["V"].apply(lambda sample: f"V={sample}")

    energy_data = pd.concat([random_baseline_energy, raw_inference_energy_data, energy_data], ignore_index=True)
    energy_data.reset_index(inplace=True)
    energy_data = energy_data.sort_values(by=["mess_plus/energy"], ascending=False)

    sns.barplot(
        data=energy_data,
        x="V",
        y="mess_plus/energy",
        ax=axes[2][col_idx],
        errorbar=("ci", 0.95),
    )

    add_value_labels(axes[2][col_idx])
    axes[2][col_idx].yaxis.set_major_formatter(plt.FuncFormatter(fmt_to_megajoules))
    axes[2][col_idx].set(ylim=[0, 2 * energy_data["mess_plus/energy"].max()])
    axes[2][col_idx].tick_params(axis='x', labelrotation=45)

    # Stackplot for Model Call Ratio
    for jdx, V in enumerate(v_values_per_benchmark[name]):

        stack_df = subset.loc[
            (run_df["benchmark_name"] == name) &
            (run_df["V"] == V) &
            (subset["alpha"] == alpha),
            ["_step", "models/small_chosen", "models/medium_chosen", "models/large_chosen"]
        ].groupby(["_step"]).mean().reset_index()

        x = stack_df["_step"]
        y = stack_df[["models/small_chosen", "models/medium_chosen", "models/large_chosen"]]
        y_stack = np.cumsum(y, axis=1)

        axes[3 + jdx][col_idx].fill_between(x, 0, y_stack.iloc[:, 0], color="#2f364d", alpha=0.95)
        axes[3 + jdx][col_idx].fill_between(x, y_stack.iloc[:, 0], y_stack.iloc[:, 1], color="#3f758a", alpha=0.95)
        axes[3 + jdx][col_idx].fill_between(x, y_stack.iloc[:, 1], y_stack.iloc[:, 2], color="#69cf81", alpha=0.95)
        axes[3 + jdx][col_idx].set(xlabel=f"Request @ V={V}", xlim=[0, subset.loc[(subset["alpha"] == alpha), "_step"].max()], ylim=[0, 1])
        axes[3 + jdx][col_idx].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

        if jdx == 0 and col_idx == 0:
            axes[3 + jdx][col_idx].legend(["Llama 3.1 1B", "Llama 3.1 8B", "Llama 3.3 70B"])

    # Add area plot for random baseline with constraint.
    baseline_stack_df = baseline_df.loc[
            (baseline_df["benchmark_name"] == name) &
            (baseline_df["alpha"] == alpha),
            ["_step", "models/small_chosen", "models/medium_chosen", "models/large_chosen"]
        ].groupby(["_step"]).mean().reset_index()

    x_base = baseline_stack_df["_step"]
    y_base = baseline_stack_df[["models/small_chosen", "models/medium_chosen", "models/large_chosen"]]
    y_stack_base = np.cumsum(y_base, axis=1)

    axes[6][col_idx].fill_between(x_base, 0, y_stack_base.iloc[:, 0], color="#2f364d", alpha=0.95)
    axes[6][col_idx].fill_between(x_base, y_stack_base.iloc[:, 0], y_stack_base.iloc[:, 1], color="#3f758a", alpha=0.95)
    axes[6][col_idx].fill_between(x_base, y_stack_base.iloc[:, 1], y_stack_base.iloc[:, 2], color="#69cf81", alpha=0.95)
    axes[6][col_idx].set(xlabel=f"Requests (Rand.)", xlim=[0, baseline_stack_df["_step"].max()], ylim=[0, 1])
    axes[6][col_idx].yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
    axes[6][col_idx].set(xlim=[0, baseline_stack_df["_step"].max()])

    # Set axis properties
    axes[0][col_idx].set(xlabel="Request", xlim=[0, subset.loc[(subset["alpha"] == alpha), "_step"].max()])
    axes[1][col_idx].set(xlabel="Request", xlim=[0, subset.loc[(subset["alpha"] == alpha), "_step"].max()])
    axes[2][col_idx].set(xlabel="")

    # Remove y-labels for columns after the first
    if col_idx > 0:
        axes[0][col_idx].set(ylabel=None)
        axes[1][col_idx].set(ylabel=None)
        axes[2][col_idx].set(ylabel=None)

    # Set title for each column
    axes[0][col_idx].set_title(r"{bm_name} ($\alpha = {alpha_val} $)".format(bm_name=BENCHMARK_NAME_DICT[name], alpha_val=alpha))

    # Increment column counter
    col_count += 1

    # Check if we need to save the current figure and start a new one
    if col_count == 6 or combo_idx == len(benchmark_alpha_combinations) - 1:
        # Add row labels
        for idx, (ax, row) in enumerate(zip(axes[:,0], ["User Satisfaction", "Q Length", "Cost (in MJ energy)", "", "", "", ""])):
            if idx == 5:
                fig.text(0.003, 0.225, "Model Call Ratio (MCR)", ha="center", rotation='vertical', fontsize=plt.rcParams['axes.labelsize'])
            else:
                ax.set_ylabel(row, rotation=90, size='large')

        # Save the figure
        fig.tight_layout()
        write_figure_to_disk(plt, file_name=f"benchmark_performance_plot_{plot_num}", chapter_name="evaluations")

        # Reset column counter for next figure
        col_count = 0