# Simulations for the MESS+ estimator

In [None]:
import time

import numpy as np
import pandas as pd
import wandb
import yaml

from pathlib import Path

from classifier.model import MultilabelBERTClassifier
from classifier.file_reader import read_files_from_folder
from classifier.dataset import create_bert_datasets, preprocess_dataframe

from utils.mess_plus import sample_from_bernoulli


In [None]:
BENCHMARK_NAME = "winogrande"
# this refers to whether we want to use a pre-trained classifier or learn the classifier online while benchmarking.
APPROACH = "online"  # alt: online
# This only has an effect when APPROACH = pretrained. Make sure to adjust the minibatch size accordingly!
NUM_PRETRAINING_STEPS = 400
SEEDS = [42]
NUM_CLASSIFIER_LABELS = 3

PROJECT_ROOT_PATH = Path("mess_plus_simulator").parent

## Load benchmark config

In [None]:

if APPROACH == "pretrained":
	config_path = Path(f"{PROJECT_ROOT_PATH}/config/pretrained/{BENCHMARK_NAME}.yaml")
elif APPROACH == "online":
	config_path = Path(f"{PROJECT_ROOT_PATH}/config/online/{BENCHMARK_NAME}.yaml")
	NUM_PRETRAINING_STEPS = 0
else:
	raise NotImplementedError(f"Approach {APPROACH} not implemented.")

with config_path.open("r") as f:
	CONFIG = yaml.safe_load(f)
	display(CONFIG)


## Load dataset

In [None]:
input_df = read_files_from_folder(folder_path=f"{PROJECT_ROOT_PATH}/data/inference_outputs/{BENCHMARK_NAME}")
input_df["idx_original"] = input_df.index
input_df = input_df.sample(frac=1).reset_index(drop=True)

display(f"Loaded dataframe with {input_df.shape[0]} rows and {input_df.shape[1]} columns")
display(f"{len(input_df.columns.tolist())} available columns: {input_df.columns.tolist()}")
display(input_df.head())

## Load, configure, and train classifier

In [None]:
text_col = "input_text"
label_cols = ["label_small", "label_medium", "label_large"]

classifier = MultilabelBERTClassifier(num_labels=NUM_CLASSIFIER_LABELS, **CONFIG["classifier_model"])
training_df = input_df.loc[:NUM_PRETRAINING_STEPS]
training_df = preprocess_dataframe(training_df, label_cols=label_cols)

train_dataset, val_dataset, tokenizer = create_bert_datasets(
	training_df,
	text_col,
	label_cols,
	model_name=CONFIG["classifier_model"]["model_id"],
	max_length=CONFIG["classifier_model"]["max_length"],
	val_ratio=CONFIG["classifier_model"]["validation_dataset_size"],
	random_seed=SEEDS[0],
)

training_stats = classifier.fit(train_dataset, val_dataset, epochs=CONFIG["classifier_model"]["epochs"], early_stopping_patience=2)

display(training_stats)


In [None]:
# Dataset statistics
display(input_df[NUM_PRETRAINING_STEPS:]["label_small"].mean())
display(input_df[NUM_PRETRAINING_STEPS:]["label_medium"].mean())
display(input_df[NUM_PRETRAINING_STEPS:]["label_large"].mean())


In [None]:
algorithm_config = CONFIG["algorithm"]

model_categories = [i for i in CONFIG["model_zoo"].keys()]
sample_cols = input_df.columns.tolist()

ALPHA_VALUES = algorithm_config["alpha_values"]
C_VALUES = [1.0]
V_VALUES = [0.01, 0.001, 0.0001, 0.00001, 0.000001]
R_VALUES = [1]

for alpha in ALPHA_VALUES:
	for c in C_VALUES:
		for v in V_VALUES:
			algorithm_config["V"] = v
			algorithm_config["alpha"] = alpha
			algorithm_config["c"] = c

			ACCURACY_LIST = []
			EXPLORATION_STEP_LIST = []
			ENERGY_CONSUMPTION_LIST = []
			INFERENCE_TIME_LIST = []
			MODEL_CHOSEN_LIST = []

			ENERGY_PER_MODEL = {
				"small": [0.01],
				"medium": [0.1],
				"large": [1.0],
			}

			model_category_list = [i for i in ENERGY_PER_MODEL.keys()]

			Q = 0.0
			ctr = 0

			run = wandb.init(
				project=f"mess-plus_runs_vTEST2",
				name=f"{BENCHMARK_NAME}_V={algorithm_config['V']}_a={algorithm_config['alpha']}_c={algorithm_config['c']}_r={r}",
				config=CONFIG
			)

			if wandb.run is not None:
				run.summary.update({**{f"classifier/{k}": v for k, v in training_stats.items()}})

			monitoring_dict = {}
			for idx, sample in input_df[NUM_PRETRAINING_STEPS:].iterrows():
				p_t, x_t = sample_from_bernoulli(c=algorithm_config["c"], timestamp=idx)
				EXPLORATION_STEP_LIST.append(x_t)

				if x_t == 1:
					result = sample["label_large"]
					ACCURACY_LIST.append(result)
					step_energy = sum([sample[i] for i in sample_cols if "energy" in i])
					step_time = sum([sample[i] for i in sample_cols if "inference" in i])
					ENERGY_CONSUMPTION_LIST.append(step_energy)
					INFERENCE_TIME_LIST.append(step_time)
					for i in ENERGY_PER_MODEL.keys():
						ENERGY_PER_MODEL[i] = sample[f"energy_consumption_{i}"]

					monitoring_dict[f"mess_plus/energy"] = step_energy
					monitoring_dict[f"mess_plus/chosen_model"] = len(model_category_list) - 1

				else:
					preds, probs = classifier.predict(texts=[sample["input_text"]])
					energy = pd.DataFrame(ENERGY_PER_MODEL, index=[0]).to_numpy()

					energy = np.array(energy).reshape(-1, 1)
					probs = probs.reshape(-1, 1)

					cost_fn = algorithm_config["V"] * energy + Q * (alpha - probs)
					cost_fn = cost_fn.reshape(1, -1)
					chosen_model_id = np.argmin(cost_fn)
					# print(f"STEP={ctr} - V={v} - Q={Q} - CHOSEN MODEL: {chosen_model_id} - COST FN: {cost_fn}")
					model_category_chosen = model_category_list[chosen_model_id]

					result = sample[f"label_{model_category_chosen}"]
					step_energy = sample[f"energy_consumption_{model_category_chosen}"]
					step_time = sample[f"inference_time_{model_category_chosen}"]

					INFERENCE_TIME_LIST.append(step_time)
					ENERGY_CONSUMPTION_LIST.append(step_energy)
					MODEL_CHOSEN_LIST.append(chosen_model_id)

					monitoring_dict[f"mess_plus/energy"] = step_energy
					monitoring_dict[f"mess_plus/chosen_model"] = chosen_model_id

					ACCURACY_LIST.append(result)

				Q = max(0.0, Q + algorithm_config["alpha"] - result)

				x = np.array(MODEL_CHOSEN_LIST)
				monitoring_dict.update({
					"mess_plus/p_t": p_t,
					"mess_plus/x_t": x_t,
					"mess_plus/exploration_step_ratio": sum(EXPLORATION_STEP_LIST) / (ctr + 1),
					"mess_plus/q_length": Q,
					"mess_plus/accuracy": sum(ACCURACY_LIST) / (ctr + 1),
					"mess_plus/step_time": step_time,
					"mess_plus/total_runtime": sum(INFERENCE_TIME_LIST),
					"mess_plus/step_energy_consumption": step_energy,
					"models/small_chosen": len(np.where(x == 0)[0]) / (len(x) + 1e-8),
					"models/medium_chosen": len(np.where(x == 1)[0]) / (len(x) + 1e-8),
					"models/large_chosen": len(np.where(x == 2)[0]) / (len(x) + 1e-8),
				})

				print(monitoring_dict)

				ctr += 1
				if wandb.run is not None:
					wandb.log(monitoring_dict, step=ctr)

				if ctr % 3 == 0 and ctr > 0:
					break

			if wandb.run is not None:
				wandb.finish()
				time.sleep(2)


print(f"DONE")