# How much VRAM does my model need for my specific use case?

This notebook estimates the **total VRAM required** to run inference for your specific use case. To get a tailored calculation, you can specify:

* A model name from the Hugging Face Hub.

* Your Hugging Face API key (if using a gated model).

* The number of active parameters in billions.

* The average input and output length (in tokens) for your workload.

* The number of concurrent users you want to support.

The final calculated memory requirement will help you select the right hardware - and this can be used for both GPU and TPU selection. For context, here are the memory capacities of some common accelerator types:

```
1 L4 GPU = 24 GB
1 H100 GPU = 80 GB
1-chip v5e TPU (ct5lp-hightpu-1t) = 16 GB
8-chip v5e TPU (ct5lp-hightpu-8t) = 128 GB
A g2-standard-96 machine with 8 L4 GPUs = 192 GB
An a3-highgpu-8g machine with 8 H100 GPUs = 640 GB

```

## What the Notebook Calculates
The notebook breaks down the memory requirement into several key components before providing a final, actionable number:

* Estimated Model Weight: The memory needed to load the model's parameters, based on a common estimation formula derived from the model's configuration.

* Activation Memory: The temporary "scratchpad" memory used by the GPU during the model's computation for a single request.

* KV Cache Memory: The memory required to store the attention context (Key-Value cache) for each concurrent user, which grows with the sequence length.

* ✅ Total Required GPU Memory: The final estimated VRAM your configuration will need. This combines the model weight, overhead, activations, and the KV cache for all concurrent users, adjusted for your desired memory utilization.

*Note: Multi-GPU cases might incur additional memory usage for communication overhead, depending on how the model and KV cache are distributed across devices.*

In [None]:
# @title Resource requirements for LLM inference
huggingface_api_key = "" # @param {type:"string"}
huggingface_model_name = "google/gemma-3-27b-it" # @param {type:"string"}
model_parameters_in_billions = 27 # @param {type:"number"}
avg_input_length = 1500 # @param {type:"integer"}
avg_output_length = 200 # @param {type:"integer"}
concurrent_users = 10 # @param {type:"integer"}

import os
import requests
import json

# Create a directory to store the config files
os.makedirs("config_files", exist_ok=True)

# Download the model's config.json file
config_url = f"https://huggingface.co/{huggingface_model_name}/resolve/main/config.json?download=true"
headers = {"Authorization": f"Bearer {huggingface_api_key}"}
response = requests.get(config_url, headers=headers)
response.raise_for_status()

with open(os.path.join("config_files", f"config.json"), "w") as f:
    f.write(response.text)

with open(os.path.join("config_files", f"config.json"), "r") as f:
  data = json.load(f)

# Check for a nested text_config, common in multimodal models.
if "text_config" in data and isinstance(data["text_config"], dict):
    config_source = data["text_config"]
    print("Using nested 'text_config' for model parameters.")
else:
    config_source = data
    print("Using top-level config for model parameters.")

# Use .get() for safe dictionary access
hidden_size = config_source.get('hidden_size')
num_hidden_layers = config_source.get('num_hidden_layers')
num_attention_heads = config_source.get('num_attention_heads')
intermediate_size = config_source.get('intermediate_size')
num_kv_heads = config_source.get('num_key_value_heads', num_attention_heads)

# Ensure all required parameters were found
required_params = [hidden_size, num_hidden_layers, num_attention_heads, intermediate_size]
if not all(required_params):
    raise ValueError("One or more required model parameters (e.g., hidden_size) could not be found in the config.")

head_dims = hidden_size // num_attention_heads
dtype = data.get('torch_dtype', 'bfloat16')

# Determine data type size in bytes
match dtype:
  case 'float16' | 'bfloat16':
    parameter_data_type_size = 2
    kv_data_type_size = 2
  case 'float32':
    parameter_data_type_size = 4
    kv_data_type_size = 4
  case _:
    parameter_data_type_size = 2
    kv_data_type_size = 2


# @title Calculate Required GPU Memory
print("\n--- Calculating Required GPU Memory ---")

# --- Component 1: Model Weight (from User Input) ---
# The number of parameters is provided directly by the user.
number_of_model_parameters = model_parameters_in_billions * 1e9
model_weight_bytes = number_of_model_parameters * parameter_data_type_size
model_weight_gb = model_weight_bytes / (1000**3)

print(f"1. Model Weight: {model_weight_gb:.2f} GB")
print(f"   (Based on user input of {model_parameters_in_billions}B parameters)")

# --- Component 2: Non-PyTorch Memory ---
non_torch_memory_gb = 1.0
print(f"2. Non-PyTorch Memory (Overhead): {non_torch_memory_gb:.2f} GB")

# --- Component 3: PyTorch Activation Peak Memory ---
sequence_length = avg_input_length + avg_output_length
pytorch_activation_peak_memory_bytes = sequence_length * (18 * hidden_size + 4 * intermediate_size)
pytorch_activation_peak_memory_gb = pytorch_activation_peak_memory_bytes / (1000**3)
print(f"3. PyTorch Activation Peak Memory (per request): {pytorch_activation_peak_memory_gb:.2f} GB")

# --- Component 4: KV Cache Memory ---
kv_vectors = 2
kv_cache_memory_per_batch_bytes = (kv_vectors * num_kv_heads * head_dims * num_hidden_layers * kv_data_type_size) * sequence_length
kv_cache_memory_per_batch_gb = kv_cache_memory_per_batch_bytes / (1000**3)
print(f"4. KV Cache Memory (per request): {kv_cache_memory_per_batch_gb:.2f} GB")

print("\n--- Total Memory Calculation ---")

# --- Final Calculation ---
static_memory_gb = model_weight_gb + non_torch_memory_gb + pytorch_activation_peak_memory_gb
total_kv_cache_for_all_users_gb = kv_cache_memory_per_batch_gb * concurrent_users
total_unadjusted_memory_gb = static_memory_gb + total_kv_cache_for_all_users_gb
required_gpu_memory_gb = total_unadjusted_memory_gb

print(f"Total Memory for {concurrent_users} users (unadjusted): {total_unadjusted_memory_gb:.2f} GB")

print("\n-------------------------------------")
print(f"✅ Required GPU Memory: {required_gpu_memory_gb:.2f} GB")
print("-------------------------------------")
print(f"\nThis is the estimated total GPU VRAM needed to serve {concurrent_users} concurrent users with the specified model and sequence lengths.")

Using nested 'text_config' for model parameters.

--- Calculating Required GPU Memory ---
1. Model Weight: 54.00 GB
   (Based on user input of 27B parameters)
2. Non-PyTorch Memory (Overhead): 1.00 GB
3. PyTorch Activation Peak Memory (per request): 0.31 GB
4. KV Cache Memory (per request): 1.13 GB

--- Total Memory Calculation ---
Total Memory for 10 users (unadjusted): 66.64 GB

-------------------------------------
✅ Required GPU Memory: 66.64 GB
-------------------------------------

This is the estimated total GPU VRAM needed to serve 10 concurrent users with the specified model and sequence lengths.


In [None]:
f