<a href="https://colab.research.google.com/github/neilw4/CFCompiler/blob/master/load_balancing_sim.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Dependencies

In [None]:
%pip install --quiet pyinstrument
%pip install --quiet kaleido
%pip install --quiet binarytree
%load_ext pyinstrument

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.4/145.4 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.9/43.9 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from dataclasses import dataclass, field
from enum import Enum
from typing import cast, Optional, Callable, TypeVar, Generic, Generator
from collections import defaultdict
import random
import traceback
import math

# Used to make png charts
import kaleido
import bisect
import plotly.express as px
from plotly.subplots import make_subplots
# from tqdm import trange
from tqdm.notebook import trange, tqdm
import pandas as pd
import numpy as np
from scipy.stats import gamma
# from tqdm.contrib.concurrent import process_map, thread_map
from google.colab import auth
import gspread
from google.auth import default
from concurrent.futures import ProcessPoolExecutor, as_completed

## Helpers

In [None]:
# Gamma distribution never goes below 0 but can extend infinitely high. This makes it a useful approximation for many
# factors that vary, tend to have a low mean, don't go below 0, and can, on occasion, be very high.
def randGamma(mean, variance=1.5, min=1):
  # mean is always `a × scale`
  return int(gamma.rvs(a=((variance*mean)-min)/variance, scale=variance) / variance + min)

one = [randGamma(1) for _ in range(1000)]
assert round(sum(one) / len(one)) == 1
assert min(one)  == 1
assert max(one) > 1

two = [randGamma(2) for _ in range(1000)]
assert round(sum(two) / len(two)) == 2
assert min(one) == 1
assert max(one) > 2

def RandomRound(x: float) -> int:
  if random.random() < x % 1:
    return math.ceil(x)
  else:
    return math.floor(x)

In [121]:
# Out-of-order version of tqdm's process_map. Better for seeing progress.
def ooo_process_map(f, xs):
  with ProcessPoolExecutor(max_workers=8) as p:
    fs = {p.submit(f, x): x for x in xs}
    with tqdm(total=len(fs)) as pbar:
      for f in as_completed(fs.keys()):
        try:
          r = f.result()
          yield r
        except:
          print(f"error processing {fs[f]}")
          traceback.print_exc()
        pbar.update(1)

# from time import sleep
# list(tqdm_process_map(sleep, [3,5,4, 1, 2]))

In [None]:
T = TypeVar('T')

@dataclass(frozen=True)
class LastN(Generic[T]):
  n: int
  buffer: list[T]

  def __init__(self, n: int, buffer: list[T] | None = None):
    object.__setattr__(self, 'n', n)
    object.__setattr__(self, 'buffer', buffer or [])

  def is_empty(self):
    return not bool(self.buffer)

  def insert(self, value: T) -> 'LastN[T]':
    if len(self.buffer) < self.n:
      buffer = self.buffer[:]
    else:
      buffer = self.buffer[1:]
    buffer.append(value)
    return self.__class__(self.n, buffer)

  def __iter__(self):
    return iter(self.buffer)

  def latest(self, translator: Callable[[T], float]) -> float | None:
    if self.is_empty():
      return None
    return self.buffer[-1]

  def mean(self, translator: Callable[[T], float]) -> float | None:
    if self.is_empty():
      return None
    return sum(translator(x) for x in self.buffer) / len(self.buffer)

  def ewma(self, translator: Callable[[T], float], weight=0.2) -> float | None:
    # Exponential weighted moving average
    # TODO: use weight from test manager (can set on initialisation).
    # We could even use data from before the window if we keep a rolling window.
    av = None
    for x in self.buffer:
      if av is None:
        av = translator(x)
      else:
        av = weight * translator(x) + (1 - weight) * av
    return av

  def __repr__(self) -> str:
    return f'LastN(data_points={len(self.buffer)}, latest={self.latest(lambda x: x)}'

  def __str__(self) -> str:
    return repr(self)

class LastNFloats(LastN[float]):

  def insert(self, value: float) -> 'LastNFloats':
    return cast(LastNFloats, super().insert(value))

  def latest(self) -> float | None:
    return super().latest(translator=lambda x: x)

  def mean(self) -> float | None:
    return super().mean(translator=lambda x: x)

  def ewma(self,  weight: float=0.2) -> float | None:
    return super().ewma(translator=lambda x: x, weight=weight)

  def __repr__(self) -> str:
    return f'LastNFloats(data_points={len(self.buffer)}, latest={self.latest()}, mean={self.mean()}, ewma={self.ewma()})'

  def __str__(self) -> str:
    return repr(self)

c = LastNFloats(3)
assert c.is_empty()
assert c.latest() is None
assert c.mean() is None
assert list(c) == []
c_orig = c
c = c.insert(5)
assert c_orig.is_empty()
assert c_orig.latest() is None
assert c_orig.mean() is None
assert list(c_orig) == []

assert not c.is_empty()
assert c.latest() == 5
assert c.mean() == 5
assert c.ewma(weight=0.5) == 5
assert list(c) == [5]
c = c.insert(10)
assert not c.is_empty()
assert c.latest() == 10
assert c.mean() == 7.5
assert c.ewma(weight=0.5) == 7.5
assert c.ewma(weight=1) == 10
assert c.ewma(weight=0) == 5
assert list(c) == [5, 10]
c = c.insert(15)
assert not c.is_empty()
assert c.latest() == 15
assert c.mean() == 10
assert list(c) == [5, 10, 15]
c = c.insert(20)
assert not c.is_empty()
assert c.latest() == 20
assert c.mean() == 15
assert list(c) == [10, 15, 20]

In [None]:
@dataclass
class PidController:
  kp: float
  ki: float
  kd: float
  prev_error: float = 0
  integral: float = 0

  def update(self, value: float, target: float) -> float:
    error = target - value
    self.integral += error
    derivative = error - self.prev_error
    self.prev_error = error
    return self.kp * error + self.ki * self.integral + self.kd * derivative

In [None]:
# TODO: we can improve performance using a sorted data structure such as a BST.
def interpolate(d: dict[int, float], x: int) -> float | None:
  closest_below = None
  closest_above = None
  for key in d:
    if key <= x and (closest_below is None or key > closest_below):
      closest_below = key
    if key >= x and (closest_above is None or key < closest_above):
      closest_above = key
    if closest_below == closest_above:
      return d[closest_below]
  match closest_below, closest_above:
    case None, None:
      return None
    case None, _:
      return d[closest_above]
    case _, None:
      return d[closest_below]
    case _, _:
      return d[closest_below] + (d[closest_above] - d[closest_below]) * (x - closest_below) / (closest_above - closest_below)

d = {1: 1.0, 2: 2.0, 4: 4.0, 8: 8.0}

assert interpolate({}, 0) is None
assert interpolate({}, 1) is None

assert interpolate(d, 1) == 1.0
assert interpolate(d, 2) == 2.0
assert interpolate(d, 3) == 3.0
assert interpolate(d, 4) == 4.0
assert interpolate(d, 5) == 5.0
assert interpolate(d, 6) == 6.0
assert interpolate(d, 7) == 7.0
assert interpolate(d, 8) == 8.0
assert interpolate(d, 9) == 8.0
assert interpolate(d, 90) == 8.0

In [None]:
def queries_for_utilization(threads, processing_time, probes_per_request=0.0, probe_processing_time=1, target_utilization = 0.8):
  return target_utilization * threads / (processing_time + probes_per_request * probe_processing_time)

## Implementation

In [None]:
@dataclass()
class TestManager:

  class Stage(Enum):
    STOPPED = 0
    # First few requests, caches will be cold
    WARMING = 1
    # (Ideally) steady state
    RUNNING = 2
    # Waiting for straggling requests to finish
    COOLING = 3

  # Architecture
  channel_picker_factory: 'ChannelPickerFactory'
  servers: int = 40
  clients: int = 40
  subset_size: int = 12
  # Ticks
  mean_requests_per_server_per_tick: float = 2.2
  total_ticks: int = 500
  current_tick: int = 0
  # Latency
  mean_processing_time: int = 6
  probe_processing_time: int = 1
  mean_channel_rtt = 2
  channel_rtt_variance = 3
  mean_server_rtt = 2
  server_rtt_variance = 3
  # Throughput
  mean_threads: int = 100
  # Resets
  network_reset_probability: float = 0.002
  server_reset_probability: float = 0.002
  cold_channel_penalty: int = 10
  # Lookback
  server_state_buffer: int = 100
  client_requests_buffer: int = 100
  ewma_weight: float = 0.2
  # Debug
  debug: bool = False
  logs: list = field(default_factory=list)
  completed_requests: list = field(default_factory=list)
  stage: Stage = Stage.STOPPED
  max_cooldown: int = 300
  data_frequency: int = 30

  def Log(self, obj, msg):
    if self.debug:
      name = obj.__class__.__name__
      if hasattr(obj, 'name'):
        name = obj.name
      self.logs.append((self.current_tick, name, msg))
      if self.debug and 'MyFavouriteClass' in obj.__class__.__name__:
        print(f'tick {self.current_tick} {name}: {msg}')

  def RequestCompleted(self, request: 'Request'):
    if self.current_tick % self.data_frequency == 0:
      assert request.channel is not None
      assert request.completed_on_tick is not None
      assert request.server_state is not None
      self.completed_requests.append({
            'test_stage': self.stage.name,
            'client': request.channel.client.name,
            'server': request.channel.server.name,
            'is_probe': request.is_probe,
            'started_on_tick': request.started_on_tick,
            'completed_on_tick': request.completed_on_tick,
            'tick_at_send': request.tick_at_state_start[Request.State.SENDING],
            'tick_at_server_queue': request.tick_at_state_start[Request.State.SERVER_QUEUE],
            'tick_at_processing': request.tick_at_state_start[Request.State.PROCESSING],
            'tick_at_receive': request.tick_at_state_start[Request.State.RECEIVING],
            'tick_at_completion': request.tick_at_state_start[Request.State.COMPLETED],
            'ticks_between_start_and_completion': request.completed_on_tick - request.started_on_tick,
            'total_ticks': sum(request.ticks_in_state.values()),
            'picking_ticks': request.ticks_in_state[Request.State.PICKING],
            'send_ticks': request.ticks_in_state[Request.State.SENDING],
            'server_queue_ticks': request.ticks_in_state[Request.State.SERVER_QUEUE],
            'processing_ticks': request.ticks_in_state[Request.State.PROCESSING],
            'receive_ticks': request.ticks_in_state[Request.State.RECEIVING],
            'server_utilization_latest': request.server_state.utilization.latest(),
            'server_utilization_mean': request.server_state.utilization.mean(),
            'server_utilization_ewma': request.server_state.utilization.ewma(weight=self.ewma_weight),
            'server_queries_per_tick_latest': request.server_state.queries_per_tick.latest(),
            'server_queries_per_tick_mean': request.server_state.queries_per_tick.mean(),
            'server_queries_per_tick_ewma': request.server_state.queries_per_tick.ewma(weight=self.ewma_weight),
            'server_requests_in_flight_latest': request.server_state.requests_in_flight.latest(),
            'server_requests_in_flight_mean': request.server_state.requests_in_flight.mean(),
            'server_requests_in_flight_ewma': request.server_state.requests_in_flight.ewma(weight=self.ewma_weight),
            'server_queued_requests_latest': request.server_state.queued_requests.latest(),
            'server_queued_requests_mean': request.server_state.queued_requests.mean(),
            'server_queued_requests_ewma': request.server_state.queued_requests.ewma(weight=self.ewma_weight),
            'server_pid_weight': request.server_state.pid_weight,
          })

  def tick(self, servers: list['Server'], clients: list['Client'], requests: int):
    for _ in range(requests):
        client = random.choice(clients)
        client.send(Request(self))
    for server in servers:
      server.tick()
    for client in clients:
      client.tick()
    self.current_tick += 1


  def runTest(self, scenario_name: str) -> pd.DataFrame:
    servers = [Server(f'Server {n}', self) for n in range(self.servers)]
    self.servers_impl = servers
    clients = [Client(f'Client {n}', self, servers) for n in range(self.clients)]
    self.clients_impl = clients
    self.stage = TestManager.Stage.WARMING
    warm_period = int(min((self.mean_channel_rtt + self.mean_server_rtt + self.mean_processing_time) * 3, self.total_ticks / 2))
    # Incoming request rate isn't even, it's variable. Mimic this by picking a random time to send each request.
    # https://www.usenix.org/sites/default/files/conference/protected-files/srecon19apac_slides_plenz.pdf
    request_tick = [random.randrange(0, self.total_ticks) for _ in range(round(self.mean_requests_per_server_per_tick * self.servers * self.total_ticks))]
    for i in trange(self.total_ticks, desc=scenario_name):
      requests = request_tick.count(i)
      if i > warm_period:
        self.stage = TestManager.Stage.RUNNING
      self.tick(servers, clients, requests)
    self.stage = TestManager.Stage.COOLING
    for i in range(self.max_cooldown):
      if all(client.requests_in_flight() == 0 for client in clients):
        break
      self.tick(servers, clients, requests=0)
    if self.current_tick >= self.total_ticks + self.max_cooldown:
      print("Cooldown failed, there are still requests in flight")
    return pd.DataFrame.from_records(self.completed_requests)


In [None]:
@dataclass(frozen=True)
class ServerState:
  """Tracks the state of a server at a moment in time."""
  pid_weight: float
  utilization: LastNFloats
  queries_per_tick: LastNFloats
  requests_in_flight: LastNFloats
  requests_in_queue: LastNFloats
  latencies: LastNFloats
  queued_requests: LastNFloats = field(repr=False)
  latency_with_rif: dict[int, float] = field(default_factory=dict, repr=False)

  def expectedLatency(self) -> float | None:
    requests_in_flight = None if self.requests_in_flight.is_empty() else self.requests_in_flight.mean()
    if requests_in_flight is None:
      return None
    return interpolate(self.latency_with_rif, requests_in_flight)

  def update(self,
             pid_weight_adjustment: float,
             utilization: float,
             queries_per_tick: int,
             requests_in_flight: int,
             requests_in_queue: int,
             queued_requests: int,
             latency: float | None,
             ) -> 'ServerState':
    return ServerState(pid_weight=self.pid_weight + pid_weight_adjustment,
                       utilization=self.utilization.insert(utilization),
                       queries_per_tick=self.queries_per_tick.insert(queries_per_tick),
                       requests_in_flight=self.requests_in_flight.insert(requests_in_flight),
                       requests_in_queue=self.requests_in_queue.insert(requests_in_queue),
                       queued_requests=self.queued_requests.insert(queued_requests),
                       latencies=self.latencies.insert(latency) if latency else self.latencies,
                       latency_with_rif=(self.latency_with_rif | ({requests_in_flight: latency} if latency else {}))
                       )

@dataclass
class Request:
  """
  Requests are sent from client to server and back, through a network channel.
  They have a simple state machine. Requests don't (yet) support errors or
  timeouts.
  """
  class State(Enum):
    PICKING = 0
    SENDING = 1
    SERVER_QUEUE = 2
    PROCESSING = 3
    RECEIVING = 4
    COMPLETED = 5

  name: str
  state: State
  test_manager: TestManager = field(repr=False)
  started_on_tick: int
  completed_on_tick: Optional[int]
  ticks_in_state: dict[State, int]
  ticks_left_in_current_state: int
  tick_at_state_start: dict[State, int]
  server_state: Optional[ServerState]
  channel: Optional['Channel'] = field(repr=False)
  is_probe: bool = False

  STALLABLE_STATES = (State.SERVER_QUEUE, State.PICKING, State.COMPLETED)


  def __init__(self, test_manager: TestManager, is_probe: bool = False):
    self.name = ''.join([random.choice('0123456789abcdefghijklmnopqrstuvwxyz') for _ in range(5)])
    self.is_probe = is_probe
    self.test_manager = test_manager
    self.state = Request.State.PICKING
    self.started_on_tick = test_manager.current_tick
    self.completed_on_tick = None
    self.ticks_left_in_current_state = 0
    self.ticks_in_state = defaultdict(int)
    self.tick_at_state_start = dict()
    self.server_state = None
    self.channel = None

  def setState(self, originator, state, expected_state, ticks_in_state=0, server_state = None, channel = None):
    if self.state != expected_state:
      raise ValueError(f'Cannot set state of request in unexpected state {self.state} (expected {expected_state}). Request {self}')
    # Can happen when a request is demoted from Processing to Queued if a server ends up with a smaller number of threads.
    # if self.ticks_left_in_current_state > 0:
    #   raise ValueError(f'Cannot set state of request to {state} from {expected_state} with {self.ticks_left_in_current_state} left in state. Request {self}')
    if ticks_in_state > 100:
      print(f"T-- {self.name} moving to new state {state} from {self.state} with too many ticks in current state: {ticks_in_state}")
    self.state = state
    self.ticks_left_in_current_state = ticks_in_state
    if server_state:
      self.server_state = server_state
    if channel:
      self.channel = channel
      self.name += f" from {channel.client.name}"
    if state == Request.State.COMPLETED:
      self.completed_on_tick = self.test_manager.current_tick
    if state not in self.tick_at_state_start.keys():
      self.tick_at_state_start[state] = self.test_manager.current_tick
    self.test_manager.Log(self, f'Request state updated from {expected_state} to {state} with {ticks_in_state} ticks left by {originator.name}')
    self.test_manager.Log(originator, f'Request {self.name} updated from {expected_state} to {state} with {ticks_in_state} ticks left')

  def isFinishedInState(self):
    return self.ticks_left_in_current_state <= 0

  def tick(self, progress=True):
    if self.isFinishedInState() and self.state not in self.STALLABLE_STATES:
      raise ValueError(f'Cannot tick request left with {self.ticks_left_in_current_state} ticks left in state. Request {self}')
    self.ticks_in_state[self.state] += 1
    if self.ticks_left_in_current_state > 100:
      print(f"T-- {self.name} ticking, reducing ticks in state from {self.ticks_left_in_current_state} to {self.ticks_left_in_current_state-1}")
    self.ticks_left_in_current_state -= 1
    self.test_manager.Log(self, f'Request in state {self.state} ticked. {self.ticks_in_state[self.state]} ticks taken, {self.ticks_left_in_current_state} ticks left.')


@dataclass
class Server:
  """
  Processes requests. Has a fixed size thread pool and, when this is exhausted, requests are queued.
  Thread pool and processing time are set randomly, and are occasionally reset to simulate a server being replaced.
  Servers don't (yet) implement any kind of pushback mechanism, so queues can grow unbounded.
  """
  name: str
  test_manager: TestManager = field(repr=False)
  queued_requests: list[Request]
  processing_requests: list[Request]
  queries_in_tick: int
  threads: int
  mean_processing_time: int
  server_state: ServerState
  warm_channels: dict[str, 'Channel']
  pid_controller: PidController

  def __init__(self, name: str, test_manager: TestManager):
    self.name = name
    self.test_manager = test_manager
    self.queued_requests = []
    self.processing_requests = []
    self.queries_in_tick = 0
    self.reset()

  def reset(self):
    if hasattr(self, 'threads') and hasattr(self, 'mean_processing_time'):
      self.test_manager.Log(self, f'Resetting from {self.threads=}, {self.mean_processing_time=}')
    self.threads = randGamma(self.test_manager.mean_threads)
    self.mean_processing_time = randGamma(self.test_manager.mean_processing_time)
    self.mean_rtt = randGamma(self.test_manager.mean_server_rtt, self.test_manager.server_rtt_variance)
    self.warm_channels = dict()

    # https://tlk-energy.de/blog-en/practical-pid-tuning-guide
    # P: just a constant since the weight is going to be compared to other weights
    kp = 1
    # I: P/Ti where Ti is the integration time - how long we want it to take for the system to asymptotically meet the target.
    #    In our case, it isn't a big deal if the value doesn't quite hit the target so we can use a large time.
    kti = 50
    ki = kp/kti
    # D: P*Td, where Td is the derivative time. The recommended starting derivative time is Ti/10.
    ktd = kti/10
    kd = kp * ktd
    self.pid_controller = PidController(kp, ki, kd)
    self.server_state = ServerState(pid_weight = 0,
                                    utilization=LastNFloats(self.test_manager.server_state_buffer),
                                    queries_per_tick=LastNFloats(self.test_manager.server_state_buffer),
                                    requests_in_flight=LastNFloats(self.test_manager.server_state_buffer),
                                    requests_in_queue=LastNFloats(self.test_manager.server_state_buffer),
                                    queued_requests=LastNFloats(self.test_manager.server_state_buffer),
                                    latencies=LastNFloats(self.test_manager.server_state_buffer),
                                    latency_with_rif=dict(),
                                    )
    self.test_manager.Log(self, f'Resetting to {self.threads=}, {self.mean_processing_time=}')

    if len(self.processing_requests) > self.threads:
      # Put some requests back to the front of the queue.
      starved_requests = self.processing_requests[self.threads:]
      self.processing_requests = self.processing_requests[:self.threads]
      for request in starved_requests:
        request.setState(self, Request.State.SERVER_QUEUE, expected_state=Request.State.PROCESSING)
        self.queued_requests.insert(0, request)

    # for request in self.processing_requests:
    #   # Simulate a cache flush causing requests to temporarily take longer
    #   if request.ticks_left_in_current_state > 100:
    #     print(f"T-- server reset of {self.name} increased ticks left for request {request.name} from {request.ticks_left_in_current_state} to {request.ticks_left_in_current_state * self.test_manager.cold_channel_penalty}")
    #   request.ticks_left_in_current_state *= self.test_manager.cold_channel_penalty

  def enqueue(self, request: Request, channel: 'Channel'):
    request.setState(self, Request.State.SERVER_QUEUE, expected_state=Request.State.SENDING, channel=channel)
    self.queries_in_tick += 1
    self.queued_requests.append(request)
    self.test_manager.Log(self, f'enqueued request from {channel.client.name}, increasing queue length to {len(self.queued_requests)}')

  def tick(self):
    if random.random() < self.test_manager.server_reset_probability:
      self.reset()

    # dequeue
    requests_to_dequeue = min(max(0, self.threads - len(self.processing_requests)), len(self.queued_requests))
    prev_processing_requests = len(self.processing_requests)
    for request in self.queued_requests[:requests_to_dequeue]:
      processing_time = randGamma(self.mean_processing_time)
      if request.channel.name not in self.warm_channels:
        self.warm_channels[request.channel.name] = request.channel
        processing_time += self.test_manager.cold_channel_penalty
      if request.is_probe:
        processing_time = self.test_manager.probe_processing_time
      request.setState(self, Request.State.PROCESSING, expected_state=Request.State.SERVER_QUEUE, ticks_in_state=processing_time)
      self.processing_requests.append(request)
    prev_queued_requests = len(self.queued_requests)
    self.queued_requests = [request for request in self.queued_requests if request.state == Request.State.SERVER_QUEUE]
    self.test_manager.Log(self, f'dequeued {prev_queued_requests - len(self.queued_requests)} requests, queue length changed from {prev_queued_requests} to {len(self.queued_requests)}, processing requests changed from {prev_processing_requests} to {len(self.processing_requests)}')

    # Tick
    for request in self.queued_requests:
      request.tick()

    # Process
    latencies = []
    for request in self.processing_requests:
      request.tick()
      if request.isFinishedInState():
        assert request.channel
        request.channel.receive(request, self.server_state)
        latencies.append(request.ticks_in_state[Request.State.SERVER_QUEUE] + request.ticks_in_state[Request.State.PROCESSING])
    requests_processed = len(self.processing_requests)
    self.processing_requests = [request for request in self.processing_requests if request.state == Request.State.PROCESSING]
    self.test_manager.Log(self, f'processed {requests_processed} requests, {requests_processed - len(self.processing_requests)} finished, leaving {len(self.processing_requests)} to continue processing next time')

    # Record
    utilization = 100*requests_processed / self.threads

    # TODO: this is a bit of a hack. We should probably move this logic to the client and use the most recent data.
    other_utilizations = {server.name: server.server_state.utilization.mean() for server in self.test_manager.servers_impl if not server.server_state.utilization.is_empty()}
    # other_utilizations = {channel.server.name: channel.server.server_state.utilization.latest() for channel in self.warm_channels.values() if not channel.server.server_state.utilization.is_empty()}
    if other_utilizations:
      average_utilization = sum(other_utilizations.values()) / len(other_utilizations)
      # This seems to help.
      target_utilization = average_utilization + 10 if average_utilization < 80 else average_utilization
      pid_adjustment = self.pid_controller.update(utilization, target_utilization)
      self.test_manager.Log(self, f'PID adjusting weight from {self.server_state.pid_weight} to {self.server_state.pid_weight + pid_adjustment} (change: {pid_adjustment}) based on utilization of {utilization} compared to average utilization of {average_utilization} ({other_utilizations})')
    else:
      self.test_manager.Log(self, f'Not adjusting PID weight from {self.server_state.pid_weight} because no other servers have utilization or there are no warm channels to pick servers from')
      pid_adjustment = 0
    if requests_processed > self.threads:
      raise ValueError(f'{self.name} processed {requests_processed} requests, more than the {self.threads} available threads')
    old_server_state = self.server_state
    self.server_state = self.server_state.update(pid_weight_adjustment=pid_adjustment,
                                                 utilization=utilization,
                                                 queries_per_tick=self.queries_in_tick,
                                                 requests_in_flight=len(self.queued_requests) + len(self.processing_requests),
                                                 requests_in_queue=len(self.queued_requests),
                                                 queued_requests=len(self.queued_requests),
                                                 latency=sum(latencies)/len(latencies) if latencies else None)
    self.test_manager.Log(self, f'server state updated from {old_server_state} to {self.server_state} based on '
          f'utilization of {requests_processed} / {self.threads} = {utilization}%,'
          f'{len(self.queued_requests) + len(self.processing_requests)} requests in flight'
          f'{len(self.queued_requests)} queued requests'
          f'and mean latency of {sum(latencies)/len(latencies) if latencies else None} ticks ({latencies})'
        )
    self.queries_in_tick = 0


@dataclass
class Channel:
  """Network connection between a client and server. Introduces some latency in both directions."""
  test_manager: TestManager = field(repr=False)
  mean_ticks: float
  client: 'Client' = field(repr=False)
  server: Server
  send_queue: list[Request]
  receive_queue: list[Request]
  completed: LastN[Request]
  requests_in_flight: int
  name: str

  def __init__(self, client: 'Client', server: Server, test_manager: TestManager):
    self.test_manager = test_manager
    self.name = f'{client.name} -> {server.name}'
    self.client = client
    self.server = server
    self.send_queue = []
    self.receive_queue = []
    self.requests_in_flight = 0
    self.completed = LastN(self.test_manager.client_requests_buffer)
    self.reset()

  def send(self, request: Request):
    self.requests_in_flight += 1
    request.setState(self, Request.State.SENDING,
                     expected_state=Request.State.PICKING,
                     ticks_in_state=randGamma(self.mean_ticks))
    self.send_queue.append(request)

  def receive(self, request: Request, server_state: Optional[ServerState]):
    self.requests_in_flight -= 1
    request.setState(self,
    Request.State.RECEIVING,
                     expected_state=Request.State.PROCESSING,
                     ticks_in_state=randGamma(self.mean_ticks),
                     server_state=server_state)
    self.receive_queue.append(request)

  def reset(self):
    if hasattr(self, 'mean_ticks'):
      self.test_manager.Log(self, f'resetting from {self.mean_ticks} RTT')
    self.mean_ticks = (self.server.mean_rtt + randGamma(self.test_manager.mean_channel_rtt, variance=self.test_manager.channel_rtt_variance))/2
    self.test_manager.Log(self, f'resetting to {self.mean_ticks} RTT')

  def tick(self):
    if random.random() < self.test_manager.network_reset_probability:
      self.reset()
    self.test_manager.Log(self, f'ticking with {len(self.send_queue)} in send queue, {len(self.receive_queue)} in recieve queue, {self.requests_in_flight} in flight')
    for request in self.send_queue:
      request.tick()
      if request.isFinishedInState():
        self.server.enqueue(request, self)
    self.send_queue = [request for request in self.send_queue if request.state == Request.State.SENDING]

    for request in self.receive_queue:
      request.tick()
      if request.isFinishedInState():
        self.client.complete(request)
        self.completed = self.completed.insert(request)
    self.receive_queue = [request for request in self.receive_queue if request.state == Request.State.RECEIVING]


@dataclass
class Client:
  """Sends requests. Must make decision to balance load based on incomplete information."""
  name: str
  test_manager: TestManager = field(repr=False)
  channels: list[Channel]
  pending_requests: list[tuple[Request, 'ChannelPicker']]

  def __init__(self, name: str, test_manager: TestManager, servers: list[Server]):
    self.name = name
    self.test_manager = test_manager
    if test_manager.subset_size is not None:
      servers = random.choices(servers, k=test_manager.subset_size)
    self.channels = [Channel(self, server, self.test_manager) for server in servers]
    self.pending_requests = []

  def requests_in_flight(self):
    return sum(channel.requests_in_flight for channel in self.channels)

  def send(self, request: Request):
    channel_picker = self.test_manager.channel_picker_factory(request, self.name, self.channels, self.test_manager)
    self.pending_requests.append((request, channel_picker))

  def complete(self, request: Request):
    request.setState(self, Request.State.COMPLETED, expected_state=Request.State.RECEIVING)
    self.test_manager.RequestCompleted(request)
    self.test_manager.Log(self, f'!!!request completed {request}')

  def tick(self):
    for request, channel_picker in self.pending_requests:
      request.tick()
      self.test_manager.Log(self, f'tick {request}')
      channel = next(channel_picker)
      if channel is not None:
        channel.send(request)
    self.pending_requests = [(request, channel_picker)
                             for request, channel_picker in self.pending_requests
                             if request.state == Request.State.PICKING]

    for channel in self.channels:
      channel.tick()

## Channel pickers

In [None]:
# todo: look at both utilization and latency together.
# todo: pid based balancer. I think this should be on the server because the client doesn't have
# enough info to know whether it was the cause of a utilization change. The target should be the
# mean utilisation across all servers because if the target is an upper bound we might
# assign too much weight during periods of low traffic. This is a bir annoying as it involves
# inter-server cooperation or possibly peoxying average utilisation data through the clients,
# with the extra complexity and latency that would bring.

# yield None until a channel has been picked, then yield the channel.
ChannelPicker = Generator[Channel | None, None, None]
ChannelPickerFactory = Callable[[Request, str, list[Channel], TestManager], ChannelPicker]
Aggregator = Callable[[LastN[T], Callable[[T], float]], T]
FloatAggregator = Callable[[LastNFloats], float]
# higher is better, 0 is best. None if no data is available.
Cost = float | None
ClientCostFn = Callable[[Channel], Cost]
ServerCostFn = Callable[[ServerState], Cost]

def shuffle(xs: list[T]) -> list[T]:
  return random.sample(xs, k=len(xs))

# TODO: https://engineering.atspotify.com/2015/12/els-part-2/ suggests cost of
# (success latency + (failure latency + 800) * (1 / success rate – 1)) * (in flight + 1)
# However we'd need to implement success/failure first. And I'm a bit unsure about this algorithm
# as it seems to assume that servers are single-threaded.

def PickFirst(request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
  yield channels[0]

def PickRandom(request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
  yield random.choice(channels)

class RoundRobin: # ChannelPickerFactory
  next_channel: dict[str, int] = defaultdict(int)
  def __call__(self, request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
    i = self.next_channel[client]
    self.next_channel[client] += 1
    yield channels[i % len(channels)]

@dataclass(init=True, frozen=True)
class Subset: # ChannelPickerFactory
  channel_picker_factory: ChannelPickerFactory
  subset_size: int=2

  def __call__(self, request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
    return self.channel_picker_factory(request, client, random.choices(channels, k=self.subset_size), test_manager)

def AlwaysTrue(probe: Request) -> bool:
  return True

@dataclass(init=True, frozen=True)
class FastestProbe: # ChannelPickerFactory
  validate: Callable[[Request], bool] = AlwaysTrue

  def __call__(self, request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
    # Note: we don't bother cancelling probes because they spend so little time being processed (the only part that can be cancelled)
    test_manager.Log(request, "starting probes")
    probes = [(Request(test_manager, is_probe=True), channel) for channel in shuffle(channels)]
    for probe, channel in probes:
      channel.send(probe)

    while True:
      for probe, channel in probes:
        if probe.state == Request.State.COMPLETED and self.validate(probe):
          other_options = ", ".join(c.server.name for c in channels)
          test_manager.Log(channel, f"first probe {probe.name} completed in {sum(probe.ticks_in_state.values())}, picking channel to server {channel.server.name} (all options: {','.join(c.server.name for c in channels)})")
          yield channel
          break
      if all(probe.state == Request.State.COMPLETED for probe, _ in probes):
        latencies = {sum(probe.ticks_in_state.values()): channel for probe, channel in probes}
        latency, channel = min(latencies.items(), key=lambda x: x[0])
        test_manager.Log(channel, f"no probe was valid, picking channel to server {channel.server.name} based on latency (all latencies: {latencies})")
        return channel
      test_manager.Log(request, "no probe completed, yielding")
      yield

def IsServerEmpty(probe: Request) -> bool:
  return probe.server_state.requests_in_queue == 0

FastestProbeToEmptyServer = FastestProbe(IsServerEmpty)

@dataclass(init=True, frozen=True)
class WithProbes: # ChannelPickerFactory
  channel_picker_factory: ChannelPickerFactory
  probes_per_request: float=2.0

  def __call__(self, request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
    # TODO: consider having a requirement that the last probe or request was sent or completed >N ticks ago
    empty_channels = [channel for channel in channels if channel.requests_in_flight == 0]
    for channel in random.sample(channels, k=min(len(empty_channels), RandomRound(self.probes_per_request))):
      channel.send(Request(test_manager, is_probe=True))
    return self.channel_picker_factory(request, client, channels, test_manager)

@dataclass(init=True, frozen=True)
class LeastCost: # ChannelPickerFactory
  cost_fn: ClientCostFn
  choices: int | None = None

  def cost(self, channel: Channel) -> float:
    return max(0, self.cost_fn(channel) or 0.0)

  def __call__(self, request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
    if self.choices is not None:
      channels = random.sample(channels, k=min(self.choices, len(channels)))
    costs = {channel.name: self.cost_fn(channel) for channel in channels}
    channel = min(shuffle(channels), key=self.cost)
    test_manager.Log(client, f"picked {channel.name} with cost {self.cost_fn(channel)} based on options: {costs} and possibly based on data from {None if channel.completed.is_empty() else channel.completed.latest(lambda x: x.tick_at_state_start[Request.State.RECEIVING])}")
    yield channel

# https://tlk-energy.de/blog-en/practical-pid-tuning-guide
# P: just a constant since the weight is going to be compared to other weights
kp = 1
# I: P/Ti where Ti is the integration time - how long it takes for the system to asymptotically meet the target.
#    In our case, it isn't a big deal if the value doesn't quite hit the target so we can use a large time.
kti = 30
ki = kp/kti
# D: P*Td, where Td is the derivative time. The recommended starting derivative time is Ti/10.
ktd = kti/10
kd = kp * ktd

def PidUtilizationWeighted(request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
  utilizations = {}
  for channel in channels:
    if not hasattr(channel, 'pid_controller'):
      channel.pid_controller = PidController(kp, ki, kd)
      channel.pid_weight = 0.0
    latest_request = channel.completed.latest(id)
    if latest_request is not None:
      server_state = latest_request.server_state
      if server_state is not None:
        utilization = server_state.utilization
        if not utilization.is_empty():
          utilizations[channel.name] = utilization.latest()
  if not utilizations:
    test_manager.Log(client, f"no valid choices, picking random channel")
    yield random.choice(channels)
  average_utilization = sum(utilizations.values()) / len(utilizations)
  target_utilization = average_utilization + 10 if average_utilization < 80 else average_utilization
  weights = []
  debug_data = [["weights", "utilization", "previous_weight", "adjustment", "time of latest data"]]
  for channel in channels:
    utilization = utilizations.get(channel.name, average_utilization)
    pid_adjustment = channel.pid_controller.update(utilization, target_utilization)
    channel.pid_weight += pid_adjustment
    weights.append(channel.pid_weight)
    debug_data.append([channel.pid_weight,
                       utilization,
                       channel.pid_weight - pid_adjustment,
                       pid_adjustment,
                       None if channel.completed.is_empty() else channel.completed.latest(lambda x: x.tick_at_state_start[Request.State.RECEIVING])])
  debug_data = "\n".join(",".join(str(x) for x in l) for l in debug_data)
  min_weight = min(weights)
  weights = [weight - min_weight + 0.1 for weight in weights]
  # average_weight = sum(weights) / len(weights)
  # normalised_weights = [weight / average_weight / len(weights) for weight in weights]
  channel = random.choices(channels, weights=weights, k=1)[0]
  test_manager.Log(channel, f"picked {channel.server.name} with weight {weights[channels.index(channel)]} based on \n{debug_data}")
  yield channel


@dataclass(init=True, frozen=True)
class Weighted: # ChannelPickerFactory
  cost_fn: ClientCostFn
  allow_negative_cost: bool=False

  @classmethod
  def average(cls, maybe_costs: list[float | None]) -> float | None:
    costs: list[float] = [c for c in maybe_costs if c]
    if not costs:
      return None
    return sum(costs) / len(costs)

  def cost(self, channel: Channel) -> float | None:
    cost = self.cost_fn(channel)
    if cost is None:
      return None
    if self.allow_negative_cost:
      return cost
    return max(0, cost)

  def costs(self, channels: list[Channel]) -> list[float | None]:
    return [self.cost(channel) for channel in channels]

  def __call__(self, request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
    maybe_costs = self.costs(channels)
    average_cost = self.average(maybe_costs)
    if not average_cost:
      test_manager.Log(channels[0].client, f"no valid choices, got {maybe_costs} costs for servers {[channel.server.name for channel in channels]}")
      yield random.choice(channels)
    costs: list[float] = [cost or average_cost for cost in maybe_costs]
    weights = [1/cost for cost in costs]
    if weights and self.allow_negative_cost and min(weights) <= 0:
      # TODO: should we perform this normalisation on weights or costs?
      min_weight = min(weights)
      weights = [weight - min_weight + 0.1 for weight in weights]
    average_weight = sum(weights) / len(weights)
    weights = [weight / average_weight / len(weights) for weight in weights]
    debug_data = [["weights", "costs", "original_costs", "original_weights", "server", "time of latest data"]] + list(zip(
        weights,
        costs,
        maybe_costs,
        (1/c if c else None for c in maybe_costs),
          (channel.server.name for channel in channels),
          (None if channel.completed.is_empty() else channel.completed.latest(lambda x: x.tick_at_state_start[Request.State.RECEIVING]) for channel in channels)))
    debug_data = "\n".join(",".join(str(x) for x in l) for l in debug_data)

    if sum(weights) < 0.9 or sum(weights) > 1.1:
      raise ValueError(f"unexpected sum of weights: {sum(weights)}. All data: \n{debug_data}")
    channel = random.choices(channels, weights=weights, k=1)[0]
    test_manager.Log(channel, f"picked {channel.server.name} with weight {weights[channels.index(channel)]} based on \n{debug_data}")
    yield channel

def Aggregate(method: str) -> Aggregator:
  return {
      'mean': LastN.mean,
      'latest': LastN.latest,
      'ewma': LastN.ewma,
      }[method]

def AggregateFloat(method: str) -> FloatAggregator:
  return {
      'mean': LastNFloats.mean,
      'latest': LastNFloats.latest,
      'ewma': LastNFloats.ewma,
      }[method]

@dataclass(init=True, frozen=True)
class ServerReportedCost: # ClientCostFn
  cost_fn: ServerCostFn

  def __call__(self, channel: Channel) -> Cost:
    latest_request = channel.completed.latest(id)
    if latest_request is None:
      return None
    server_state = latest_request.server_state
    if server_state is None:
      return None
    return self.cost_fn(server_state)

@dataclass(init=True, frozen=True)
class ServerQuantumTunneledCost: # ClientCostFn
  # Get latest data from the server, breaking the laws of physics as we know it
  cost_fn: ServerCostFn

  def __call__(self, channel: Channel) -> Cost:
    server_state = channel.server.server_state
    return self.cost_fn(server_state)

def ServerRequestsInFlightCost(server_state: ServerState) -> Cost:
  return server_state.requests_in_flight.latest()

def ServerExpectedLatencyCost(server_state: ServerState) -> Cost:
  return server_state.expectedLatency()

@dataclass(init=True, frozen=True)
class ServerRequestRateCost: # ServerCostFn
  aggregator: FloatAggregator

  def __call__(self, server_state: ServerState) -> Cost:
    return self.aggregator(server_state.queries_per_tick)

@dataclass(init=True, frozen=True)
class ServerUtilizationCost: # ServerCostFn
  # Another reasonably common option for WRR algorithms.
  aggregator: FloatAggregator

  def __call__(self, server_state: ServerState) -> Cost:
    return self.aggregator(server_state.utilization)


@dataclass(init=True, frozen=True)
class ServerCapacityCost: # ServerCostFn
  # Traditionally used for WRR algorithms.
  request_rate_aggregator: FloatAggregator
  utilization_aggregator: FloatAggregator

  def __call__(self, server_state: ServerState) -> Cost:
    request_rate = ServerRequestRateCost(self.request_rate_aggregator)(server_state)
    utilization = ServerUtilizationCost(self.utilization_aggregator)(server_state)
    if not request_rate or not utilization:
      return None
    capacity = request_rate / (utilization / 100)
    if capacity <= 0:
      return None
    return 1 / capacity

@dataclass(init=True, frozen=True)
class ServerAvailableCapacityCost: # ServerCostFn
  # My suggestion for an improvement using the same data used by most WRR algorithms.
  request_rate_aggregator: FloatAggregator
  utilization_aggregator: FloatAggregator

  def __call__(self, server_state: ServerState) -> Cost:
    request_rate = ServerRequestRateCost(self.request_rate_aggregator)(server_state)
    utilization = ServerUtilizationCost(self.utilization_aggregator)(server_state)
    if not request_rate or not utilization:
      return None
    capacity = request_rate / (utilization / 100)
    if capacity <= 0:
      return None
    available_capacity = capacity - request_rate
    if available_capacity <= 0:
      return None
    return 1/available_capacity

def ServerPidCost(server_state: ServerState) -> Cost: # ServerCostFn
  weight = server_state.pid_weight
  # Hack around the fact that we have to convert to cost, which means dividing by a value that could be zero
  epsilon  = 0.0001
  if weight > -0.0001 and weight < 0.0001:
    if weight >= 0:
      weight = epsilon
    else:
      weight = -epsilon
  return 1 / weight

CheatServerPidFactory = Weighted(ServerQuantumTunneledCost(ServerPidCost), allow_negative_cost=True)

@dataclass(init=True, frozen=True)
class ServerLatencyCost: # ServerCostFn
  aggregator: FloatAggregator

  def __call__(self, server_state: ServerState) -> Cost:
    return self.aggregator(server_state.latencies)

def ClientRequestsInFlightCost(channel: Channel) -> Cost:
  return channel.requests_in_flight

@dataclass(init=True, frozen=True)
class ClientLatencyCost: # ClientCostFn
  aggregator: Aggregator

  def __call__(self, channel: Channel) -> Cost:
    if channel.completed.is_empty():
      return None
    return self.aggregator(channel.completed, lambda request: sum(request.ticks_in_state.values()))


# hot/cold lexicographical selection, as described by PLB/PreQual
# This implements most of PreQual as I understand it, although missing the exact probing mechanism and outlier detection.
@dataclass(init=True)
class Hcl: # ChannelPickerFactory
  ranking_fn: ClientCostFn = ServerReportedCost(ServerExpectedLatencyCost)
  health_fn: ClientCostFn =  ServerReportedCost(ServerRequestsInFlightCost)
  health_cutoff_pct: int = 80
  health_history: dict[str, LastN] = field(default_factory=dict, repr=False)

  @classmethod
  def average(cls, maybe_costs: list[float | None]) -> float | None:
    costs: list[float] = [c for c in maybe_costs if c]
    if not costs:
      return None
    return sum(costs) / len(costs)

  def __call__(self, request: Request, client: str, channels: list[Channel], test_manager: TestManager) -> ChannelPicker:
    maybe_health = [self.health_fn(channel) for channel in channels]
    average_health = self.average(maybe_health)
    if average_health is None:
      test_manager.Log(client, f"no valid choices due to missing health, picking random channel")
      yield random.choice(channels)
    healths = [average_health if health is None else health for health in maybe_health]
    historical_health = []
    for channel, health in zip(channels, maybe_health):
      if channel.name not in self.health_history:
        # Ideally we should rate-limit additions to this to prevent adding to it on every request. But it's good enough.
        self.health_history[channel.name] = LastN(test_manager.client_requests_buffer)
      if health is not None:
        self.health_history[channel.name] = self.health_history[channel.name].insert(health)
      historical_health.extend(self.health_history[channel.name])
    if not historical_health:
      test_manager.Log(client, f"no valid choices due to missing historical health, picking random channel")
      yield random.choice(channels)
    health_cutoff = np.percentile(historical_health, self.health_cutoff_pct)

    maybe_ranks = [self.ranking_fn(channel) for channel in channels]
    average_rank = self.average(maybe_ranks)
    if average_rank is None:
      test_manager.Log(client, f"no valid choices due to missing ranks, picking random channel")
      yield random.choice(channels)
    ranks = [average_rank if rank is None else rank for rank in maybe_ranks]
    healthy_ranks = [(channel, rank, health) for channel, rank, health in zip(channels, ranks, healths) if health <= health_cutoff]
    if healthy_ranks:
      best_channel, rank, health = min(healthy_ranks, key=lambda t: t[1])
      test_manager.Log(client, f"picking channel {best_channel.name} with rank {rank} and health {health} based on selection of {len(healthy_ranks)} healthy servers")
      yield best_channel
    else:
      best_channel, health = min(zip(channels, healths), key=lambda t: t[1])
      health_dict = dict(list(zip((c.name for c in channels), healths)))
      test_manager.Log(client, f"picking channel {best_channel.name} with health {health} due to no healthy servers (cutoff {health_cutoff}, healths: {health_dict})")
      yield best_channel


## Sample run

In [None]:
# raise Exception("Are you sure you want to do a sample run?")

In [None]:
qps_per_server = queries_for_utilization(TestManager.mean_threads,
                                        TestManager.mean_processing_time,
                                        probes_per_request=0,
                                        probe_processing_time=TestManager.probe_processing_time,
                                        target_utilization = 0.8)
qps_per_server

13.333333333333334

In [None]:
test_manager = TestManager(total_ticks=80, debug=True,
                           data_frequency=1,
                           mean_requests_per_server_per_tick=qps_per_server,
                 channel_picker_factory=WithProbes(Hcl(
                                                    ranking_fn=ServerQuantumTunneledCost(ServerExpectedLatencyCost),
                                                    health_fn=ServerQuantumTunneledCost(ServerRequestsInFlightCost)))) #WithProbes(LeastCost(ServerReportedCost(ServerRequestsInFlightCost)), probes_per_request=3))#Subset(FastestProbe(), subset_size=3))
df = test_manager.runTest("demo")
df

demo:   0%|          | 0/80 [00:00<?, ?it/s]

Unnamed: 0,test_stage,client,server,is_probe,started_on_tick,completed_on_tick,tick_at_send,tick_at_server_queue,tick_at_processing,tick_at_receive,...,server_queries_per_tick_latest,server_queries_per_tick_mean,server_queries_per_tick_ewma,server_requests_in_flight_latest,server_requests_in_flight_mean,server_requests_in_flight_ewma,server_queued_requests_latest,server_queued_requests_mean,server_queued_requests_ewma,server_pid_weight
0,WARMING,Client 0,Server 11,True,0,1,0,0,1,1,...,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,60.200000
1,WARMING,Client 0,Server 11,True,0,1,0,0,1,1,...,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,60.200000
2,WARMING,Client 0,Server 11,True,0,1,0,0,1,1,...,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,60.200000
3,WARMING,Client 0,Server 16,True,0,1,0,0,1,1,...,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,60.200000
4,WARMING,Client 0,Server 16,True,0,1,0,0,1,1,...,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,60.200000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69165,COOLING,Client 12,Server 33,False,77,105,77,81,91,100,...,0.0,22.390000,1.598423,12.0,139.850000,72.970853,0.0,70.160000,22.753278,-2943.435409
69166,COOLING,Client 19,Server 33,False,76,105,76,82,92,99,...,0.0,22.616162,1.998029,21.0,141.141414,88.213566,0.0,70.868687,28.441598,-2970.938909
69167,COOLING,Client 32,Server 33,False,78,105,78,81,92,102,...,0.0,22.310000,1.022991,5.0,139.920000,48.501346,0.0,70.160000,14.562098,-2844.576429
69168,COOLING,Client 35,Server 31,False,74,105,74,75,89,105,...,0.0,14.000000,0.110063,2.0,102.740000,21.348027,0.0,46.300000,3.579125,-3280.747301


In [None]:
latency = df[(~df['is_probe']) & (df['test_stage'] != 'WARMING')]['total_ticks']
utilization = df[df['test_stage'] == 'RUNNING'][['tick_at_receive', 'server', 'server_utilization_latest']].groupby(['server_utilization_latest'])['server_utilization_latest'].mean()

stddev_utilization_per_server = df[df['test_stage'] == 'RUNNING'][['server', 'server_utilization_latest']].groupby(['server']).std()['server_utilization_latest']
cooling = df[df['test_stage'] == 'COOLING']

iles = [0.5, 0.8, 0.9, 0.95, 0.99, 0.995, 0.999, 1.0]
# TODO: work out cooldown period calculation, bearing in mind that we have multiple runs,
# and the per-request data only includes the test stage at completion.
summary = {
    # 'cooldown_period': cooling['completed_on_tick'].max() - cooling['started_on_tick'].min(),
    'mean_latency': latency.mean(),
    'stdev_latency': latency.std(),
  }
summary |= {f"p{100*p}_latency": latency.quantile(p) for p in iles}
summary |= {
      'mean_utilization': utilization.mean(),
      'stdev_utilization': utilization.std(),
      'mean_stdev_utilization': stddev_utilization_per_server.mean(),
  }
summary |= {f"p{100*p}_utilization_ratio": utilization.quantile(p) / utilization.mean() for p in iles}
summary

{'mean_latency': 12.10755355885607,
 'stdev_latency': 5.887501659648994,
 'p50.0_latency': 11.0,
 'p80.0_latency': 16.0,
 'p90.0_latency': 19.0,
 'p95.0_latency': 23.0,
 'p99.0_latency': 31.0,
 'p99.5_latency': 37.0,
 'p99.9_latency': 52.0,
 'p100.0_latency': 61.0,
 'mean_utilization': 40.744155382588374,
 'stdev_utilization': 27.042215496548028,
 'mean_stdev_utilization': 27.694444039138908,
 'p50.0_utilization_ratio': 0.8662375576483647,
 'p80.0_utilization_ratio': 1.6684106226540136,
 'p90.0_utilization_ratio': 2.0491041621416817,
 'p95.0_utilization_ratio': 2.2320967781810515,
 'p99.0_utilization_ratio': 2.3952410254685477,
 'p99.5_utilization_ratio': 2.42853600591098,
 'p99.9_utilization_ratio': 2.4384073378063884,
 'p100.0_utilization_ratio': 2.4543397466703665}

In [None]:
# [f"{t} {n}: {m}" for t, n, m in test_manager.logs if 'zaqra' in n or 'zaqra' in m]

In [None]:
probes = df[df['is_probe']]
probes[[col for col in df.columns if "_ticks" in col]].mean()

Unnamed: 0,0
total_ticks,6.609688
picking_ticks,0.0
send_ticks,2.110084
server_queue_ticks,1.40544
processing_ticks,1.0
receive_ticks,2.094164


In [None]:
non_probes = df[~df['is_probe']]
non_probes[[col for col in df.columns if "_ticks" in col]].mean()

Unnamed: 0,0
total_ticks,11.316855
picking_ticks,1.0
send_ticks,2.139254
server_queue_ticks,1.631484
processing_ticks,4.386096
receive_ticks,2.160022


In [None]:
def chart(df, y: str, aggregation: str, x="tick_at_receive", grouping="server", probes=None, p_chart=px.line, **kwargs):
  if probes is not None:
    if probes:
      df = df[~df['is_probe']]
    else:
      df = df[df['is_probe']]
  grouping_list = [grouping] if grouping else []
  df = df[grouping_list+[x, y]].groupby(grouping_list+[x]).aggregate(aggregation).reset_index()
  return p_chart(df, x=x, y=y, color=grouping, height=300, **kwargs)

def qps_chart(df, x="tick_at_send", grouping="server", probes=None, p_chart=px.area, **kwargs):
  return chart(df, y='total_ticks', aggregation='count', x=x, grouping=grouping, probes=probes, p_chart=p_chart, **kwargs)

chart(df, y='server_utilization_ewma', aggregation='mean', range_y=(0,100)).show()
chart(df, y='server_utilization_latest', aggregation='mean', grouping='test_stage', range_y=(0,100)).show()
chart(df, y='server_queued_requests_latest', aggregation='mean').show()

# chart(sdf, y='server_utilization_latest', aggregation='mean', grouping='server').show()
# chart(sdf, y='server_queries_per_tick_latest', aggregation='mean', grouping='server').show()
# chart(sdf, y='server_requests_in_flight_latest', aggregation='mean', grouping='server').show()
# chart(sdf, y='server_queued_requests_latest', aggregation='mean', grouping='server').show()

In [None]:
sdf = df[(df['server'] == 'Server 15')]
chart(sdf, y='server_pid_weight', aggregation='mean').show()
chart(sdf, y='server_utilization_latest', aggregation='mean', range_y=(0,100)).show()
chart(sdf, y='server_queued_requests_latest', aggregation='mean').show()
qps_chart(sdf, x='tick_at_send', grouping='is_probe').show()

In [None]:
# print("\n".join([f"{t} {n}: {m}" for t, n, m in test_manager.logs if ("picked Server 15" in m) and "pick" in m and t > 55 and t < 75][:50]))
# print("\n".join([f"{t} {n}: {m}" for t, n, m in test_manager.logs if ("Server 26" in n or "Server 25" in n) and "PID" in m and t >= 55 and t < 75]))

print("\n".join([f"{t} {n}: {m}" for t, n, m in test_manager.logs if (n == 'Server 15') and t > 48 and t < 52][:50]))

49 Server 15: Request tk9v4 from Client 3 updated from State.SERVER_QUEUE to State.PROCESSING with 1 ticks left
49 Server 15: Request 0i81a from Client 3 updated from State.SERVER_QUEUE to State.PROCESSING with 1 ticks left
49 Server 15: Request i36oi from Client 3 updated from State.SERVER_QUEUE to State.PROCESSING with 1 ticks left
49 Server 15: Request t8s9f from Client 3 updated from State.SERVER_QUEUE to State.PROCESSING with 1 ticks left
49 Server 15: Request 8scpj from Client 3 updated from State.SERVER_QUEUE to State.PROCESSING with 1 ticks left
49 Server 15: Request in44b from Client 5 updated from State.SERVER_QUEUE to State.PROCESSING with 1 ticks left
49 Server 15: Request 08jzb from Client 15 updated from State.SERVER_QUEUE to State.PROCESSING with 1 ticks left
49 Server 15: Request lyoqb from Client 18 updated from State.SERVER_QUEUE to State.PROCESSING with 1 ticks left
49 Server 15: Request mkd11 from Client 18 updated from State.SERVER_QUEUE to State.PROCESSING with 1 

In [None]:
chart(df, y='server_queue_ticks', x='tick_at_server_queue', aggregation='mean', probes=False, grouping='server').show()
chart(df, y='server_queue_ticks', x='tick_at_processing', aggregation='mean', probes=False, grouping='server').show()
chart(df, y='server_queue_ticks', x='tick_at_receive', aggregation='mean', probes=False, grouping='server').show()

In [None]:
def latency_chart(df, x="tick_at_send", probes=None, p_chart=px.area, **kwargs):
  per_stage = pd.melt(df,
          value_vars=['picking_ticks', 'send_ticks', 'server_queue_ticks', 'processing_ticks', 'receive_ticks'],
          var_name='stage',
          value_name= 'ticks_in_stage',
          id_vars=['is_probe', 'server', 'client', x]).reset_index()
  return chart(per_stage, grouping='stage', probes=probes, y='ticks_in_stage', x=x, aggregation='mean', p_chart=px.area)

latency_chart(df, probes=False, x="started_on_tick").show()
# latency_chart(df, probes=False, x="tick_at_send").show()
# latency_chart(df, probes=False, x="tick_at_server_queue").show()
# latency_chart(df, probes=False, x="tick_at_processing").show()
# latency_chart(df, probes=False, x="tick_at_receive").show()
# latency_chart(df, probes=False, x="completed_on_tick").show()
chart(df, probes=False, grouping=None, y='total_ticks', x='tick_at_send', aggregation='mean').show()

## Real run

In [None]:
# raise Exception("Are you sure you want to continue with the real run?")

In [None]:
# TODO: scenarios to test:
# standard
# negligible server processing latency
# negligible network latency
# big spread in server network latency
# many low-qps clients/ (low enough that client-measured RIF is 0)
# low utilization
# high utilization
# very long running test

In [200]:
TOTAL_TICKS = 170
TEST_RUNS = 6
TARGET_UTILIZATION = 0.7 # Actual utilization tends to be lower, as some requests will be built up in a queue and won't complete by the time the test finishes
request_rate = queries_for_utilization(TestManager.mean_threads,
                                       TestManager.mean_processing_time,
                                       target_utilization = TARGET_UTILIZATION)
scenarios =  {
    'Random': {'total_ticks':TOTAL_TICKS,
                   'mean_requests_per_server_per_tick': request_rate,
                   'channel_picker_factory': PickRandom},
    'RoundRobin': {'total_ticks':TOTAL_TICKS,
                   'mean_requests_per_server_per_tick': request_rate,
                   'channel_picker_factory': RoundRobin()},
    'RoundRobinLittleBigJobs': {'total_ticks':TOTAL_TICKS,
                   'mean_requests_per_server_per_tick': request_rate * 4,
                   'channel_picker_factory': RoundRobin(),
                    'servers': int(math.ceil(TestManager.servers / 4)),
                    'mean_threads': TestManager.mean_threads * 4,
                      },
    'LeastMeanClientLatency': {'total_ticks':TOTAL_TICKS,
                               'mean_requests_per_server_per_tick': request_rate,
                               'channel_picker_factory': LeastCost(ClientLatencyCost(Aggregate('mean')))},

    'LeastServerRequestsInFlight': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': LeastCost(ServerReportedCost(ServerRequestsInFlightCost))},
    'WeightedServerRequestsInFlight': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': LeastCost(ServerReportedCost(ServerRequestsInFlightCost))},

    'LeastServerExpectedLatency': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': LeastCost(ServerReportedCost(ServerExpectedLatencyCost))},

    'WeightedServerExpectedLatency': {'total_ticks':TOTAL_TICKS,
                                      'mean_requests_per_server_per_tick': request_rate,
                                      'channel_picker_factory': Weighted(ServerReportedCost(ServerExpectedLatencyCost))},


    'ProbingLeastServerRequestsInFlight': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': WithProbes(LeastCost(ServerReportedCost(ServerRequestsInFlightCost)))},
    'ProbingWeightedServerRequestsInFlight': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': WithProbes(LeastCost(ServerReportedCost(ServerRequestsInFlightCost)))},

    'ProbingLeastServerExpectedLatency': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': WithProbes(LeastCost(ServerReportedCost(ServerExpectedLatencyCost)))},

    'ProbingWeightedServerExpectedLatency': {'total_ticks':TOTAL_TICKS,
                                      'mean_requests_per_server_per_tick': request_rate,
                                      'channel_picker_factory': WithProbes(Weighted(ServerReportedCost(ServerExpectedLatencyCost)))},


    'ProbingLeastServerRequestsInFlightLowerRequestRate': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                                    'channel_picker_factory': WithProbes(LeastCost(ServerReportedCost(ServerRequestsInFlightCost)))},
    'ProbingWeightedServerRequestsInFlightLowerRequestRate': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                                    'channel_picker_factory': WithProbes(LeastCost(ServerReportedCost(ServerRequestsInFlightCost)))},

    'ProbingLeastServerExpectedLatencyLowerRequestRate': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                                    'channel_picker_factory': WithProbes(LeastCost(ServerReportedCost(ServerExpectedLatencyCost)))},

    'ProbingWeightedServerExpectedLatencyLowerRequestRate': {'total_ticks':TOTAL_TICKS,
                                      'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                                      'channel_picker_factory': WithProbes(Weighted(ServerReportedCost(ServerExpectedLatencyCost)))},



    'LeastClientRequestsInFlight': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': LeastCost(ClientRequestsInFlightCost)},
    'LeastClientRequestsInFlightPickTwo': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': LeastCost(ClientRequestsInFlightCost, choices=2)},

    'Hcl': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': Hcl(),
    },
    'ProbingHcl': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': WithProbes(Hcl()),
    },
    'ProbingHclLowerRequestRate': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                                    'channel_picker_factory': WithProbes(Hcl()),
    },

    'ClientMetricsHcl': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': Hcl(health_fn=ClientRequestsInFlightCost,
                                                                ranking_fn=ClientLatencyCost(Aggregate('mean')))
    },
    'ProbingClientMetricsHcl': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': WithProbes(Hcl(health_fn=ClientRequestsInFlightCost,
                                                                ranking_fn=ClientLatencyCost(Aggregate('mean')))),
    },
    'ProbingClientMetricsHclLowerRequestRate': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                                    'channel_picker_factory': WithProbes(Hcl(health_fn=ClientRequestsInFlightCost,
                                                                ranking_fn=ClientLatencyCost(Aggregate('mean')))),
    },

    'LeastClientRequestsInFlightPickThree': {'total_ticks':TOTAL_TICKS,
                                    'mean_requests_per_server_per_tick': request_rate,
                                    'channel_picker_factory': LeastCost(ClientRequestsInFlightCost, choices=3)},

    'WeightedMeanClientLatency': {'total_ticks':TOTAL_TICKS,
                                       'mean_requests_per_server_per_tick': request_rate,
                                  'channel_picker_factory': Weighted(ClientLatencyCost(Aggregate('mean')))},
    'WeightedClientRequestsInFlight': {'total_ticks':TOTAL_TICKS,
                                       'mean_requests_per_server_per_tick': request_rate,
                                       'channel_picker_factory': Weighted(ClientRequestsInFlightCost)},


    'WeightedMeanUtilization': {'total_ticks':TOTAL_TICKS,
                                'mean_requests_per_server_per_tick': request_rate,
                                'channel_picker_factory': Weighted(ServerReportedCost(ServerUtilizationCost(AggregateFloat('mean'))))},
    'WeightedMeanCapacity': {'total_ticks':TOTAL_TICKS,
                             'mean_requests_per_server_per_tick': request_rate,
                             'channel_picker_factory': Weighted(ServerReportedCost(ServerCapacityCost(AggregateFloat('mean'), AggregateFloat('mean'))))},
    'WeightedMeanAvailableCapacity': {'total_ticks': TOTAL_TICKS,
                                      'mean_requests_per_server_per_tick': request_rate,
                                      'channel_picker_factory': Weighted(ServerReportedCost(ServerAvailableCapacityCost(AggregateFloat('mean'), AggregateFloat('mean'))))},

    'CheatWeightedMeanUtilization': {'total_ticks':TOTAL_TICKS,
                                'mean_requests_per_server_per_tick': request_rate,
                                'channel_picker_factory': Weighted(ServerQuantumTunneledCost(ServerUtilizationCost(AggregateFloat('mean'))))},
    'CheatWeightedMeanCapacity': {'total_ticks':TOTAL_TICKS,
                             'mean_requests_per_server_per_tick': request_rate,
                             'channel_picker_factory': Weighted(ServerQuantumTunneledCost(ServerCapacityCost(AggregateFloat('mean'), AggregateFloat('mean'))))},
    'CheatWeightedMeanAvailableCapacity': {'total_ticks': TOTAL_TICKS,
                                      'mean_requests_per_server_per_tick': request_rate,
                                      'channel_picker_factory': Weighted(ServerQuantumTunneledCost(ServerAvailableCapacityCost(AggregateFloat('mean'), AggregateFloat('mean'))))},
    'ProbingWeightedMeanUtilization': {'total_ticks':TOTAL_TICKS,
                                'mean_requests_per_server_per_tick': request_rate,
                                'channel_picker_factory': WithProbes(Weighted(ServerReportedCost(ServerUtilizationCost(AggregateFloat('mean')))))},
    'ProbingWeightedMeanCapacity': {'total_ticks':TOTAL_TICKS,
                             'mean_requests_per_server_per_tick': request_rate,
                             'channel_picker_factory': WithProbes(Weighted(ServerReportedCost(ServerCapacityCost(AggregateFloat('mean'), AggregateFloat('mean')))))},
    'ProbingWeightedMeanAvailableCapacity': {'total_ticks': TOTAL_TICKS,
                                      'mean_requests_per_server_per_tick': request_rate,
                                      'channel_picker_factory': WithProbes(Weighted(ServerReportedCost(ServerAvailableCapacityCost(AggregateFloat('mean'), AggregateFloat('mean')))))},


    'ProbingWeightedMeanUtilizationLowerRequestRate': {'total_ticks':TOTAL_TICKS,
                                'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                                'channel_picker_factory': WithProbes(Weighted(ServerReportedCost(ServerUtilizationCost(AggregateFloat('mean')))))},
    'ProbingWeightedMeanCapacityLowerRequestRate': {'total_ticks':TOTAL_TICKS,
                                'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                             'channel_picker_factory': WithProbes(Weighted(ServerReportedCost(ServerCapacityCost(AggregateFloat('mean'), AggregateFloat('mean')))))},
    'ProbingWeightedMeanAvailableCapacityLowerRequestRate': {'total_ticks': TOTAL_TICKS,
                                'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=6, # Not super-accurate because the probe frequency isn't trivially predictable
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                                      'channel_picker_factory': WithProbes(Weighted(ServerReportedCost(ServerAvailableCapacityCost(AggregateFloat('mean'), AggregateFloat('mean')))))},


    'CheatPidControlledUtilization': {'total_ticks': TOTAL_TICKS,
                                      'mean_requests_per_server_per_tick': request_rate,
                                      'channel_picker_factory': CheatServerPidFactory},
    'ClientPidControlledUtilization': {'total_ticks': TOTAL_TICKS,
                                      'mean_requests_per_server_per_tick': request_rate,
                                      'channel_picker_factory': PidUtilizationWeighted},


    'FastestProbe2': {'total_ticks': TOTAL_TICKS,
                      'mean_requests_per_server_per_tick': request_rate,
                      'channel_picker_factory': Subset(FastestProbe(), subset_size=2)},
    'FastestProbe3': {'total_ticks': TOTAL_TICKS,
                      'mean_requests_per_server_per_tick': request_rate,
                     'channel_picker_factory': Subset(FastestProbe(), subset_size=3)},
    # Causes too much overload
    # 'FastestProbeAll': {'total_ticks': TOTAL_TICKS,
    #                     'mean_requests_per_server_per_tick': request_rate,
    #                    'channel_picker_factory': FastestProbe()},


    'FastestProbe2LowerRequestRate': {'total_ticks': TOTAL_TICKS,
                      'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=2 * 55/63, # Fudge factor, not sure why this is necessary to keep utilization on target
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                      'channel_picker_factory': Subset(FastestProbe(), subset_size=2)},
    'FastestProbe3LowerRequestRate': {'total_ticks': TOTAL_TICKS,
                      'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=3 * 55/63, # Fudge factor, not sure why this is necessary to keep utilization on target
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                     'channel_picker_factory': Subset(FastestProbe(), subset_size=3)},
    'FastestProbeAllLowerRequestRate': {'total_ticks': TOTAL_TICKS,
                        'mean_requests_per_server_per_tick': queries_for_utilization(TestManager.mean_threads,
                                                                                TestManager.mean_processing_time,
                                                                                probes_per_request=TestManager.subset_size * 55/63, # Fudge factor, not sure why this is necessary to keep utilization on target
                                                                                probe_processing_time=TestManager.probe_processing_time,
                                                                                target_utilization = TARGET_UTILIZATION),
                       'channel_picker_factory': FastestProbe()},
    }

In [201]:
# If we already have some results from a previous run, strip out any duplicate work
run_again = {'MyAlgorithm'}
try:
  previously_run_scenarios = {result['algorithm']: result.get('kwargs', None) for result in results}
  new_scenarios = {}
  for algorithm, kwargs in scenarios.items():
    if algorithm not in previously_run_scenarios.keys():
      print(f'{algorithm} not seen before')
      new_scenarios[algorithm] = kwargs
    elif previously_run_scenarios[algorithm] is not None and previously_run_scenarios[algorithm] != kwargs:
      print(f'{algorithm} kwargs changed from {previously_run_scenarios[algorithm]} to {kwargs}')
      new_scenarios[algorithm] = kwargs
    elif algorithm in run_again:
      print(f'{algorithm} unchanged but running again')
      new_scenarios[algorithm] = kwargs
    # else:
    #   print(f'{algorithm} unchanged (old kwargs {previously_run_scenarios.get(algorithm, "algorithm not present")} new kwargs {kwargs})')
  if new_scenarios:
    print(f'updating scenarios: {new_scenarios}')
    scenarios = new_scenarios
    results = [result for result in results if result['algorithm'] not in new_scenarios.keys()]
except NameError:
  pass

RoundRobin kwargs changed from {'total_ticks': 170, 'mean_requests_per_server_per_tick': 11.666666666666666, 'channel_picker_factory': <__main__.RoundRobin object at 0x7a8390b8f100>} to {'total_ticks': 170, 'mean_requests_per_server_per_tick': 11.666666666666666, 'channel_picker_factory': <__main__.RoundRobin object at 0x7a831cbca8c0>}
RoundRobinLittleBigJobs kwargs changed from {'total_ticks': 170, 'mean_requests_per_server_per_tick': 46.666666666666664, 'channel_picker_factory': <__main__.RoundRobin object at 0x7a8390b8f130>, 'servers': 10, 'mean_threads': 400} to {'total_ticks': 170, 'mean_requests_per_server_per_tick': 46.666666666666664, 'channel_picker_factory': <__main__.RoundRobin object at 0x7a8308b9ae90>, 'servers': 10, 'mean_threads': 400}
Hcl kwargs changed from {'total_ticks': 170, 'mean_requests_per_server_per_tick': 11.666666666666666, 'channel_picker_factory': Hcl(ranking_fn=ServerReportedCost(cost_fn=<function ServerExpectedLatencyCost at 0x7a832364dd80>), health_fn=Se

In [202]:
# raw_data = {}
iles = [0.5, 0.9, 0.99, 0.999]

def runTest(args):
  scenario, kwargs = args
  try:
    test_manager = TestManager(**kwargs)
    df = test_manager.runTest(scenario)
    latency = df[(~df['is_probe']) & (df['test_stage'] != 'WARMING')]['total_ticks']
    utilization = df[df['test_stage'] == 'RUNNING'][['tick_at_receive', 'server', 'server_utilization_latest']].groupby(['server_utilization_latest'])['server_utilization_latest'].mean()

    stddev_utilization_per_server = df[df['test_stage'] == 'RUNNING'][['server', 'server_utilization_latest']].groupby(['server']).std()['server_utilization_latest']
    cooling = df[df['test_stage'] == 'COOLING']
    # TODO: work out cooldown period calculation, bearing in mind that we have multiple runs,
    # and the per-request data only includes the test stage at completion.
    summary = {
        'algorithm': scenario,
        'kwargs': kwargs,
        # 'cooldown_period': cooling['completed_on_tick'].max() - cooling['started_on_tick'].min(),
        'mean_latency': latency.mean(),
        'stdev_latency': latency.std(),
      }
    summary |= {f"p{100*p}_latency": latency.quantile(p) for p in iles}
    summary |= {
          'mean_utilization': utilization.mean(),
          'stdev_utilization': utilization.std(),
          'mean_stdev_utilization': stddev_utilization_per_server.mean(),
      }
    summary |= {f"p{100*p}_utilization_ratio": utilization.quantile(p) / utilization.mean() for p in iles}
    return [summary]
  except Exception as e:
    print(f"Failed {scenario}: {e}")

previous_results = []
try:
  previous_results = results
except NameError:
  pass
results = list(ooo_process_map(runTest, zip(list(scenarios.keys()) * TEST_RUNS, list(scenarios.values()) * TEST_RUNS)))
results = previous_results + [item for sublist in results for item in sublist]

df = pd.DataFrame.from_dict(results).set_index('algorithm')
# df

  0%|          | 0/48 [00:00<?, ?it/s]

In [203]:
df = df.drop('kwargs', axis=1)

In [215]:
def mean_plus_std(s):
  return s.mean() + s.std()

adf = df.groupby(by='algorithm').agg([('mean across test runs', np.mean), ('variance across test runs', np.var), ('reasonably bad run', mean_plus_std), ('worst run', np.max)])
# adf = adf.drop('FastestProbeAll')


The provided callable <function mean at 0x7a83915015a0> is currently using SeriesGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead.


The provided callable <function var at 0x7a83915017e0> is currently using SeriesGroupBy.var. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "var" instead.


The provided callable <function max at 0x7a8391500ca0> is currently using SeriesGroupBy.max. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "max" instead.


The provided callable <function mean at 0x7a83915015a0> is currently using SeriesGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead.



In [216]:
scores = []
n = 5
def stringify(x):
  return {1: 'very bad', 2: 'bad', 3: 'ok', 4: 'good', 5: 'very good'}.get(x, x)

for col in [
    'mean_latency',
    'stdev_latency', 'p99.0_latency',
    'mean_stdev_utilization',
    'stdev_utilization'
    ]:
  scores.append(pd.qcut(adf[(col, 'reasonably bad run')], n, labels=list(range(1, n+1))[::-1]).astype(int))

scores = pd.concat(scores, axis=1)
scores.columns = scores.columns.droplevel(1)
scores['total_score'] = scores.sum(axis=1)
scores = scores.sort_values('total_score', ascending=False)
scores.map(stringify).reset_index().style.background_gradient(cmap='Blues_r')

Unnamed: 0,algorithm,mean_latency,stdev_latency,p99.0_latency,mean_stdev_utilization,stdev_utilization,total_score
0,ProbingWeightedMeanUtilizationLowerRequestRate,very good,very good,very good,very good,very good,25
1,ProbingWeightedMeanCapacityLowerRequestRate,very good,very good,very good,very good,very good,25
2,CheatWeightedMeanAvailableCapacity,very good,very good,very good,ok,very good,23
3,LeastClientRequestsInFlightPickTwo,good,very good,very good,very good,good,23
4,LeastClientRequestsInFlightPickThree,very good,very good,good,very good,ok,22
5,LeastClientRequestsInFlight,good,very good,very good,very good,ok,22
6,ProbingWeightedMeanAvailableCapacity,very good,good,very good,ok,very good,22
7,ProbingClientMetricsHclLowerRequestRate,very good,very good,very good,good,ok,22
8,ProbingWeightedMeanAvailableCapacityLowerRequestRate,good,good,good,good,very good,21
9,WeightedMeanAvailableCapacity,very good,good,very good,bad,very good,21


In [217]:
print("## mean latency for test run")
# print("mean=average latency")
# print("std=variation in latency between test runs - higher means that some runs were much better than others")
# print("mean_plus_std=latency of reasonably bad test run")
# print("mean_plus_std=latency of worst test run")
px.box(df.reset_index().sort_values('mean_latency', ascending=False), y="algorithm", x="mean_latency", orientation='h', height=1000).show()
# px.box(df.reset_index().sort_values('mean_latency', ascending=False), y="algorithm", x="mean_latency", orientation='h').show()
# adf['mean_latency'].sort_values('reasonably bad run').reset_index().style.background_gradient(cmap='Blues').format(precision=1)

## mean latency for test run


In [218]:
print("## latency variability during test run")
# print("mean=variability (stdev) of latency across all test runs")
# print("std=variation in test run variability - higher means that some runs had a much bigger latency distribution")
# print("mean_plus_std=latency variability of reasonably bad test run")
# print("mean_plus_std=latency variability of worst test run")
px.box(df.reset_index().sort_values('stdev_latency', ascending=False), y="algorithm", x="stdev_latency", orientation='h', height=1000).show()
# adf['stdev_latency'].sort_values('reasonably bad run').reset_index().style.background_gradient(cmap='Blues').format(precision=1)

## latency variability during test run


In [226]:
print("## p99 latency during test run")
# print("mean=p99 latency across all test runs")
# print("std=variation in p99 latency between test runs - higher means that some runs had very different p99 latency")
# print("mean_plus_std=p99 latency of reasonably bad test run")
# print("mean_plus_std=p99 latency of worst test run")
px.box(df.reset_index().sort_values('p99.0_latency', ascending=False), y="algorithm", x="p99.0_latency", orientation='h', log_x=True, height=1000).show()
# adf['p99.0_latency'].sort_values('reasonably bad run').reset_index().style.background_gradient(cmap='Blues').format(precision=1)

## p99 latency during test run


In [220]:
print("## average utilization during test run")
print("This does not necessarily reflect performance of algorithms.")
print("Lower can mean that the algorithm is letting some servers get overloaded and letting others go idle.")
print("Higher can mean that the algorithm is adding load using probing requests.")

px.box(df.reset_index().sort_values('mean_utilization', ascending=False), y="algorithm", x="mean_utilization", orientation='h', height=1000).show()
# adf[('mean_utilization', 'mean across test runs')].sort_values().reset_index().style.background_gradient(cmap='Blues').format(precision=1)

## average utilization during test run
This does not necessarily reflect performance of algorithms.
Lower can mean that the algorithm is letting some servers get overloaded and letting others go idle.
Higher can mean that the algorithm is adding load using probing requests.


In [221]:
print("## utilization variability within servers during test run")
print("This is high if utilization for a server changes often (e.g cycles of overload then backoff)")
print("This is low if each server has consistent utilization, even if each server has very different utilization")
# print("mean=average utilization across all test runs")
# print("std=variation in utilization between test runs - higher means that some runs had very different mean utilization")
# print("mean_plus_std=utilization of reasonably bad test run")
# print("mean_plus_std=utilization of worst test run")
px.box(df.reset_index().sort_values('mean_stdev_utilization', ascending=False), y="algorithm", x="mean_stdev_utilization", orientation='h', height=1000).show()
# adf['mean_stdev_utilization'].sort_values('reasonably bad run').reset_index().style.background_gradient(cmap='Blues').format(precision=1)

## utilization variability within servers during test run
This is high if utilization for a server changes often (e.g cycles of overload then backoff)
This is low if each server has consistent utilization, even if each server has very different utilization


In [222]:
print("## utilization variability of all servers during test run")
print("This is high if some servers are getting an unfair amount of traffic. High variability means that overall utilization needs to be lowered to avoid the impact of the small number of serves with high utilization")
print("Unlike the above metric, it's high if the traffic is consistently unfair")
# print("mean=utilization variability across all test runs")
# print("std=variation in variability between test runs - higher means that some runs had a very different distribution of utilization")
# print("mean_plus_std=utilization of reasonably bad test run")
# print("mean_plus_std=utilization of worst test run")
px.box(df.reset_index().sort_values('stdev_utilization', ascending=False), y="algorithm", x="stdev_utilization", orientation='h', height=1000).show()
# adf['stdev_utilization'].sort_values('reasonably bad run').reset_index().style.background_gradient(cmap='Blues').format(precision=1)

## utilization variability of all servers during test run
This is high if some servers are getting an unfair amount of traffic. High variability means that overall utilization needs to be lowered to avoid the impact of the small number of serves with high utilization
Unlike the above metric, it's high if the traffic is consistently unfair


In [223]:
print(df.to_csv())

algorithm,mean_latency,stdev_latency,p50.0_latency,p90.0_latency,p99.0_latency,p99.9_latency,mean_utilization,stdev_utilization,mean_stdev_utilization,p50.0_utilization_ratio,p90.0_utilization_ratio,p99.0_utilization_ratio,p99.9_utilization_ratio
Random,14.43457300275482,12.01156071987014,11.0,21.0,81.49000000000001,87.5490000000002,54.489359512042746,23.74289685024239,4.878221398763135,0.9808938174480893,1.613600825473676,1.8170995843602307,1.8319066028600866
LeastServerRequestsInFlight,12.28186968838527,7.284723579912673,11.0,19.0,46.0,49.0,52.51386692430213,24.384802402467226,18.958811718025437,0.9685454898766848,1.6574105495316025,1.864461836152653,1.9002195921264262
WeightedServerRequestsInFlight,11.791836734693877,8.038212940972812,10.0,18.0,52.0,65.59300000000053,51.64574765590668,25.564177087606936,19.005452201831115,0.9585484331620023,1.720756031897668,1.8922315722258205,1.9318637748135488
LeastServerExpectedLatency,21.816204051012754,14.169500776196065,16.0,43.0,58.0,75.0,49.

In [224]:
print(adf.to_csv())

,mean_latency,mean_latency,mean_latency,mean_latency,stdev_latency,stdev_latency,stdev_latency,stdev_latency,p50.0_latency,p50.0_latency,p50.0_latency,p50.0_latency,p90.0_latency,p90.0_latency,p90.0_latency,p90.0_latency,p99.0_latency,p99.0_latency,p99.0_latency,p99.0_latency,p99.9_latency,p99.9_latency,p99.9_latency,p99.9_latency,mean_utilization,mean_utilization,mean_utilization,mean_utilization,stdev_utilization,stdev_utilization,stdev_utilization,stdev_utilization,mean_stdev_utilization,mean_stdev_utilization,mean_stdev_utilization,mean_stdev_utilization,p50.0_utilization_ratio,p50.0_utilization_ratio,p50.0_utilization_ratio,p50.0_utilization_ratio,p90.0_utilization_ratio,p90.0_utilization_ratio,p90.0_utilization_ratio,p90.0_utilization_ratio,p99.0_utilization_ratio,p99.0_utilization_ratio,p99.0_utilization_ratio,p99.0_utilization_ratio,p99.9_utilization_ratio,p99.9_utilization_ratio,p99.9_utilization_ratio,p99.9_utilization_ratio
,mean across test runs,variance across test runs,re