# Transformer Theoretical Analysis on Estimation of FLOPs, Parameters, Peak Memory Footprint, and Checkpoint Size

[![Twitter Handle](https://img.shields.io/badge/Twitter-@gaohongnan-blue?style=social&logo=twitter)](https://twitter.com/gaohongnan)
[![LinkedIn Profile](https://img.shields.io/badge/@gaohongnan-blue?style=social&logo=linkedin)](https://linkedin.com/in/gao-hongnan)
[![GitHub Profile](https://img.shields.io/badge/GitHub-gao--hongnan-lightgrey?style=social&logo=github)](https://github.com/gao-hongnan)

```{contents}
:local:
```

This notebook references from
[Andrej Karpathy's NanoGPT](https://github.com/karpathy/nanoGPT/blob/master/transformer_sizing.ipynb),
which originally stores a bunch of analysis about a Transformer, e.g. estimates
the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc.

In [8]:
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Literal

import pandas as pd
from rich.pretty import pprint
from tabulate import tabulate
from torch import nn
from transformers import GPT2LMHeadModel

## Configurations, Constants and Enums

In [9]:
@dataclass
class GPTConfig:
    num_decoder_blocks: int = 12
    context_length: int = 1024
    n_embd: int = 768
    ffw_size: int = 3072  # note, this is 4 * n_embd
    n_head: int = 12
    vocab_size: int = 50257
    bias: Literal[False] = False

    def __post_init__(self) -> None:
        assert self.ffw_size == 4 * self.n_embd, "ffw_size must be 4 * n_embd"
        assert self.bias is False, "bias must be False in this experiment."

class GPT2ModelType(Enum):
    GPT2 = "gpt2"
    GPT2_MEDIUM = "gpt2-medium"
    GPT2_LARGE = "gpt2-large"
    GPT2_XL = "gpt2-xl"

class ByteUnits(IntEnum):
    B = 1           # Byte = 1 byte
    KB = 1000       # Kilobyte = 10^3 bytes
    MB = 1000**2    # Megabyte = 10^6 bytes
    GB = 1000**3    # Gigabyte = 10^9 bytes


class FloatingPointPrecision(IntEnum):
    FP32 = 4        # 32-bit floating-point, 4 bytes
    FP16 = 2        # 16-bit floating-point, 2 bytes
    BFLOAT16 = 2    # bfloat16, 16-bit, 2 bytes

class GPUMemory(Enum):
    A100_40GB = 40e9  # 40 GB for NVIDIA A100
    V100_16GB = 16e9  # 16 GB for NVIDIA V100
    V100_32GB = 32e9  # 32 GB for NVIDIA V100
    T4_16GB = 16e9    # 16 GB for NVIDIA T4
    P100_16GB = 16e9  # 16 GB for NVIDIA P100

## Total Trainable Parameters

In [10]:
def total_trainable_parameters(model: nn.Module, include_bias: bool = True) -> int:
    """Returns the number of trainable parameters in the model."""
    if not include_bias:
        return sum(p.numel() for name, p in model.named_parameters() if p.requires_grad and 'bias' not in name)
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [11]:
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2ModelType.GPT2.value)

In [12]:
gpt2_params_no_bias = total_trainable_parameters(gpt2, include_bias=False)
gpt2_params_with_bias = total_trainable_parameters(gpt2, include_bias=True)

print(f"Number of trainable parameters in GPT2 model: {gpt2_params_no_bias} (excluding bias) and {gpt2_params_with_bias} (including bias).")

Number of trainable parameters in GPT2 model: 124337664 (excluding bias) and 124439808 (including bias).


Since Karpathy's blog post assumed that there is no bias for simplicity, we will
also assume that there is no bias in the linear layers. We confirmed that the
number of params (`124337664`) for the smallest GPT-2 model indeed matches the
number of params given by Karpathy.

In what follows, we would assume the smallest GPT-2 model and work out the
theoretical model for the Transformer.

In [13]:
# config_args = {
#     'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
#     'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
#     'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
#     'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
# }[model_type]

def params(
    num_decoder_blocks: int = 12,
    context_length: int = 1024,
    n_embd: int = 768,
    ffw_size: int = 3072,
    vocab_size: int = 50257,
) -> OrderedDict[str, int]:
    """estimates the number of parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out["embedding/position"] = n_embd * context_length
    out["embedding/token"] = n_embd * vocab_size
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    # attention blocks
    out["attention/ln"] = n_embd  # note, bias=False in our LN
    out["attention/kqv"] = n_embd * 3 * n_embd
    out["attention/proj"] = n_embd**2
    out["attention"] = out["attention/ln"] + out["attention/kqv"] + out["attention/proj"]

    # MLP blocks
    assert ffw_size == 4 * n_embd, "ffw_size must be 4 * n_embd"
    out["mlp/ln"] = n_embd
    out["mlp/ffw"] = n_embd * ffw_size
    out["mlp/proj"] = ffw_size * n_embd
    out["mlp"] = out["mlp/ln"] + out["mlp/ffw"] + out["mlp/proj"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = num_decoder_blocks * out["block"]
    out["ln_f"] = n_embd  # final layernorm
    out["dense"] = 0  # 0 because of parameter sharing. This layer uses the weights from the embedding layer

    # total
    out["total"] = out["embedding"] + out["transformer"] + out["ln_f"] + out["dense"]

    return out


In [14]:
params_dict = params()
gpt2_params_no_bias_manual = params_dict["total"]

# Compare to expected PyTorch model parameter count
expected_params = gpt2_params_no_bias
comparison_result = gpt2_params_no_bias_manual == expected_params
comparison_msg = f"We see: {gpt2_params_no_bias_manual}, Expected: {expected_params}, Match: {comparison_result}"

data = {
    "Name": params_dict.keys(),
    "Parameters": params_dict.values(),
    "Ratio (%)": [value / gpt2_params_no_bias_manual * 100 for value in params_dict.values()],
}
df = pd.DataFrame(data)

# Printing comparison result and parameter distribution table
print(comparison_msg + "\n")
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))


We see: 124337664, Expected: 124337664, Match: True

+--------------------+------------+-----------------------+
|        Name        | Parameters |       Ratio (%)       |
+--------------------+------------+-----------------------+
| embedding/position |   786432   |  0.6324970042866496   |
|  embedding/token   |  38597376  |  31.042384711361475   |
|     embedding      |  39383808  |  31.674881715648123   |
|    attention/ln    |    768     | 0.0006176728557486812 |
|   attention/kqv    |  1769472   |  1.4231182596449616   |
|   attention/proj   |   589824   |  0.47437275321498723  |
|     attention      |  2360064   |  1.8981086857156975   |
|       mlp/ln       |    768     | 0.0006176728557486812 |
|      mlp/ffw       |  2359296   |   1.897491012859949   |
|      mlp/proj      |  2359296   |   1.897491012859949   |
|        mlp         |  4719360   |   3.795599698575646   |
|       block        |  7079424   |   5.693708384291344   |
|    transformer     |  84953088  |   68.324500

## Calculating Checkpoint Size and Fluff Ratio 

The functions below perform a series of calculations related to the size
of a GPT-2 model checkpoint, both measured and estimated, and computes the
"fluff ratio" to compare these sizes. The purpose of these calculations is to
evaluate how closely the estimated size of a GPT-2 model checkpoint matches the
actual measured size, and to quantify any overhead or additional data in the
checkpoint file as a percentage of the estimated size.

In [15]:
def calculate_checkpoint_size(params_count: int, precision: FloatingPointPrecision, units: ByteUnits) -> float:
    """
    Calculate the estimated checkpoint size in specified units.

    This function estimates the checkpoint size for a model given the number
    of parameters, the precision of these parameters, and
    the desired units for the result. It accounts for the AdamW optimizer's
    storage requirements by adding two times the parameter bytes to account
    for the optimizer's moment and velocity vectors.

    Parameters
    ----------
    params_count : int
        The number of parameters excluding biases.
    precision : FloatingPointPrecision
        The floating point precision of the parameters.
    units : ByteUnits
        The units for the resulting checkpoint size.

    Returns
    -------
    float
        The estimated checkpoint size in the specified units.

    Notes
    -----
    The AdamW optimizer requires additional storage for each parameter
    for maintaining momentum and variance vectors, hence the calculation
    includes 2 * params_bytes to accommodate these.
    """
    params_bytes = params_count * precision.value
    params_and_buffers_bytes = params_bytes + 2 * params_bytes  # AdamW optimizer buffers
    return params_and_buffers_bytes / units.value


def calculate_fluff_ratio(measured_bytes: int, estimated_bytes: float, units: ByteUnits) -> float:
    """
    Calculate the fluff ratio between measured and estimated checkpoint sizes.

    The fluff ratio is a measure of the overhead or additional data in the
    checkpoint file, expressed as a percentage of the estimated size. This
    function converts the estimated size from gigabytes (or specified units)
    to bytes before calculating the ratio to ensure consistency in units.

    Parameters
    ----------
    measured_bytes : int
        The actual size of the checkpoint file, in bytes.
    estimated_bytes : float
        The estimated size of the checkpoint file, in the specified units.
    units : ByteUnits
        The units in which the estimated bytes are provided.

    Returns
    -------
    float
        The fluff ratio, expressed as a percentage.
    """
    estimated_bytes_in_bytes = estimated_bytes * units.value
    return (measured_bytes / estimated_bytes_in_bytes) * 100


1. **Measured Checkpoint Size in Bytes**:

    - `gpt2_checkpoint_size_measured_in_bytes` is assigned a numerical value
      that represents the actual size of a GPT-2 model checkpoint file in bytes.
      This value is obtained from the output of the Unix command
      `wc -c ckpt.pt`, which counts the number of bytes in the file `ckpt.pt`.

2. **Estimated Checkpoint Size in Bytes**:

    - The `calculate_checkpoint_size` function is called with the number of
      parameters excluding biases (`gpt2_params_no_bias`), the precision of the
      model's parameters (`FloatingPointPrecision.FP32`), and the unit of
      measurement (`ByteUnits.B` for bytes). This function calculates the
      estimated total size of the checkpoint in bytes, taking into account the
      parameters and the additional storage required for the AdamW optimizer's
      buffers.
    - It is worth noting we are assuming floating-point precision of 32 bits (4
      bytes) for the model's parameters, and hence we are multiplying the number
      of parameters by 4 to obtain the size in bytes.

    - The AdamW optimizer, which is commonly used in training deep learning
      models for tasks like those involving GPT-2, maintains two additional
      values (buffers) for each parameter: the first for the moment vector (`m`)
      and the second for the squared moment vector (`v`). These buffers are used
      to adapt the learning rates for each parameter during training. This is
      why the storage requirement triples (`params_bytes + 2*params_bytes`),
      accounting for the original parameters plus the two buffers.

3. **Fluff Ratio Calculation**:

    - The `calculate_fluff_ratio` function is called with the measured size in
      bytes, the estimated size in bytes, and the unit of measurement for the
      estimated size (bytes). This function calculates the fluff ratio, which
      indicates the percentage of overhead or additional data in the measured
      checkpoint file compared to the estimated size.

In [16]:
gpt2_checkpoint_size_measured_in_bytes = 1542470366  # from 'wc -c ckpt.pt'
gpt2_checkpoint_size_measured_in_gb = gpt2_checkpoint_size_measured_in_bytes / ByteUnits.GB

gpt2_checkpoint_size_estimated_in_bytes = calculate_checkpoint_size(
    params_count=gpt2_params_no_bias,
    precision=FloatingPointPrecision.FP32,
    units=ByteUnits.B,
)
gpt2_checkpoint_size_estimated_in_gb = gpt2_checkpoint_size_estimated_in_bytes / ByteUnits.GB


fluff_ratio = calculate_fluff_ratio(
    measured_bytes=gpt2_checkpoint_size_measured_in_bytes,
    estimated_bytes=gpt2_checkpoint_size_estimated_in_bytes,
    units=ByteUnits.B,
)

data = [
    ["Measured Checkpoint Size (bytes)", gpt2_checkpoint_size_measured_in_bytes],
    ["Measured Checkpoint Size (GB)", gpt2_checkpoint_size_measured_in_gb],
    ["Estimated Checkpoint Size (bytes)", gpt2_checkpoint_size_estimated_in_bytes],
    ["Estimated Checkpoint Size (GB)", gpt2_checkpoint_size_estimated_in_gb],
    ["Fluff Ratio", fluff_ratio]
]

print(tabulate(data, headers=["Metric", "Value"], tablefmt="pretty"))


+-----------------------------------+-------------------+
|              Metric               |       Value       |
+-----------------------------------+-------------------+
| Measured Checkpoint Size (bytes)  |    1542470366     |
|   Measured Checkpoint Size (GB)   |    1.542470366    |
| Estimated Checkpoint Size (bytes) |   1492051968.0    |
|  Estimated Checkpoint Size (GB)   |    1.492051968    |
|            Fluff Ratio            | 103.3791314968461 |
+-----------------------------------+-------------------+


## GPU Memory Footprint of Loading Model and Optimizer

We can roughly understand that a checkpoint represents the amount of memory
needed to store not just the model itself (its weights) but also additional
information related to the optimizer state when you're using GPUs for deep
learning tasks.

When loading a model from a checkpoint for further training or inference, the
GPU memory must accommodate the model weights and the optimizer state (if
continuing training).

Below, we estimate the ratio of our GPU memory that will be taken up by
the model and optimizer state when loading a GPT-2 model from a checkpoint.

In [17]:
def calculate_memory_ratio(checkpoint_size: float, gpu_memory: GPUMemory) -> str:
    memory_ratio = checkpoint_size / gpu_memory.value * 100
    return f"Memory ratio taken up just for parameters: {memory_ratio:.2f}%"

print(calculate_memory_ratio(checkpoint_size=gpt2_checkpoint_size_estimated_in_bytes, gpu_memory=GPUMemory.A100_40GB))

Memory ratio taken up just for parameters: 3.73%


Assuming an A100 GPU with roughly 40GB memory, then the code calculates the
percentage of the GPU memory that the estimated checkpoint size (in bytes)
occupies. This calculation gives an insight into how much of the GPU's memory is
dedicated to storing the model's weights and the optimizer's buffers, without
considering other memory usages such as activations during forward and backward
passes.

This percentage is relatively small, implying that most of the GPU memory is
actually used for activations. Activations are the intermediate outputs of
layers during the forward pass and their gradients during the backward pass,
which can consume significant amounts of memory, especially in deep models and
with large batch sizes.

## Estimating FLOPs for a Single Forward Pass

In order to estimate FLOPs for a single forward pass, we would first need to
define what is a FLOPS.

### Basics of Floating Point Numbers

-   **Floating Point Representation**: In computers, numbers can be represented
    in various formats, and one common format is floating point. This format is
    used to represent real numbers (numbers with fractions) using a fixed amount
    of memory, allowing for a wide range of values. A floating point number is
    composed of a sign, an exponent, and a mantissa (or significand). This
    representation can handle very large numbers, very small numbers, and
    fractions.
-   **Operations on Floating Point Numbers**: Operations on floating point
    numbers include addition, subtraction, multiplication, and division. Each of
    these operations takes one or more floating point numbers as input and
    produces a floating point number as output.

### Floating Point Operations (FLOPs)

Floating Point Operations, or FLOPs, refer to individual mathematical operations
(additions, subtractions, multiplications, divisions) performed on
[floating point numbers](https://en.wikipedia.org/wiki/Floating-point_arithmetic).
Each operation counts as one FLOP.

### Counting FLOPs of Matrix Multiplications

In the context of deep learning, many operations are done via matrix
multiplications, we will take a look at how to count FLOPs for matrix
multiplications next.

Deep learning, particularly in neural networks, relies heavily on matrix
multiplications. A single matrix multiplication operation involves multiple
floating point multiplications and additions.

Consider two matrices $\mathbf{A}$ and $\mathbf{B}$ of size $m \times n$ and
$n \times p$:

$$
\mathbf{A} = \begin{bmatrix}
a_{11} & a_{12} & \cdots & a_{1n} \\
a_{21} & a_{22} & \cdots & a_{2n} \\
\vdots & \vdots & \ddots & \vdots \\
a_{m1} & a_{m2} & \cdots & a_{mn}
\end{bmatrix}_{m \times n} \quad \mathbf{B} = \begin{bmatrix}
b_{11} & b_{12} & \cdots & b_{1p} \\
b_{21} & b_{22} & \cdots & b_{2p} \\
\vdots & \vdots & \ddots & \vdots \\
b_{n1} & b_{n2} & \cdots & b_{np}
\end{bmatrix}_{n \times p}
$$

It is easy to see that if we want to compute the product
$\mathbf{C} = \mathbf{A} \mathbf{B}$, the element $c_{ij}$ of $\mathbf{C}$ is
given by:

$$
c_{ij} = \sum_{k=1}^{n} a_{ik} b_{kj}
$$

and therefore there are a total of $m \times n \times p$ multiplications and
$m \times (n-1) \times p$ additions. This amounts to roughly:

$$
m \times n \times p + m \times (n-1) \times p \approx 2 \times m \times n \times p
$$

FLOPs. Note this is basically because matrix multiplication is a series of dot
products, and each dot product involves $n$ multiplications and $n-1$ additions.

In [18]:
def flops(
    num_decoder_blocks: int = 12,
    context_length: int = 1024,
    n_embd: int = 768,
    n_head: int = 12,
    ffw_size: int = 3072,
    vocab_size: int = 50257,
) -> OrderedDict[str, int]:
    # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant
    # we count actual FLOPs, not MACs. Hence 2* all over the place
    # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D

    out = OrderedDict()
    head_size = n_embd // n_head

    # attention blocks
    # 1) the projection to key, query, values
    out["attention/kqv"] = 2 * context_length * (n_embd * 3 * n_embd)
    # 2) calculating the attention scores
    out["attention/scores"] = 2 * context_length * context_length * n_embd
    # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    out["attention/reduce"] = 2 * n_head * (context_length * context_length * head_size)
    # 4) the final linear projection
    out["attention/proj"] = 2 * context_length * (n_embd * n_embd)
    out["attention"] = sum(out["attention/" + k] for k in ["kqv", "scores", "reduce", "proj"])

    # MLP blocks
    ffw_size = 4 * n_embd  # feed forward size
    out["mlp/ffw1"] = 2 * context_length * (n_embd * ffw_size)
    out["mlp/ffw2"] = 2 * context_length * (ffw_size * n_embd)
    out["mlp"] = out["mlp/ffw1"] + out["mlp/ffw2"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = num_decoder_blocks * out["block"]
    out["dense"] = 2 * context_length * (n_embd * vocab_size)

    # forward,backward,total
    out["forward_total"] = out["transformer"] + out["dense"]
    out["backward_total"] = 2 * out["forward_total"]  # use common estimate of bwd = 2*fwd
    out["total"] = out["forward_total"] + out["backward_total"]

    return out


f = flops()
flops_total = f["forward_total"]

table = [("name", "flops", "ratio (%)")]
for k, v in f.items():
    table.append((k, v, v / flops_total * 100))

print(tabulate(table, headers="firstrow", tablefmt="pretty", numalign="right"))


+------------------+--------------+---------------------+
|       name       |    flops     |      ratio (%)      |
+------------------+--------------+---------------------+
|  attention/kqv   |  3623878656  | 1.2425508965889174  |
| attention/scores |  1610612736  | 0.5522448429284077  |
| attention/reduce |  1610612736  | 0.5522448429284077  |
|  attention/proj  |  1207959552  | 0.41418363219630583 |
|    attention     |  8053063680  | 2.7612242146420387  |
|     mlp/ffw1     |  4831838208  | 1.6567345287852233  |
|     mlp/ffw2     |  4831838208  | 1.6567345287852233  |
|       mlp        |  9663676416  | 3.3134690575704466  |
|      block       | 17716740096  |  6.074693272212485  |
|   transformer    | 212600881152 |  72.89631926654981  |
|      dense       | 79047426048  |  27.10368073345018  |
|  forward_total   | 291648307200 |        100.0        |
|  backward_total  | 583296614400 |        200.0        |
|      total       | 874944921600 |        300.0        |
+-------------

In [7]:
# now here is an estimate copy pasted from the PaLM paper
# this formula is often used to calculate MFU (model flops utilization)
def palm_flops():
    """estimate of the model flops following PaLM paper formula"""
    # non-embedding model parameters. note that we do not subtract the
    # embedding/token params because those are tied and get used in the last layer.
    N = params()['total'] - params()['emebedding/position']
    L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size
    mf_per_token = 6*N + 12*L*H*Q*T
    mf = mf_per_token * block_size
    return mf

print(f"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}")

palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001


Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:

In [8]:
# here is what we currently roughly measure
batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100
measured_time = 0.755 # in seconds per iteration
measured_throughput = batch_size / measured_time
flops_achieved = f['total'] * measured_throughput

# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores
a100_flops_promised = 312e12

# the fraction of the A100 that we are using:
print(f"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%")

fraction of A100 used: 37.14%


For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU.

In [9]:
# Finally let's check out the 6ND approximation as total cost of training in FLOPs
model_size = params()['total'] # this is number of parameters, N
tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D
a100_flops = 312e12 # 312 TFLOPS
assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)
flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization
flops_needed = 6 * model_size * tokens_num # 6ND
time_needed_s = flops_needed / flops_throughput # in seconds
print(f"time needed to train the model: {time_needed_s/3600/24:.2f} days")

time needed to train the model: 3.46 days


This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4).

Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later.