In [2]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt


Let's define some variables:

* $B$: batch size,
* $n_L$: number of layers per "node",
* $N$: total number of model layers,
* $S_\text{kv}$: Maximum size of the KV-cache per prompt per layer.

Assume network latency is negligible.

We define $t(B,n)$ as the time to push a batch of size $B$ through $n$ layers. $t(B, n) \approx n \cdot t(B, 1)$.

Throughput would be
$$T = \frac{B}{t(B, n_L)} =\frac{B}{n_L \cdot t(B)}$$

Latency of a token going through all layers is
$$l = \frac{N}{n_L} \cdot t(B, n_L) = N \cdot t(B)$$

Therefore, the number of active prompts needed to keep the pipeline full is
$$T \cdot l = \frac{B}{n_L \cdot t(B)} \cdot N \cdot t(B) = B\cdot \frac{N}{n_L}\quad\text{(Little's law)}$$ 

Amount of memory necessary for storing KV-cache on each node is
$$M_\text{kv} = B \cdot \frac{N}{n_L} \cdot (n_L \cdot S_\text{kv}) = B \cdot N \cdot S_\text{kv}$$

Therefore, the memory requirement for each node is only dependent on batch size and number of model layers.



Goal is maximize $T$ subject to $M_\text{kv} \le \text{node mem}$.

$$B_{opt} = \underset{B}{\mathrm{argmax}} \left\{ \frac{B}{n_L \cdot t(B)} \right\}, \quad \text{s.t. }\,\, B \le \frac{\text{node mem}}{N\cdot S_\text{kv}}$$ 

## Some conlusions 

* For Llama2-70B ($N=80$, $S_\text{kv}=16\,\text{MB}$) on a node with $100\,\text{GB}$ memory for the KV-cache, $B \le 80$.

* Based on imperical data, $\frac{B}{t(B)}$ always increases with $B$, so $$B_{opt} = \frac{\text{node mem}}{N\cdot S_\text{kv}}$$

* Given a fixed number of nodes, total throughput is not dependent on the number of layers on a single node. For example, if we have $80$ nodes and $n_L = 2$, for one instance of the model we use 40 nodes with half the throughput of $n_L = 1$, but we can run two model instances concurrently.