# How to Scale your Model

Link: https://jax-ml.github.io/scaling-book/

### Types of time

There are two types of time

- **computation time** - how much time accelerator spends on compute

![image.png](attachment:image.png)

Computation FLOPs - how many FLOPs is needed to perform computation

Accelerator FLOPs/s - how many FLOPs accelerator can perform in one second

- **communication time** - how much time accelerator spends on sending/receiving data

![image-2.png](attachment:image-2.png)

Communication Bytes - how many bytes is needed to read/write/send to another block inside accelerator or send to another accelerator

Network/Memory Bandwidth - it is what it is

### Upper and Lower bounds

Typically (but not always), computation within a single chip can be overlapped with communication within a chip and between chips. This means we can lower-bound training and inference time by using the maximum of computation and communication time. The upper bound is sum of these two times:

![image-3.png](attachment:image-3.png)

Usually we use optimize against the maximum as the algebra is simpler and we can usually come close to this bound by overlapping our communication and computation.

### Compute and communication bound

T<sub>math</sub> > T<sub>comms</sub> - full utilization from our hardware, we that this operation is **“compute-bound”** and this is desired situation

T<sub>math</sub> < T<sub>comms</sub> - underutilization because some fraction of our accelerator FLOPs/s is wasted waiting for data to be passed around, operation is **“communication-bound"**

### Arithmetic Intensity

 One way to tell if an operation will be compute or communication-bound is to look at its “arithmetic intensity” or “operational intensity”.

 **Definition**: the arithmetic intensity of an algorithm is given by the ratio of the total FLOPs it performs to the number of bytes it needs to communicate — either within a chip or between chips.

 ![image-4.png](attachment:image-4.png)

Accelerator has its own arithmetic intensity too and we can derive the relation between this intensities from relation between times:

![image-5.png](attachment:image-5.png)

So, if we want our accelerator to be fully utilized by computing algorithm, we need intensity of this algorithm to be more than intensity of accelerator.


### Visualizing rooflines

We can visualize the tradeoff between memory and compute using a roofline plot, which plots the peak achievable FLOPs/s (throughput) of an algorithm on our hardware (the y-axis) against the arithmetic intensity of that algorithm (the x-axis). Here’s an example log-log plot:

![image-6.png](attachment:image-6.png)

### Examples

1. Matrix multiplication, one TPU

X∗Y→Z where X has shape bf16[B,D]bf16[B,D] , Y has shape bf16[D,F]bf16[D,F] , and Z has shape bf16[B,F]bf16[B,F]

To do the matmul we need to load 2DF+2BD bytes, perform 2BDF FLOPs, and write 2BF bytes back.

![image.png](attachment:image.png)

Assume that B (batch size) << D and F, then, assuming that our TPU has 1.97e14 compute power and 8.2e11 bandwidth:

![image-2.png](attachment:image-2.png)

This is a reasonable assumption for Transformer matmuls since we typically have a local (per-replica) batch size B < 1024 tokens (not sequences) but D and F > 8000. Thus we generally become compute-bound when our per-replica batch size is greater than 240 tokens.

2. Matrix multiplication, two TPUs

Then we split half matrices X and Y by axis D:

A = X[:, :D // 2] @ Y[:D // 2, :] on TPU 0 and B = X[:, D // 2:] @ Y[D // 2:, :] on TPU 1

Then Z = A + B, but we won't consider final addition in our calculation.

T<sub>math</sub>​ is clearly half of what it was before, since each TPU is doing half the work, i.e.:

![image-3.png](attachment:image-3.png)

Now what about Tc<sub>comms</sub>? This now refers to the communication time between chips. This is just the total bytes sent divided by the network bandwidth, i.e.:

![image-4.png](attachment:image-4.png)

Then matmul intensity = BDF / 2 BF = D / 2 => to become compute bound we need D / 2 > 1.97e14/4.5e10 = 4377 => D > 8755.

Note that, unlike before, the critical threshhold now depends on D and not B.

## Sharding

### Sharding notation

![image.png](attachment:image.png)

 - Mesh - shape of array of devices, example: grid 2 x 2 -> Mesh(devices=((0, 1), (2, 3)), axis_names=(‘X', ‘Y'))
 - Sharding: data array A[ I<sub>X</sub>, J<sub>Y</sub> ], which tells us to shard the first axis, I, along the mesh axis X, and the second axis, J, along the mesh axis Y. This sharding tells us that each shard holds 1/(∣X∣⋅∣Y∣) of the array. Sharding axis and mesh axis are independent, e.g. axis I can be sharded along both X and Y axis and sharding notation will be this: A[I<sub>XY</sub>​,J].
 
 Here's all possible shardings for 2D mesh of 4x4 devices and 2D data array:
 
![image-2.png](attachment:image-2.png)
 
 I<sub>XY</sub> and I<sub>YX</sub> are different shardings
 
![image-3.png](attachment:image-3.png)
  
  ### Types of communication primitives that are used in sharding
  
  1. **AllGather** - removes the sharding along an axis and reassembles the shards spread across devices onto each device along that axis. Using the notation above, an AllGather removes a subscript from a set of axes, e.g.:
  
![image-4.png](attachment:image-4.png)
  
  **AllGather visualization**

In [5]:
from IPython.display import HTML

# Embed GIF using HTML
HTML('<img src="https://jax-ml.github.io/scaling-book/assets/img/all-gather.gif" width="450" height="250"/>')

  AllGather computation time. Let V be the number of bytes in the array, and X be the number of shards on the contracting dimension. Then from the above diagram, each hop sends V/∣X∣ bytes in each direction, so each hop takes
  
  ![image.png](attachment:image.png)
  
  where W<sub>ici</sub> is the bidirectional ICI (Inter Chip Connection) bandwidth. We need to send a total of ∣X∣/2∣ hops to reach every TPU, so the total reduction takes
  
  ![image-2.png](attachment:image-2.png)
  
  Note that AllGather time doesn’t depend on X!
  
  Here is an empirical measurement of AllGather bandwidth on a TPU v5e 8x16 slice. The array is sharded across the second (16) axis so it has a full bidirectional ring.
  
  ![image-3.png](attachment:image-3.png)
  
  What happens when we AllGather over multiple axes? When we gather over multiple axes, we have multiple dimensions of ICI over which to perform the gather. For instance, 
  AllGather<sub>XY</sub>([B, D<sub>XY</sub>]) operates over two hardware mesh axes. This increases the available bandwidth by a factor of N<sub>axes</sub>.

  When considering latency, we end up with the general rule:
  
  ![image-4.png](attachment:image-4.png)
  
  2. **ReduceScatter** - removes a subscript ( F<sub>Y</sub>→F above), a ReduceScatter sums an unreduced/partially summed array and then scatters (shards) a different logical axis along the same mesh axis. [F]{U<sub>Y</sub>}→[F<sub>Y</sub>​]. The animation shows how this is done: note that it’s very similar to an AllGather but instead of retaining each shard, we sum them together. Thus, its latency is roughly the same, excluding the time taken to perform the reduction.
  
  **ReduceScatter visualization**

In [6]:
from IPython.display import HTML

# Embed GIF using HTML
HTML('<img src="https://jax-ml.github.io/scaling-book/assets/img/reduce-scatter.gif" width="450" height="250"/>')

  ReduceScatter is an opposite operation to AllGather. Likewise, ![image.png](attachment:image.png) in the forward pass implies ![image-2.png](attachment:image-2.png) in the backwards pass.
  
  ReduceScatter computation time. The communication time for each hop is simply the per-shard bytes V/Y divided by the bandwidth W<sub>ici</sub>, as it was for an AllGather, so we have:
  
  ![image-3.png](attachment:image-3.png)
  
  3. **AllReduce** - removes partial sums, resulting in each device along the axis having the same fully-summed value. AllReduce takes an array with an unreduced (partially summed) axis and performs the sum by passing those shards around the unreduced axis and accumulating the result. The signature is:
  
  ![image-4.png](attachment:image-4.png)

  AllReduce computation time. One mental model for how an AllReduce is performed is that every device sends its shard to its neighbors, and sums up all the shards that it receives. Clearly, this is more expensive than an AllGather because each “shard” has the same shape as the full array. Generally, an AllReduce is twice as expensive as an AllGather. One way to see this is to note that an AllReduce can be expressed as a composition of two other primitives: a ReduceScatter and an AllGather.
  
  4. **AllToAll** - the special case of a sharded transposition or resharding operation. e.g.
  
  ![image-5.png](attachment:image-5.png)
  
  AllToAlls are typically required to rearrange sharded layouts between different regions of a sharded computation that don’t have compatible layout schemes. They arise naturally when considering sharded mixture-of-experts models. You can think of an AllToAll as moving a subscript from one axis to another.
  
  **AllToAll visualization**

In [10]:
from IPython.display import HTML

# Embed GIF using HTML
HTML('<img src="https://jax-ml.github.io/scaling-book/assets/img/all-to-all.gif" width="750" height="250"/>')

  AllToAll computation time. Because an all to all doesn’t need to replicate all of the data of each shard across the ring, it’s actually cheaper than an AllGather (by a factor of ¼).
  
  Explanation:
  
![image.png](attachment:image.png)
  
  ### Computation With Sharded Arrays
  
 The most common operation in machine learning is matrix multiplication
 
 There are 4 cases of sharded matrices multiplication
 
 1. **Case 1**: neither input is sharded along the contracting dimension. We can multiply local shards without any communication.
 
  ![image-2.png](attachment:image-2.png)
 
 2. Case 2: one input has a sharded contracting dimension. We typically “AllGather” the sharded input along the contracting dimension, because we can't multiply matrices when we have two different contracting dimensions.
 
  ![image-3.png](attachment:image-3.png)
 
 3. Case 3: both inputs are sharded along the contracting dimension. We can multiply the local shards, then “AllReduce” the result.
 
  ![image-4.png](attachment:image-4.png)
 
 4. Case 4: both inputs have a non-contracting dimension sharded along the same axis. We cannot proceed without AllGathering one of the two inputs first.
 
  ![image-6.png](attachment:image-6.png)

# Transformers

![image.png](attachment:image.png)

## Matmul forward and backward FLOPs calculation

Let’s start with vectors x,y and matrices A,B of the following shapes:

![image.png](attachment:image.png)

![image-2.png](attachment:image-2.png)

In [4]:
from IPython.display import HTML

# Embed GIF using HTML
HTML('<img src="https://jax-ml.github.io/scaling-book/assets/img/matmul-flops.gif" width="600" height="300"/>')

During forward path we have 2NPM flops. During backward path we have to calculate gradient for both matrices A and B:

![image.png](attachment:image.png)

![image-2.png](attachment:image-2.png)

Where L - loss and C = A @ B.

Both of these operations take 2NPM FLOPs too since dL/dB and dL/dC have shapes [N, M] and contracting dimensions are N and M respectively.

Adding these up, we see that **during training, we have a total of 6NPM FLOPs**, compared to 2NPM during inference: 2NPM in the forward pass, 4NPM in the backward pass. 

Since PM is the number of parameters in the matrix, this is the simplest form of the famous 6 \* num parameters \* num tokens. Each token resquires 6 \* num parameters FLOPs.

## Transformer FLOPs and Params Calculation

### MLP

The MLPs of a Transformer typically consist of 2 input matmuls that are element-wise combined and a single output matmul:

![image.png](attachment:image.png)

### Attention

Estimated cost of the QKVO matmuls:

![image-2.png](attachment:image-2.png)

Estimated cost of the dot-product attention:

![image-3.png](attachment:image-3.png)

### Other operations (layer norm and vocab)

![image-4.png](attachment:image-4.png)

Total vocab train FLOPs are 12BTDV

### General rule of thumb for Transformer FLOPs

If we neglect the cost of dot-product attention for shorter-context training, then the total FLOPs across all layers is:

![image-5.png](attachment:image-5.png)

### Fractional cost of attention with context length

If we do account for dot-product attention above and assume F = 4D, D = NH and N = K:

![image-6.png](attachment:image-6.png)

So the takeaway is that dot-product attention FLOPs only become dominant during training once T>8D. 

For D ~ 8k, this would be ~64K tokens. This makes some sense, since it means as the MLP size increases, the attention FLOPs become less critical. 

For large models, the quadratic cost of attention is not actually a huge obstacle to longer context training. 

However, for smaller models, even e.g. Gemma-27B, D=4608 which means attention becomes dominant around 32k sequence lengths. 

Flash Attention also helps alleviate the cost of long-context.

## Key-Value (KV) caching

Each KV cache is then effectively an array of size [2,S,L,K,H]. This is quite large! 

The total size of the Key-Value cache in int8 is 2SLKH. 

For a moderately-sized model with 8k context length, 64 layers, and KH=NH=D=8192, this is 2⋅8192⋅64⋅8192=8GiB. 

You can see why we would want to use Groupped Query Attention with K≪N.

# How to Parallelize a Transformer for Training

We’ll use the following notation to simplify calculations throughout this section.

![image.png](attachment:image.png)

For simplicity’s sake, we’ll approximate a Transformer as a stack of MLP blocks — attention is a comparatively small fraction of the FLOPs for larger models

![image-2.png](attachment:image-2.png)

Here’s the full algorithm for our little Transformer with no parallelism.

![image-3.png](attachment:image-3.png)

## Parallelism schemes

### Data parallelism

Activations sharded along batch, parameters and optimizer state are replicated on each device. Communication only occurs during the backwards pass.

![image-4.png](attachment:image-4.png)

When your model fits on a single chip with even a tiny batch size (>240 tokens, so as to be compute-bound), **you should always use simple data parallelism**.

Pure data parallelism splits our activations across any number of TPUs so long as the number of TPUs is smaller than our batch size. 

 The forward pass involves no communication, but at the end of every step, **each TPU performs an AllReduce on its local gradients to synchronize them before updating the parameters.**

 ![image-5.png](attachment:image-5.png)

 Here’s the full algorithm for the forward and backwards pass. We abuse notation to write dL/dOut as dOut, purely for compactness.

 ![image-6.png](attachment:image-6.png)

Note that the forward pass has no communication — it’s all in the backward pass! The backward pass also has the great property that the AllReduces aren’t in the “critical path”, meaning that each AllReduce can be performed whenever it’s convenient and doesn’t block you from performing subsequent operations. The overall communication cost can still bottleneck us if it exceeds our total compute cost, but it is much more forgiving from an implementation standpoint. We’ll see that model/tensor parallelism doesn’t have this property.

**Why do this?** Pure data parallelism reduces activation memory pressure by splitting our activations over the batch dimension, allowing us to almost arbitrarily increase batch size as long as we have more chips to split the batch dimension over. Especially during training when our activations often dominate our memory usage, this is very helpful.

Why not do this? Pure data parallelism does nothing to reduce memory pressure from model parameters or optimizer states, which means pure data parallelism is rarely useful for interesting models at scale where our parameters + optimizer state don’t fit in a single TPU. To give a sense of scale, if we train with parameters in bf16 and optimizer state in fp32 with Adam, the largest model we can fit has TPU memory / 10, so e.g. on a TPUv5p chip with 96GB of HBM and pure data parallelism this is about 9B parameters.

#### When do we become bottlenecked by communication?

![image-8.png](attachment:image-8.png)

![image-9.png](attachment:image-9.png)

Total time:

![image-10.png](attachment:image-10.png)

We become compute-bound when T_math > T_comms, or when

![image-7.png](attachment:image-7.png)

The upshot is that, to remain compute-bound with data parallelism, we need the per-device batch size B / X to exceed the ICI operational intensity, C / W_ici.

For example, for TPUv5p, C=4.6e14, W_ici=2 * 9e10 for 1D data parallelism our batch size per chip **must be at least 2,550 to avoid being communication-bound**.

Since we can do data parallelism over multiple axes, if we dedicate all three axes of a TPUv5p pod to pure data parallelism, we 3x our bandwidth W_ici and can scale down to only BS=850 per TPU or 7.6M tokens per batch per pod (of 8960 chips). **This tells us that it’s fairly hard to become bottlenecked by pure data parallelism!**

### Fully-Sharded Data Parallelism (FSDP)

Fully-sharded data parallelism (often called FSDP or ZeRO-sharding because of related work: "ZeRO: Memory optimizations toward training Trillion parameter models")
splits the model optimizer states and weights across the data parallel shards and efficiently gathers and scatters them as needed. **Compared to pure data parallelism, FSDP drastically reduces per-device memory usage and saves on backward pass FLOPs, with very minimal overhead.**

![image.png](attachment:image.png)

Here’s the full algorithm for FSDP:

![image-2.png](attachment:image-2.png)

This is also called “ZeRO Sharding”, from “ZeRo Overhead sharding” since we don’t perform any unnecessary compute or store any unnecessary state. ZeRO-{1,2,3} are used to refer to sharding the optimizer states, gradients, and weights in this way, respectively. Since all have the same communication cost, we can basically always do ZeRO-3 sharding, which shards the parameters, gradients, and optimizer states across a set of devices.

**Why do this?** Standard data parallelism involves a lot of duplicated work. Each TPU AllReduces the full gradient, then updates the full optimizer state (identical work on all TPUs), then updates the parameters (again, fully duplicated). For ZeRO sharding (sharding the gradients/optimizer state), instead of an AllReduce, you can ReduceScatter the gradients, update only your shard of the optimizer state, update a shard of the parameters, then AllGather the parameters as needed for your forward pass.

**When do we become bottlenecked by communication?** Our relative FLOPs and comms costs are exactly the same as pure data parallelism, since each AllReduce in the backward pass has become an AllGather + ReduceScatter.

Here we model the forward pass since it has the same FLOPs-to-comms ratio as the backward pass:

![image-3.png](attachment:image-3.png)

Therefore, as with pure data-parallelism, we are compute bound when

![image-4.png](attachment:image-4.png)

 i.e. when the per-device batch size B / X exceeds the “ICI operational intensity” C / W_ici (~ 2550 for TPUv5p).

 This is great for us, because it means if our per-device batch size is big enough to be compute-bound for pure data-parallelism, we can — without worrying about leaving the compute-bound regime — simply upgrade to FSDP, saving ourselves a massive amount of parameter and optimizer state memory.

 Though we did have to add communication to the forward pass, this cost is immaterial since it just overlaps with forward-pass FLOPs.