# Simulations for the MESS+ estimator

In [1]:
import time

import numpy as np
import pandas as pd
import wandb
import yaml

from pathlib import Path

from classifier.model import MultilabelBERTClassifier
from classifier.file_reader import read_files_from_folder
from classifier.dataset import create_bert_datasets, preprocess_dataframe

from utils.mess_plus import sample_from_bernoulli


INFO 05-05 13:10:10 [__init__.py:239] Automatically detected platform cuda.


In [2]:
BENCHMARK_NAME = "boolq"
# this refers to whether we want to use a pre-trained classifier or learn the classifier online while benchmarking.
APPROACH = "pretrained"  # alt: online
# This only has an effect when APPROACH = pretrained. Make sure to adjust the minibatch size accordingly!
NUM_PRETRAINING_STEPS = 400
SEEDS = [42]

PROJECT_ROOT_PATH = Path("mess_plus_simulator").parent

## Load benchmark config

In [3]:

if APPROACH == "pretrained":
	config_path = Path(f"{PROJECT_ROOT_PATH}/config/pretrained/{BENCHMARK_NAME}.yaml")
elif APPROACH == "online":
	config_path = Path(f"{PROJECT_ROOT_PATH}/config/online/{BENCHMARK_NAME}.yaml")
	NUM_PRETRAINING_STEPS = 0
else:
	raise NotImplementedError(f"Approach {APPROACH} not implemented.")

with config_path.open("r") as f:
	CONFIG = yaml.safe_load(f)
	display(CONFIG)


{'run_name': 'baseline',
 'seed': 43,
 'model_zoo': {'meta-llama/Llama-3.2-1B-Instruct': {'category': 'small',
   'gpu_indices': [0],
   'max_seq_len': 2048,
   'gpu_memory_utilization': 0.12,
   'quantization': None},
  'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit': {'category': 'medium',
   'gpu_indices': [0],
   'max_seq_len': 2048,
   'gpu_memory_utilization': 0.15,
   'quantization': 'bitsandbytes'},
  'unsloth/Llama-3.3-70B-Instruct-bnb-4bit': {'category': 'large',
   'gpu_indices': [0],
   'max_seq_len': 2048,
   'gpu_memory_utilization': 0.68,
   'quantization': 'bitsandbytes'}},
 'classifier_model': {'model_id': 'answerdotai/ModernBERT-base',
  'epochs': 3,
  'learning_rate': 0.0767,
  'weight_decay': 0.01,
  'momentum': 0.95,
  'batch_size': 128,
  'max_length': 128,
  'warmup_ratio': 0.1,
  'threshold': 0.5,
  'dropout_rate': 0.1,
  'freeze_bert_layers': True,
  'memory_size': 0,
  'memory_strategy': 'random',
  'reset_optimizer': False,
  'regularization_lambda': 0.0,
  'gp

## Load dataset

In [4]:
input_df = read_files_from_folder(folder_path=f"{PROJECT_ROOT_PATH}/data/inference_outputs/{BENCHMARK_NAME}")
input_df["idx_original"] = input_df.index
input_df = input_df.sample(frac=1).reset_index(drop=True)

display(f"Loaded dataframe with {input_df.shape[0]} rows and {input_df.shape[1]} columns")
display(f"{len(input_df.columns.tolist())} available columns: {input_df.columns.tolist()}")
display(input_df.head())

'Loaded dataframe with 3270 rows and 15 columns'

"15 available columns: ['input_text', 'benchmark_name', 'label_small', 'acc_small', 'energy_consumption_small', 'inference_time_small', 'label_medium', 'acc_medium', 'energy_consumption_medium', 'inference_time_medium', 'label_large', 'acc_large', 'energy_consumption_large', 'inference_time_large', 'idx_original']"

Unnamed: 0,input_text,benchmark_name,label_small,acc_small,energy_consumption_small,inference_time_small,label_medium,acc_medium,energy_consumption_medium,inference_time_medium,label_large,acc_large,energy_consumption_large,inference_time_large,idx_original
0,is a dachshund the same as a sausage dog,boolq,0.0,0.0,38.187,0.129817,1.0,1.0,125.057,0.396348,1.0,1.0,1025.541,2.083502,1214
1,has a school shooting happened in a private sc...,boolq,1.0,1.0,31.284,0.132452,1.0,1.0,132.079,0.399106,1.0,1.0,1048.761,2.090807,381
2,can law of sines be used on any triangle,boolq,1.0,1.0,36.066,0.126652,0.0,0.0,126.523,0.392874,1.0,1.0,1023.494,2.077417,1008
3,does colorado school of mines have a football ...,boolq,1.0,1.0,25.088,0.147779,1.0,1.0,132.694,0.410065,1.0,1.0,1060.18,2.110664,2763
4,has anyone ever escaped from alcatraz and lived,boolq,0.0,0.0,24.779,0.152949,1.0,1.0,126.964,0.408143,1.0,1.0,1045.153,2.092828,873


## Load, configure, and train classifier

In [5]:
text_col = "input_text"
label_cols = ["label_small", "label_medium", "label_large"]

classifier = MultilabelBERTClassifier(num_labels=3, **CONFIG["classifier_model"])
training_df = input_df.loc[:NUM_PRETRAINING_STEPS]
training_df = preprocess_dataframe(training_df, label_cols=label_cols)

train_dataset, val_dataset, tokenizer = create_bert_datasets(
	training_df,
	text_col,
	label_cols,
	model_name=CONFIG["classifier_model"]["model_id"],
	max_length=CONFIG["classifier_model"]["max_length"],
	val_ratio=CONFIG["classifier_model"]["validation_dataset_size"],
	random_seed=SEEDS[0],
)

training_stats = classifier.fit(train_dataset, val_dataset, epochs=CONFIG["classifier_model"]["epochs"], early_stopping_patience=2)

display(training_stats)


INFO:classifier.model:Using device: cuda
INFO:classifier.model:Initializing custom BERTClassifier: answerdotai/ModernBERT-base with 3 labels
Epoch 1/3 [Training]: 100%|██████████| 3/3 [00:01<00:00,  2.13it/s, loss=0.7039, batch_loss=0.5349, lr=0.0767]
Epoch 1/3 [Validation]: 100%|██████████| 1/1 [00:00<00:00,  2.38it/s, val_loss=0.8923, avg_val_loss=0.8923]
INFO:classifier.model:Epoch 1/3 - Time: 1.83s
INFO:classifier.model:  Train Loss: 0.5912 - Val Loss: 0.8923
INFO:classifier.model:  Val Metrics - Accuracy: 0.7417, F1: 0.8517, F1(macro): 0.8421
INFO:classifier.model:  Per-label metrics:
INFO:classifier.model:    Label 0: F1=0.6885, Prec=0.5250, Rec=1.0000
INFO:classifier.model:    Label 1: F1=0.9189, Prec=0.8500, Rec=1.0000
INFO:classifier.model:    Label 2: F1=0.9189, Prec=0.8500, Rec=1.0000
INFO:classifier.model:Directory /home/woi/code/mess-plus/classifier/checkpoints/boolq existis. Reusing...
INFO:classifier.model:Best model saved
Epoch 2/3 [Training]: 100%|██████████| 3/3 [00:0

{'epoch': 3,
 'train/loss': 0.5819895267486572,
 'val/loss': 0.6913987994194031,
 'val/accuracy': 0.7416666666666667,
 'val/precision_micro': 0.7416666666666667,
 'val/recall_micro': 1.0,
 'val/f1_micro': 0.8516746411483254,
 'val/f1_macro': 0.8421208093339242,
 'time/epoch_seconds': 0.3879730701446533,
 'val/0_f1score': 0.6885245901639344,
 'val/0_recall': 0.525,
 'val/0_precision': 1.0,
 'val/1_f1score': 0.918918918918919,
 'val/1_recall': 0.85,
 'val/1_precision': 1.0,
 'val/2_f1score': 0.918918918918919,
 'val/2_recall': 0.85,
 'val/2_precision': 1.0}

In [6]:
# Dataset statistics
display(input_df[NUM_PRETRAINING_STEPS:]["label_small"].mean())
display(input_df[NUM_PRETRAINING_STEPS:]["label_medium"].mean())
display(input_df[NUM_PRETRAINING_STEPS:]["label_large"].mean())


0.6898954703832753

0.8432055749128919

0.8895470383275261

In [8]:
algorithm_config = CONFIG["algorithm"]

model_categories = [i for i in CONFIG["model_zoo"].keys()]
sample_cols = input_df.columns.tolist()

ALPHA_VALUES = [0.75]
C_VALUES = [1.0]
V_VALUES = [1.0]
R_VALUES = [2, 5, 10, 15]

for alpha in ALPHA_VALUES:
	for c in C_VALUES:
		for v in V_VALUES:
			for r in R_VALUES:
				algorithm_config["V"] = v
				algorithm_config["alpha"] = alpha
				algorithm_config["c"] = c
				algorithm_config["r"] = r

				ACCURACY_LIST = []
				EXPLORATION_STEP_LIST = []
				ENERGY_CONSUMPTION_LIST = []
				INFERENCE_TIME_LIST = []
				ENERGY_PER_MODEL = {
					"small": [0.0],
					"medium": [0.0],
					"large": [0.0],
				}

				model_category_list = [i for i in ENERGY_PER_MODEL.keys()]

				Q = 0.0
				ctr = 0

				run = wandb.init(
					project=f"mess-plus_runs_vTEST",
					name=f"{BENCHMARK_NAME}_V={algorithm_config['V']}_a={algorithm_config['alpha']}_c={algorithm_config['c']}_r={r}",
					config=CONFIG
				)

				if wandb.run is not None:
					run.summary.update({**{f"classifier/{k}": v for k, v in training_stats.items()}})

				monitoring_dict = {}
				for idx, sample in input_df[NUM_PRETRAINING_STEPS:].iterrows():
					p_t, x_t = sample_from_bernoulli(c=algorithm_config["c"], timestamp=idx)
					EXPLORATION_STEP_LIST.append(x_t)

					if x_t == 1:
						result = sample["label_large"]
						ACCURACY_LIST.append(result)
						step_energy = sum([sample[i] for i in sample_cols if "energy" in i])
						step_time = sum([sample[i] for i in sample_cols if "inference" in i])
						ENERGY_CONSUMPTION_LIST.append(step_energy)
						INFERENCE_TIME_LIST.append(step_time)
						for i in ENERGY_PER_MODEL.keys():
							ENERGY_PER_MODEL[i] = sample[f"energy_consumption_{i}"]

						monitoring_dict[f"mess_plus/energy"] = step_energy
						monitoring_dict[f"mess_plus/chosen_model"] = len(model_category_list) - 1

					else:
						preds, probs = classifier.predict(texts=[sample["input_text"]])
						energy = pd.DataFrame(ENERGY_PER_MODEL, index=[0]).to_numpy()

						energy = np.array(energy).reshape(-1, 1)
						probs = probs.reshape(-1, 1)

						cost_fn = algorithm_config["V"] * energy + np.power(Q * algorithm_config["alpha"] - probs, r)
						cost_fn = cost_fn.reshape(1, -1)
						chosen_model_id = np.argmin(cost_fn)
						model_category_chosen = model_category_list[chosen_model_id]

						result = sample[f"label_{model_category_chosen}"]
						step_energy = sample[f"energy_consumption_{model_category_chosen}"]
						step_time = sample[f"inference_time_{model_category_chosen}"]

						INFERENCE_TIME_LIST.append(step_time)
						ENERGY_CONSUMPTION_LIST.append(step_energy)

						monitoring_dict[f"mess_plus/energy"] = step_energy
						monitoring_dict[f"mess_plus/chosen_model"] = chosen_model_id

						ACCURACY_LIST.append(result)

					Q = max(0.0, Q + algorithm_config["alpha"] - result)

					monitoring_dict.update({
						"mess_plus/p_t": p_t,
						"mess_plus/x_t": x_t,
						"mess_plus/exploration_step_ratio": sum(EXPLORATION_STEP_LIST) / (ctr + 1),
						"mess_plus/q_length": Q,
						"mess_plus/accuracy": sum(ACCURACY_LIST) / (ctr + 1),
						"mess_plus/step_time": step_time,
						"mess_plus/total_runtime": sum(INFERENCE_TIME_LIST),
						"mess_plus/step_energy_consumption": step_energy,
					})

					ctr += 1

					if wandb.run is not None:
						wandb.log(monitoring_dict, step=ctr)

                if wandb.run is not None:
                    wandb.finish()
					time.sleep(2)

print(f"DONE")

TabError: inconsistent use of tabs and spaces in indentation (<string>, line 104)