*More details in this article: [Estimating Memory Usage for LLMs During Inference (V2)](https://kaitchup.substack.com/p/estimating-memory-usage-for-llms)*

This notebook estimates the memory consumption of transformer models for inference.

This is only an approximation of the total memory consumed by the model, which applies optimizations like KV caching, FlashAttention, and GQA.

To get the estimation, run all the cells.

First, if you want to estimate the memory consumption of recent models, make sure you are using the last version of Hugging Face transformers.



In [None]:
!pip install --upgrade transformers

Collecting transformers
  Downloading transformers-4.45.1-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.21,>=0.20 (from transformers)
  Downloading tokenizers-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.45.1-py3-none-any.whl (9.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.19.1
    Uninstalling tokenizers-0.19.1:
      Successfully uninstalled tokenizers-0.1


In the following interactive cell, enter the name of the model. It can be the name of the repository on the Hugging Face Hub or a local path.
This cell retrieves the architecture of the model.

In [None]:
from transformers import AutoConfig

model_name = "meta-llama/Llama-3.3-70B-Instruct" # @param {type:"string"}

model_config = AutoConfig.from_pretrained(model_name)

hidden_layers = model_config.num_hidden_layers
hidden_size =  model_config.hidden_size
attention_heads = model_config.num_attention_heads
kv_heads = 0
if hasattr(model_config, "num_key_value_heads"):
  kv_heads = model_config.num_key_value_heads

print("Model: "+str(model_name))
print("Hidden layers (L): "+str(hidden_layers))
print("Hidden size (h): "+str(hidden_size))
print("Attention heads (a): "+str(attention_heads))
if kv_heads > 0:
  print("Key-value heads (g): "+str(kv_heads))

config.json:   0%|          | 0.00/879 [00:00<?, ?B/s]

Model: meta-llama/Llama-3.3-70B-Instruct
Hidden layers (L): 80
Hidden size (h): 8192
Attention heads (a): 64
Key-value heads (g): 8


In the following interactive cell enter:
- nb_billion_parameter: the number of parameters in the model, in billions. For instance, for Llama 3 8B enter 8.03 since the model has 8.03 billion parameters.
- bitwidth_model: The number of bits per parameters. For instance 16, if you load the model with float16 or bfloat16.
- seqlen: The maximum sequence length in your batches.
- batch_size: The number of instances in one batch.

In [None]:
#Number of parameters in the model (in billions)
nb_billion_parameters = 70.6 # @param {type:"number"}
print("Number of parameters in the model (n): "+str(nb_billion_parameters)+"B")

#Precision of the parameters in the model
bitwidth_model = 16 # @param {type:"integer"}
print("Bitwidth of the model's parameters (p): "+str(bitwidth_model)+"-bit")

#The maximum number of tokens in a sequence
seqlen = 8192 # @param {type:"integer"}
print("Sequence length (s): "+str(seqlen))

#The batch size
batch_size = 16 # @param {type:"integer"}
print("Batch size (b): "+str(batch_size))



#Use FlashAttention
Flash_Attention = True # @param {type:"boolean"}
tile_size = 128
if Flash_Attention:
  print("Use FlashAttention: Yes")
else:
  print("Use FlashAttention: No")

#Use a KV cache (if yes, should be equal to the seqlen)
Use_KV_Cache = True # @param {type:"boolean"}
kv_cache = 0
if Use_KV_Cache:
  print("Use a KV cache: Yes")
  kv_cache = seqlen
else:
  print("Use a KV cache: No")



Number of parameters in the model (n): 70.6B
Bitwidth of the model's parameters (p): 16-bit
Sequence length (s): 8192
Batch size (b): 16
Use FlashAttention: Yes
Use a KV cache: Yes


Run the following cell to get the estimation given the information provided in the previous cells.

In [None]:

def estimate_consumption_inference():
  return round((32*seqlen*batch_size*hidden_size + 4*attention_heads*seqlen*seqlen*batch_size)*2/(1000**3),2)
def estimate_consumption_inference_gqa():
  return round((28*seqlen*batch_size*hidden_size + ((2*kv_heads)/attention_heads)*seqlen*batch_size*hidden_size + 4*kv_heads*seqlen*seqlen*batch_size)*2/(1000**3),2)

def estimate_consumption_inference_FA(): #Ignoring GQA for simplicity; will be slightly lower with GQA
  return round((32*seqlen*batch_size*hidden_size + 4*tile_size*seqlen*batch_size)*2/(1000**3),2)


def kv_cache():
  return round(2*hidden_layers*seqlen*batch_size*hidden_size*2/(1000**3),2)

def kv_cache_gqa():
  return round(2*hidden_layers*seqlen*batch_size*(hidden_size/kv_heads)*2/(1000**3),2)

def estimate_model_size():
  return round(nb_billion_parameters*bitwidth_model/8*(1000**3)/(1000**3),2)


activation_consumption_inference = estimate_consumption_inference()
activation_consumption_inference_gqa = estimate_consumption_inference_gqa()
activation_consumption_inference_FA = estimate_consumption_inference_FA()
model_consumption = estimate_model_size()

print("Memory consumption of the model: "+str(model_consumption)+" GB\n")

print("Memory consumption of vanilla inference: "+str(activation_consumption_inference)+" GB \n")
if kv_heads > 0:
  print("Memory consumption of inference with GQA: "+str(activation_consumption_inference_gqa)+" GB \n")

print("Memory consumption of inference with FlashAttention: "+str(activation_consumption_inference_FA)+" GB \n")

if Use_KV_Cache:
  if kv_heads > 0:
    kv_cache_cost = kv_cache_gqa()
    print("Memory consumption of the KV cache (with GQA): "+str(kv_cache_cost)+" GB \n")
  else:
    kv_cache_cost = kv_cache()
    print("Memory consumption of the KV cache: "+str(kv_cache_cost)+" GB \n")
else:
  kv_cache_cost = 0

if Flash_Attention:
  print("Total Memory consumption (given the selected configuration): "+str(round(model_consumption+kv_cache_cost+activation_consumption_inference_FA,2))+" GB\n")
elif kv_heads > 0:
  print("Total Memory consumption (given the selected configuration): "+str(round(model_consumption+kv_cache_cost+activation_consumption_inference_gqa,2))+" GB\n")
else:
  print("Total Memory consumption (given the selected configuration): "+str(round(model_consumption+kv_cache_cost+activation_consumption_inference,2))+" GB\n")


Memory consumption of the model: 141.2 GB

Memory consumption of vanilla inference: 618.48 GB 

Memory consumption of inference with GQA: 129.39 GB 

Memory consumption of inference with FlashAttention: 68.85 GB 

Memory consumption of the KV cache (with GQA): 42.95 GB 

Total Memory consumption (given the selected configuration): 253.0 GB

