<!-- du4://thèse/cai/results.ipynb?d=20251024?loc=ttum?hPa=1020 -->

# Confidential Artificial Intelligence: What's the Catch?
### _Performance and costs_

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import seaborn as sns

from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pathlib import Path
from typing import List, Tuple

In [None]:
sns.set_theme(style="ticks", context="paper")
sns.set_palette("colorblind")

In [None]:
class TEE_Mode(Enum):
    TEE_ON = "tee_on"
    TEE_OFF = "tee_off"

class Sorting(Enum):
    LEXICOGRAPHGIC = "lexicographic"
    NATURAL = "narutal"
    MODEL_SIZE = "model_size"

@dataclass
class Experiment:
    path: Path

    @property
    def name(self) -> str:
        return self.path.stem

    @cached_property
    def conditions(self) -> List["Condition"]:
        return [
            Condition(
                q, 
                q.parent.name, 
                TEE_Mode.TEE_ON if q.name == TEE_Mode.TEE_ON.value else TEE_Mode.TEE_OFF
            )
            for q in self.path.glob("*/*")
            if q.is_dir() and q.name in {TEE_Mode.TEE_ON.value, TEE_Mode.TEE_OFF.value}
        ]
    
    def list_condition_names(self, sorting: Sorting):
        match sorting:
            case Sorting.MODEL_SIZE:
                all_conditions_sorted = sorted(self.conditions, key=lambda c: int(c.model_size))
                return list(dict.fromkeys(c.name for c in all_conditions_sorted))
            case Sorting.NATURAL:
                condition_names = list(dict.fromkeys(c.name for c in self.conditions))
                return sorted(condition_names, key=self._natural_sort_key)
            case _:
                return list(dict.fromkeys(c.name for c in self.conditions))

    def get_conditions(self, tee_mode: TEE_Mode):
        return [c for c in self.conditions if c.tee_mode == tee_mode]

    def get_condition(self, name: str, tee_mode: TEE_Mode):
        return next(
            filter(lambda c: c.name == name and c.tee_mode == tee_mode, self.conditions),
            None,
        )
    
    def get_all_runs(self):
        return [r for c in self.conditions for r in c.runs]
    
    def get_runs(self, tee_mode: TEE_Mode):
        return [r for c in self.conditions for r in c.runs if c.tee_mode == tee_mode]
  
    def _natural_sort_key(self, text: str):
        return [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', text)]

    def __str__(self):
        name = self.name
        nb_conditions = len(self.conditions)
        nb_total_runs = len(self.get_all_runs())
        return f"Experiment: {name}, Conditions: {nb_conditions} ({nb_total_runs} total measurements)"

@dataclass
class Condition:
    path: Path
    name: str
    tee_mode: TEE_Mode

    @property
    def model_name(self) :
        return self.path.parent.name.split("_")

    @property
    def model_size(self) -> str:
        return re.search(r"(\d+)[bB]", "_".join(self.model_name)).group(1)

    @cached_property
    def runs(self) -> List["Run"]:
        run_paths = list(self.path.glob("*repetition_*"))
        json_files = sorted([r for r in run_paths if r.suffix == ".json"])
        csv_files = sorted([r for r in run_paths if r.suffix == ".csv"])

        assert len(list(run_paths)) > 0, "Empty results"
        assert len(json_files) == len(csv_files), f"Mismatch: {len(json_files)} .json vs. {len(csv_files)} .csv: {run_paths}"

        return [
            Run(
                idx, json_file, self.path / f"{json_file.stem}_power_metrics.csv"
            )
            for idx, json_file in enumerate(json_files)
        ]

    def get_all_runs(self) -> List["Run"]:
        return self.runs

    def get_run(self, index: int) -> "Run":
        return self.runs[index]

    def get_median_throughput_with_std(self) -> Tuple[float, float]:
        output_throughputs = [
            rep.get_vllm_key("output_throughput") for rep in self.runs
        ]
        return np.median(output_throughputs), np.std(output_throughputs)
    
    def get_median_ttft_with_std_and_p95(self):# -> Tuple[float, float, float]:
        latencies = [
            rep.get_vllm_key("ttfts") for rep in self.runs 
        ]
        #return np.median(latencies), np.std(ltencies), np.percentile(latencies, 95) # TODO: compare to VLLM output
        return latencies

    def get_median_itl_with_std_and_p95(self) -> Tuple[float, float, float]:
        latencies = [
            rep.get_vllm_key("itls") for rep in self.runs
        ]
        return np.median(latencies), np.std(latencies), np.percentile(latencies, 95) # TODO: compare to VLLM output

@dataclass
class Run:
    index: int
    path_vllm_json: Path
    path_power_csv: Path

    @cached_property
    def vllm_metrics(self) -> dict:
        return json.loads(self.path_vllm_json.read_text())

    @cached_property
    def gpu_metrics(self) -> pd.DataFrame:
        return pd.read_csv(self.path_power_csv)

    # Input
        
    @property
    def model_id(self) -> str:
        return self.vllm_metrics["model_id"]

    @property
    def num_prompts(self) -> int:
        return self.vllm_metrics["num_prompts"]

    @property
    def input_len(self) -> list[int]:
        return self.vllm_metrics["input_lens"]

    @property
    def output_len(self) -> list[int]:
        return self.vllm_metrics["output_lens"]

    @property
    def max_concurrency(self) -> int:
        return self.vllm_metrics["max_concurrency"]

    @property
    def request_rate(self) -> str:
        return self.vllm_metrics["request_rate"]

    @property
    def burstiness(self) -> float:
        return self.vllm_metrics["burstiness"]

    # Output

    @property
    def duration(self) -> float:
        return self.vllm_metrics["duration"]

    @property
    def max_output_tokens_per_second(self) -> float:
        return self.vllm_metrics["max_output_tokens_per_s"]

    @property
    def max_concurrent_requests(self) -> int:
        return self.vllm_metrics["max_concurrent_requests"]

    def get_vllm_key(self, key: str):
        return self.vllm_metrics[key]
    
    def __str__(self):
        return f"[{self.index}] {self.model_id}, num prompts: {self.num_prompts}, input length: {self.input_len[0]}, output length: {self.output_len[0]}, concurrency: {self.max_concurrent_requests}/{self.max_concurrency}, request rate: {self.request_rate}, burstiness: {self.burstiness}"

## 0. Data summary

In [None]:
# Parent folder containing the data
data_path = Path("data", "calibration")  # ← FIXME
# The experiments
exp_throughput_latency = Experiment(data_path.joinpath("throughput_latency"))
sweep_1   = Experiment(data_path.joinpath("saturation_point"))
exp_sequence_overhead  = Experiment(data_path.joinpath("sequence_overhead"))
exp_energy             = Experiment(data_path.joinpath("energy"))
# All experiments
all_exps = (exp_throughput_latency, sweep_1, exp_sequence_overhead, exp_energy)

In [None]:
def format_seconds_long(seconds: float) -> str:
    h, remainder = divmod(int(seconds), 3600)
    m, s = divmod(remainder, 60)
    return f"{h:02d}:{m:02d}:{s:02d}"

nb_total_runs = 0
duration_total = 0
print(f"• Number of experiments: {len(all_exps)}")
for exp in all_exps:
    print(f"  • {str(exp)}")
    all_runs = exp.get_all_runs()
    nb_total_runs += len(all_runs)
    for run in all_runs:
        duration_total += run.get_vllm_key("duration")
print(f"• Total measurements: {nb_total_runs}")
print(f"• Total duration: {format_seconds_long(duration_total)}")
print(f"• Estimated Azure price: {round(duration_total / 3600 * 7, 3)} €")


## 1. Throughput and Latency

### 1.1. Data summary

In [None]:
exp_throughput_latency = Experiment(data_path.joinpath("throughput_latency"))

In [None]:
rows = []  # We build row by row
for c in exp_throughput_latency.conditions:
    for run in c.runs:
        assert all(e == "" for e in run.get_vllm_key("errors")), (
            f"vLLM reported an error during measurement. Check .json {run.path_vllm_json}"
        )
        assert run.get_vllm_key("completed") == run.get_vllm_key("num_prompts"), "Run crashed."
        rows.append(
            {
                # Condition
                "condition": c.name,
                "tee_mode": c.tee_mode.value,
                # Measurement
                "Measurement #": run.index,
                "duration (s)": round(run.get_vllm_key("duration")),
                # Throughput
                "output throughput (tok/s)": run.get_vllm_key("output_throughput"),
                "total token throughput (tok/s)": run.get_vllm_key(
                    "total_token_throughput"
                ),
                # Latency TODO
                # Cloud
                "Azure cost (€)": round(run.get_vllm_key("duration") / 3600 * 7, 3),
            }
        )

pd.DataFrame(rows)

### (GPU calibaration)

In [None]:
exp_mns_calibration = Experiment(data_path.joinpath("throughput_latency_mns-calibration"))

In [None]:
# Clean columns for plotting
def pre_process_gpu_metrics(run: Run) -> pd.DataFrame:
    MAX_WATTS = 700
    gpu_metrics = run.gpu_metrics
    gpu_metrics["power_draw_watts"] = gpu_metrics[" power.draw [W]"].str.rstrip("W").str.strip().astype(float)
    gpu_metrics["power_draw_percent"] = gpu_metrics["power_draw_watts"] / MAX_WATTS * 100
    gpu_metrics["utilization_gpu_percent"] = gpu_metrics[" utilization.gpu [%]"].str.rstrip("%").str.strip().astype(float)
    gpu_metrics["utilization_memory_percent"] = gpu_metrics[" utilization.memory [%]"].str.rstrip("%").str.strip().astype(float)
    gpu_metrics["temperature_gpu_celsius"] = gpu_metrics[" temperature.gpu"]
    return gpu_metrics

In [None]:
def plot_gpu_metrics_for_condition(condition: Condition):
    measurements = condition.get_all_runs()
    n_measurements = len(measurements)
    
    # Calculate grid dimensions
    n_cols = min(3, n_measurements)  # Max 3 columns
    n_rows = (n_measurements + n_cols - 1) // n_cols  # Ceiling division
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
    
    # Flatten axes array for easier iteration
    if n_measurements == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for idx, measurement in enumerate(measurements):
        ax = axes[idx]
        gpu_metrics = pre_process_gpu_metrics(measurement)
        
        # Plot metrics
        ax.plot(gpu_metrics.index, gpu_metrics["power_draw_percent"], linewidth=1.5, alpha=0.8, label="Power Draw (%)")
        ax.plot(gpu_metrics.index, gpu_metrics["utilization_gpu_percent"], linewidth=1.5, alpha=0.8, label="GPU Utilization (%)")
        ax.plot(gpu_metrics.index, gpu_metrics["utilization_memory_percent"], linewidth=1.5, alpha=0.8, label="Memory Utilization (%)")
        ax.plot(gpu_metrics.index, gpu_metrics["temperature_gpu_celsius"], linewidth=1.5, alpha=0.8, label="Temperature (°C)")
        
        ax.set_ylim(0, 100)
        ax.set_yticks(range(0, 101, 10))
        ax.set_xlabel("Sample Index")
        ax.set_ylabel("Value")
        ax.set_title(f"GPU Usage Over Time - Measurement {measurement.index}")
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for idx in range(len(measurements), len(axes)):
        axes[idx].set_visible(False)
    
    # Add main title
    fig.suptitle(f"GPU Metrics: {condition.name} ({condition.tee_mode.value})", fontsize=14, fontweight='bold', y=0.995)
    
    plt.tight_layout()
    plt.show()

In [None]:
len(exp_mns_calibration.get_all_runs())

In [None]:
gemma_512 = exp_mns_calibration.get_condition("Gemma3 1B-mns-512", TEE_Mode.TEE_ON)
gemma_1024 = exp_mns_calibration.get_condition("Gemma3 1B-mns-1024", TEE_Mode.TEE_ON)
llama_256 = exp_mns_calibration.get_condition("Llama3.1 8B-mns-256", TEE_Mode.TEE_ON)
mistral_128 = exp_mns_calibration.get_condition("Mistral 24B-mns-128", TEE_Mode.TEE_ON)
mistral_32 = exp_mns_calibration.get_condition("Mistral 24B-mns-32", TEE_Mode.TEE_ON)
qwen_64 = exp_mns_calibration.get_condition("Qwen3 32B-mns-64", TEE_Mode.TEE_ON)
qwen_16 = exp_mns_calibration.get_condition("Qwen3 32B-mns-16", TEE_Mode.TEE_ON)


plot_gpu_metrics_for_condition(gemma_512)
plot_gpu_metrics_for_condition(gemma_1024)
plot_gpu_metrics_for_condition(llama_256)
plot_gpu_metrics_for_condition(mistral_128)
plot_gpu_metrics_for_condition(mistral_32)
plot_gpu_metrics_for_condition(qwen_64)
plot_gpu_metrics_for_condition(qwen_16)

In [None]:
# Get run durations per prompt

def get_run_durations(condition: Condition):
    return [run.get_vllm_key("duration") / run.get_vllm_key("num_prompts") for run in condition.get_all_runs()]

print(f"gemma_512: {get_run_durations(gemma_512)}") # ← You
print(f"gemma_1024: {get_run_durations(gemma_1024)}")
print(f"llama_256: {get_run_durations(llama_256)}") # ← You
print(f"mistral_128: {get_run_durations(mistral_128)}") # ← You
print(f"mistral_32: {get_run_durations(mistral_32)}")
print(f"qwen_64: {get_run_durations(qwen_64)}") # ← You
print(f"qwen_16: {get_run_durations(qwen_16)}")


### 1.2. Throughput

In [None]:
condition_labels = []
throughput_medians_tee_on, tee_on_stds = [], []
throughput_medians_tee_off, tee_off_stds = [], []

for c in exp_throughput_latency.list_condition_names(sorting=Sorting.MODEL_SIZE):
    condition_labels.append(c)
    
    # TEE ON
    median_tee_on, std_tee_on = exp_throughput_latency.get_condition(
        c, TEE_Mode.TEE_ON
    ).get_median_throughput_with_std()
    throughput_medians_tee_on.append(median_tee_on)
    tee_on_stds.append(std_tee_on)
    # TEE OFF
    median_tee_off, std_tee_off = exp_throughput_latency.get_condition(
        c, TEE_Mode.TEE_OFF
    ).get_median_throughput_with_std()
    throughput_medians_tee_off.append(median_tee_off)
    tee_off_stds.append(std_tee_off)

x = np.arange(len(condition_labels))
width = 0.20  # Width of the bars

fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.bar(
    x - width / 2,
    throughput_medians_tee_off,
    width,
    yerr=tee_off_stds,
    label="TEE off",
    capsize=5,
    alpha=0.8,
)
bars2 = ax.bar(
    x + width / 2,
    throughput_medians_tee_on,
    width,
    yerr=tee_on_stds,
    label="TEE On",
    capsize=5,
    alpha=0.8,
)

# Add labels, title and legend
ax.set_xlabel("Model")
ax.set_ylabel("Output throughput (tok/s)")
ax.set_title("Throughput Comparison: TEE Off vs TEE On")
ax.set_xticks(x)
ax.set_xticklabels(condition_labels, rotation=45, ha="right")
ax.legend()

plt.show()

### 1.3. Latency

In [None]:
def get_ttft_latency_metrics(condition: Condition):
    runs = condition.get_all_runs()
    run_medians = []
    run_p95s = []
    for run in runs:
        ttfts = run.get_vllm_key("ttfts")
        run_medians.append(np.median(ttfts))
        run_p95s.append(np.percentile(ttfts, 95))
    median_ttft = np.median(run_medians)
    std_ttft = np.std(run_medians, ddof=1)
    p95_ttft = np.median(run_p95s)
    std_p95 = np.std(run_p95s, ddof=1)
    return median_ttft, std_ttft, p95_ttft, std_p95

#### 1.3.1. Prefill: Time to First Token (TTFT)

In [None]:
models = []
medians = []
stds = []
p95s = []
std_p95s = []

for c in exp_throughput_latency.conditions:
    models.append(c.model_name)
    median, std, p95, std_p95 = get_ttft_latency_metrics(c)
    medians.append(median)
    stds.append(std)
    p95s.append(p95)
    std_p95s.append(std_p95)

condition_labels = []
latency_medians_tee_on, tee_on_stds = [], []
latency_medians_tee_off, tee_off_stds = [], []
p95_latencies_tee_on = []
p95_latencies_tee_off = []

for c in exp_throughput_latency.list_condition_names(sorting=Sorting.MODEL_SIZE):
    condition_labels.append(c)
    
    # TEE ON
    median_tee_on, std_tee_on, p95_tee_on, std_p95_tee_on = get_ttft_latency_metrics(
        exp_throughput_latency.get_condition(c, TEE_Mode.TEE_ON)
    )
    latency_medians_tee_on.append(median_tee_on)
    tee_on_stds.append(std_tee_on)
    p95_latencies_tee_on.append(p95_tee_on)
    
    # TEE OFF
    median_tee_off, std_tee_off, p95_tee_off, std_p95_tee_off = get_ttft_latency_metrics(
        exp_throughput_latency.get_condition(c, TEE_Mode.TEE_OFF)
    )
    latency_medians_tee_off.append(median_tee_off)
    tee_off_stds.append(std_tee_off)
    p95_latencies_tee_off.append(p95_tee_off)

y = np.arange(len(condition_labels))
height = 0.20  # Height of the bars

fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.barh(
    y - height / 2,
    latency_medians_tee_off,
    height,
    xerr=tee_off_stds,
    label="TEE off",
    capsize=5,
    alpha=0.8,
)
bars2 = ax.barh(
    y + height / 2,
    latency_medians_tee_on,
    height,
    xerr=tee_on_stds,
    label="TEE On",
    capsize=5,
    alpha=0.8,
)

# Add whiskers for p95 latency (drawn as thin lines extending rightward)
for i, (median_off, p95_off) in enumerate(zip(latency_medians_tee_off, p95_latencies_tee_off)):
    ax.plot([median_off, p95_off], [i - height / 2, i - height / 2], color="black", lw=1.5)
    ax.scatter(p95_off, i - height / 2, color="black", s=20, zorder=3)

for i, (median_on, p95_on) in enumerate(zip(latency_medians_tee_on, p95_latencies_tee_on)):
    ax.plot([median_on, p95_on], [i + height / 2, i + height / 2], color="black", lw=1.5)
    ax.scatter(p95_on, i + height / 2, color="black", s=20, zorder=3)

# Add labels, title and legend
ax.set_xlabel("Time to First Token (ms)")
ax.set_ylabel("Model")
ax.set_title("Latency Comparison: TEE Off vs TEE On")
ax.set_yticks(y)
ax.set_yticklabels(condition_labels)
ax.legend()
#ax.invert_yaxis()  # highest on top

plt.show()

In [None]:
condition = exp_throughput_latency.get_condition("Gemma3 1B", TEE_Mode.TEE_ON)


def get_decode_latency_metrics(condition: Condition):
    runs = condition.get_all_runs()
    run_medians = []
    run_p95s = []

    for run in runs:
        # Fetch aggregated TPOT stats
        median_tpot = run.get_vllm_key("median_tpot_ms")
        p95_tpot = run.get_vllm_key("p95_tpot_ms")

        # Compute mean output length for this run
        output_lens = np.array(run.get_vllm_key("output_lens"))
        mean_output_len = np.mean(output_lens)

        # Estimate total decode latency for this run (ms)
        run_medians.append(median_tpot * mean_output_len)
        run_p95s.append(p95_tpot * mean_output_len)

    # Aggregate across runs
    median_decode = np.median(run_medians)
    std_decode = np.std(run_medians, ddof=1)
    p95_decode = np.median(run_p95s)
    std_p95 = np.std(run_p95s, ddof=1)

    return median_decode, std_decode, p95_decode, std_p95

#### 1.3.2. Decode latency

In [None]:
condition_labels_decode = []
decode_latency_medians_tee_on, decode_tee_on_stds = [], []
decode_latency_medians_tee_off, decode_tee_off_stds = [], []
p95_decode_latencies_tee_on = []
p95_decode_latencies_tee_off = []

for c in exp_throughput_latency.list_condition_names(sorting=Sorting.MODEL_SIZE):
    condition_labels_decode.append(c)
    
    # TEE ON
    median_tee_on, std_tee_on, p95_tee_on, std_p95_tee_on = get_decode_latency_metrics(
        exp_throughput_latency.get_condition(c, TEE_Mode.TEE_ON)
    )
    decode_latency_medians_tee_on.append(median_tee_on)
    decode_tee_on_stds.append(std_tee_on)
    p95_decode_latencies_tee_on.append(p95_tee_on)
    
    # TEE OFF
    median_tee_off, std_tee_off, p95_tee_off, std_p95_tee_off = get_decode_latency_metrics(
        exp_throughput_latency.get_condition(c, TEE_Mode.TEE_OFF)
    )
    decode_latency_medians_tee_off.append(median_tee_off)
    decode_tee_off_stds.append(std_tee_off)
    p95_decode_latencies_tee_off.append(p95_tee_off)

y = np.arange(len(condition_labels_decode))
height = 0.20  # Height of the bars

fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.barh(
    y - height / 2,
    decode_latency_medians_tee_off,
    height,
    xerr=decode_tee_off_stds,
    label="TEE off",
    capsize=5,
    alpha=0.8,
)
bars2 = ax.barh(
    y + height / 2,
    decode_latency_medians_tee_on,
    height,
    xerr=decode_tee_on_stds,
    label="TEE On",
    capsize=5,
    alpha=0.8,
)

# Add whiskers for p95 latency (drawn as thin lines extending rightward)
for i, (median_off, p95_off) in enumerate(zip(decode_latency_medians_tee_off, p95_decode_latencies_tee_off)):
    ax.plot([median_off, p95_off], [i - height / 2, i - height / 2], color="black", lw=1.5)
    ax.scatter(p95_off, i - height / 2, color="black", s=20, zorder=3)

for i, (median_on, p95_on) in enumerate(zip(decode_latency_medians_tee_on, p95_decode_latencies_tee_on)):
    ax.plot([median_on, p95_on], [i + height / 2, i + height / 2], color="black", lw=1.5)
    ax.scatter(p95_on, i + height / 2, color="black", s=20, zorder=3)

# Add labels, title and legend
ax.set_xlabel("Decode Latency (ms)")
ax.set_ylabel("Model")
ax.set_title("Decode Latency Comparison: TEE Off vs TEE On")
ax.set_yticks(y)
ax.set_yticklabels(condition_labels_decode)
ax.legend()
#ax.invert_yaxis()  # highest on top

plt.show()

## 2. Saturation point

In [None]:
def get_saturation_metrics(exp_saturation_point):

    x_qps = []
    y_qps_tee_on_median = []
    y_qps_tee_on_std = []
    y_qps_tee_off_median = []
    y_qps_tee_off_std = []
    y_p95_tee_on = []
    y_p95_tee_off = []

    for condition_name in exp_saturation_point.list_condition_names(Sorting.NATURAL):
        # X axis
        x_qps.append(int(condition_name.split("_")[-1]))
        # Y TEE_ON
        runs_TEE_On = exp_saturation_point.get_condition(condition_name, TEE_Mode.TEE_ON).get_all_runs()
        run_QPSs_On = [run.get_vllm_key("max_concurrent_requests") for run in runs_TEE_On]
        run_p95s_On = [run.get_vllm_key("p95_e2el_ms") for run in runs_TEE_On]
        y_qps_tee_on_median.append(np.median(run_QPSs_On))
        y_qps_tee_on_std.append(np.std(run_QPSs_On))
        y_p95_tee_on.append(np.median(run_p95s_On))
        # Y TEE_OFF
        runs_TEE_Off = exp_saturation_point.get_condition(condition_name, TEE_Mode.TEE_OFF).get_all_runs()
        run_QPSs_Off = [run.get_vllm_key("max_concurrent_requests") for run in runs_TEE_Off]
        y_qps_tee_off_median.append(np.median(run_QPSs_Off))
        y_qps_tee_off_std.append(np.std(run_QPSs_Off))
        run_p95s_Off = [run.get_vllm_key("p95_e2el_ms") for run in runs_TEE_Off]
        y_p95_tee_off.append(np.median(run_p95s_Off))
    
    return {
        "x_qps": x_qps,
        "y_qps_tee_on_median": y_qps_tee_on_median,
        "y_qps_tee_on_std": y_qps_tee_on_std,
        "y_qps_tee_off_median": y_qps_tee_off_median,
        "y_qps_tee_off_std": y_qps_tee_off_std,
        "y_p95_tee_on": y_p95_tee_on,
        "y_p95_tee_off": y_p95_tee_off,
    }

In [None]:
def plot_saturation_metrics(metrics: dict):

    x_qps = metrics["x_qps"]
    y_qps_tee_on_median = metrics["y_qps_tee_on_median"]
    y_qps_tee_on_std = metrics["y_qps_tee_on_std"]
    y_qps_tee_off_median = metrics["y_qps_tee_off_median"]
    y_qps_tee_off_std = metrics["y_qps_tee_off_std"]
    y_p95_tee_on = metrics["y_p95_tee_on"]
    y_p95_tee_off = metrics["y_p95_tee_off"]
    
    fig, ax = plt.subplots(figsize=(12, 6))

    # Plot TEE ON with error bars on primary Y-axis
    ax.errorbar(
        x_qps, 
        y_qps_tee_on_median, 
        yerr=y_qps_tee_on_std,
        label="TEE On (Median QPS)",
        #marker="o",
        capsize=5,
        capthick=2,
        linewidth=2,
        markersize=8,
        alpha=0.8,
        elinewidth=1.5
    )

    # Plot TEE OFF with error bars on primary Y-axis
    # ax.errorbar(
    #     x_qps, 
    #     y_qps_tee_off_median, 
    #     yerr=y_qps_tee_off_std,
    #     label="TEE Off (Median QPS)",
    #     #marker="s",
    #     capsize=5,
    #     capthick=2,
    #     linewidth=2,
    #     markersize=8,
    #     alpha=0.8,
    #     elinewidth=1.5
    # )

    ax.set_xlabel("Requested Concurrency")
    ax.set_ylabel("Achieved QPS")
    ax.tick_params(axis='y')

    # Create secondary Y-axis for P95 values
    ax2 = ax.twinx()

    # Plot P95 TEE ON on secondary Y-axis
    ax2.plot(
        x_qps, 
        y_p95_tee_on,
        label="TEE On (P95)",
        #marker="o",
        linestyle="--",
        linewidth=1,
        markersize=1,
        alpha=0.6,
    )

    # Plot P95 TEE OFF on secondary Y-axis
    # ax2.plot(
    #     x_qps, 
    #     y_p95_tee_off,
    #     label="TEE Off (P95)",
    #     #marker="s",
    #     linestyle="--",
    #     linewidth=1,
    #     markersize=1,
    #     alpha=0.6,
    # )

    ax2.set_ylabel("P95 Concurrent Requests")
    ax2.tick_params(axis='y')

    ax.set_title("Saturation Point: TEE On vs TEE Off (Median + P95)")

    # Combine legends from both axes
    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left')

    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

In [None]:
sweep_1 = Experiment(data_path.joinpath("saturation_point_sweep_1"))
metrics = get_saturation_metrics(sweep_1)
plot_saturation_metrics(metrics)
for run in sweep_1.get_all_runs():
    print(run)

In [None]:
sweep_2 = Experiment(data_path.joinpath("saturation_point_sweep_2"))
metrics = get_saturation_metrics(sweep_2)
plot_saturation_metrics(metrics)

In [None]:
runs = sweep_1.get_all_runs()
for run in runs:
    print(run)

In [None]:
runs = sweep_2.get_all_runs()

duration = 0.0
for run in runs:
    duration += run.get_vllm_key("duration")

hours = int(duration // 3600)
minutes = int((duration % 3600) // 60)
seconds = int(duration % 60)

print(f"{hours}:{minutes}:{seconds}")

## 3. Sequence length overhead

In [None]:
exp_sequence_overhead = Experiment(data_path.joinpath("sequence_overhead"))

In [None]:
input_lengths = []
output_lengths = []
tee_on_throughputs_medians = []
tee_on_throughputs_stds = []
tee_off_throughputs_medians = []
tee_off_throughputs_stds = []


for condition in exp_sequence_overhead.list_condition_names(Sorting.NATURAL):
    input_len = int(condition.split("_")[1])
    output_len = int(condition.split("_")[-1])
    input_lengths.append(input_len)
    output_lengths.append(output_len)
    # TEE_ON
    tee_on_throughputs = [r.get_vllm_key("total_token_throughput") for r in exp_sequence_overhead.get_condition(condition, TEE_Mode.TEE_ON).get_all_runs()]
    tee_on_throughputs_median = np.median(tee_on_throughputs)
    tee_on_throughputs_std = np.std(tee_on_throughputs)
    tee_on_throughputs_medians.append(tee_on_throughputs_median)
    tee_on_throughputs_stds.append(tee_on_throughputs_std)
    # TEE_OFF
    tee_off_throughputs = [r.get_vllm_key("total_token_throughput") for r in exp_sequence_overhead.get_condition(condition, TEE_Mode.TEE_OFF).get_all_runs()]
    tee_off_throughputs_median = np.median(tee_off_throughputs)
    tee_off_throughputs_std = np.std(tee_off_throughputs)
    tee_off_throughputs_medians.append(tee_off_throughputs_median)
    tee_off_throughputs_stds.append(tee_off_throughputs_std)

    print(f"Input length: {input_len}, Output length: {output_len}")
    print(f"TEE On throughputs: {tee_on_throughputs}")
    print(f"TEE Off throughputs: {tee_off_throughputs}")
    print(f"Overhead: {(tee_off_throughputs_median - tee_on_throughputs_median) / tee_on_throughputs_median * 100}%")
    

In [None]:
# Create heatmap using the already-processed data
# Build a dictionary from the lists: (input_len, output_len) -> throughput
heatmap_data = {}
unique_input_lengths = sorted(set(input_lengths))
unique_output_lengths = sorted(set(output_lengths))

for in_len, out_len, throughput in zip(input_lengths, output_lengths, tee_on_throughputs_medians):
    heatmap_data[(in_len, out_len)] = throughput

# Create pivot table for heatmap (rows = output lengths, columns = input lengths)
heatmap_matrix = np.zeros((len(unique_output_lengths), len(unique_input_lengths)))

for i, out_len in enumerate(unique_output_lengths):
    for j, in_len in enumerate(unique_input_lengths):
        if (in_len, out_len) in heatmap_data:
            heatmap_matrix[i, j] = heatmap_data[(in_len, out_len)]
        else:
            heatmap_matrix[i, j] = np.nan

# Create heatmap
fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(heatmap_matrix, cmap='viridis', aspect='auto', interpolation='nearest', origin='lower')

# Set ticks and labels
ax.set_xticks(np.arange(len(unique_input_lengths)))
ax.set_yticks(np.arange(len(unique_output_lengths)))
ax.set_xticklabels(unique_input_lengths)
ax.set_yticklabels(unique_output_lengths)

# Labels and title
ax.set_xlabel("Input Length (tokens)", fontsize=12, fontweight='bold')
ax.set_ylabel("Output Length (tokens)", fontsize=12, fontweight='bold')
ax.set_title("Median Total Token Throughput (TEE ON) - Sequence Overhead", fontsize=14, fontweight='bold')

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Total Token Throughput (tok/s)", fontsize=11)

# Add values in cells
for i in range(len(unique_output_lengths)):
    for j in range(len(unique_input_lengths)):
        if not np.isnan(heatmap_matrix[i, j]):
            text = ax.text(j, i, f"{heatmap_matrix[i, j]:.1f}",
                          ha="center", va="center", color="white" if heatmap_matrix[i, j] < np.nanmax(heatmap_matrix)/2 else "black",
                          fontsize=8)

plt.tight_layout()
plt.show()

## 4. Energy efficiency

In [None]:
exp_energy = Experiment(data_path.joinpath("energy"))

In [None]:
def get_J_per_token(run: Run):
    gpu_metrics = pre_process_gpu_metrics(run)
    power_watts = gpu_metrics['power_draw_watts'].values
    # Each row is 1 second, so use dx=1 for uniform integration
    energy_joules = np.trapezoid(power_watts, dx=1)
    J_per_token = energy_joules / run.get_vllm_key("total_output_tokens")
    return J_per_token

In [None]:
models = exp_energy.list_condition_names(Sorting.NATURAL)
J_per_token_medians_tee_on = []
J_per_token_medians_tee_off = []
J_per_token_stds_tee_on = []
J_per_token_stds_tee_off = []

for model in models:
    # TEE_ON
    J_per_token_tee_on = []
    for run in exp_energy.get_condition(model, TEE_Mode.TEE_ON).get_all_runs():
        J_per_token_tee_on.append(get_J_per_token(run))
    J_per_token_medians_tee_on.append(np.median(J_per_token_tee_on))
    J_per_token_stds_tee_on.append(np.std(J_per_token_tee_on))
    # TEE_OFF
    J_per_token_tee_off = []
    for run in exp_energy.get_condition(model, TEE_Mode.TEE_OFF).get_all_runs():
        J_per_token_tee_off.append(get_J_per_token(run))
    J_per_token_medians_tee_off.append(np.median(J_per_token_tee_off))
    J_per_token_stds_tee_off.append(np.std(J_per_token_tee_off))


In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

# Plot TEE_ON line with error bars
ax.errorbar(range(len(models)), J_per_token_medians_tee_on, 
            yerr=J_per_token_stds_tee_on, 
            marker='o', label='TEE_ON', linewidth=2, capsize=5, markersize=8)

# Plot TEE_OFF line with error bars
ax.errorbar(range(len(models)), J_per_token_medians_tee_off, 
            yerr=J_per_token_stds_tee_off, 
            marker='s', label='TEE_OFF', linewidth=2, capsize=5, markersize=8)

ax.set_xlabel('Model', fontsize=12, fontweight='bold')
ax.set_ylabel('Energy per Token (J/token)', fontsize=12, fontweight='bold')
ax.set_title('GPU Energy Consumption: TEE ON vs TEE OFF', fontsize=14, fontweight='bold')
ax.set_xticks(range(len(models)))
ax.set_xticklabels(models, rotation=45, ha='right')
ax.legend(fontsize=11, loc='best')
ax.grid(True, alpha=0.3, linestyle='--')

plt.tight_layout()
plt.show()

# TODO: Scale X by the jump in model size

## 4. Price of operations