# GPU-vRAM Usage Estimation for Diffusion Models
## Objective
Derive an analytical equation to estimate peak vRAM usage during inference for the `stable-diffusion-v1-5/stable-diffusion-v1-5` for arbitrary input image sizes.

## Background
vRAM consumption during diffusion model inference differs significantly from model size on disk. Peak memory depends on:
 - Model weights (fixed)
 - Intermediate activations (vary with image dimensions and prompt length)
 - Framework overhead (CUDA kernels, workspace buffers)
 - Attention mechanism memory scaling (O(N²) with sequence length)

Where:
 - `H`, `W` = input image height and width
 - `prompt_length` = tokenized prompt length
 - Identify any additional factors affecting vRAM

## Requirements
 - Analyze the architecture: Understand UNet, VAE, CLIP text encoder, and how tensors flow through the pipeline
 - Account for precision: Assume `FP16` (2 bytes/parameter)
 - Model fully on GPU: Ignore pipeline.enable_model_cpu_offload() in your equation
 - Peak, not average: Find the stage with maximum memory allocation
 - Document assumptions: Clearly state what you include/exclude (e.g., gradient storage, optimizer states)

## Deliverables
 - Equation with explanation of each term
 - Derivation notes showing how you arrived at each component
 - Validation (optional but encouraged): Compare equation predictions against actual nvidia-smi measurements using the provided test code

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


## Your Task
Derive a formula:

In [6]:
def f(h: int, w: int, prompt_length: int, **kwargs) -> float:
    """
    vram estimation for stable diffusion v1.5 inference
    returns peak vram in bytes
    """
    guidance_scale = kwargs.get('guidance_scale', 7.5)
    batch_size = kwargs.get('batch_size', 1)

    # classifier-free guidance doubles the batch size for conditional and unconditional passes
    effective_batch = batch_size * (2 if guidance_scale > 1.0 else 1)

    # stable diffusion v1.5 model parameter counts
    P_UNET = 860_000_000   # unet for iterative denoising
    P_VAE = 132_000_000    # variational autoencoder (encoder + decoder)
    P_CLIP = 123_000_000   # clip text encoder
    M_weights = (P_UNET + P_VAE + P_CLIP) * 2  # multiply by 2 for fp16 precision (2 bytes per param)

    # latent space is 8x smaller than pixel space in both dimensions
    latent_h, latent_w = h // 8, w // 8
    sequence_length = latent_h * latent_w  # flattened spatial dimensions for attention

    # ========== unet peak memory ==========

    # main latent tensor that gets iteratively denoised (4 channels in latent space)
    latent_tensor = effective_batch * 4 * latent_h * latent_w * 2  # bytes

    # text embeddings from clip encoder (77 tokens, 768 dimensions)
    text_embeddings = effective_batch * 77 * 768 * 2  # bytes

    # attention mechanism memory components
    num_heads = 8  # typical number of attention heads in sd v1.5
    head_dim = 64  # dimension per attention head

    # query, key, value projections (these exist in memory during attention computation)
    attention_qkv = 3 * effective_batch * sequence_length * num_heads * head_dim * 2  # bytes

    # attention score matrix: the main o(n²) bottleneck
    attention_scores = effective_batch * num_heads * sequence_length * sequence_length * 2  # bytes

    # cross-attention with text (linear in both dimensions)
    cross_attention = effective_batch * num_heads * sequence_length * 77 * 2  # bytes

    # intermediate feature maps in unet bottleneck (1280 channels at deepest layer)
    feature_maps = effective_batch * 1280 * latent_h * latent_w * 2  # bytes

    # skip connections store encoder features for decoder
    # conservative estimate: 2 levels with average 256 channels each
    skip_connections = 2 * effective_batch * 256 * latent_h * latent_w * 2  # bytes

    # total memory during unet denoising (the primary memory bottleneck)
    unet_peak = (latent_tensor + text_embeddings + attention_qkv + attention_scores +
                 cross_attention + feature_maps + skip_connections)

    # ========== vae decoder peak memory ==========
    # vae decoder upsamples latent to pixel space through intermediate resolutions
    # peak occurs at half resolution (h/2 x w/2) with 512 channels
    vae_peak = batch_size * 512 * (h // 2) * (w // 2) * 2  # bytes

    # ========== total peak memory ==========
    # unet and vae run sequentially, not simultaneously, so take maximum of their peaks
    activation_peak = max(unet_peak, vae_peak)

    # pytorch memory allocator overhead (conservative 10%)
    overhead = activation_peak * 0.10

    # total = model weights (always resident) + peak activations + framework overhead
    total = M_weights + activation_peak + overhead

    return float(total)

In [7]:
# pip install torch torchvision diffusers['torch'] transformers accelerate

import torch
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import make_image_grid, load_image

pipeline = AutoPipelineForImage2Image.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipeline = pipeline.to("cuda" if torch.cuda.is_available() else "cpu")

# Uncomment this if you have limited GPU vRAM (although, this assignment can be done without any GPU use!)
# pipeline.enable_model_cpu_offload()

# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
# pipeline.enable_xformers_memory_efficient_attention()

# helper function to format memory in gb
def format_memory_gb(bytes_val):
    return bytes_val / (1024**3)

# prepare image
img_src = [{
    "url": "./drive/MyDrive/data/balloon--low-res.jpeg",
    "prompt": "aerial view, colorful hot air balloon, lush green forest canopy, springtime, warm climate, vibrant foliage, soft sunlight, gentle shadow, white birds flying alongside, harmony, freedom, bright natural colors, serene atmosphere, highly detailed, realistic, photorealistic, cinematic lighting"
}, {
    'url': "./drive/MyDrive/data/bench--high-res.jpg",
    'prompt': "photorealistic, high resolution, realistic lighting, natural shadows, detailed textures, lush green grass, wooden bench with grain detail, expansive valley, agricultural fields, blue-toned mountains, fluffy cumulus clouds, wispy cirrus clouds, bright blue sky, clear sunny day, soft sunlight, tranquil atmosphere, cinematic realism"
}, {
    'url': "./drive/MyDrive/data/groceries--low-res.jpg",
    'prompt': "cartoon style, bold outlines, simplified shapes, vibrant colors, playful atmosphere, exaggerated proportions, stylized SUV trunk, whimsical paper grocery bags, fresh produce with bright highlights, baguette with cartoon detail, cheerful parking area, greenery with simplified textures, sunny day, lighthearted mood, 2D illustration, animated landscape aesthetic"
}, {
    'url': "./drive/MyDrive/data/truck--high-res.jpg",
    'prompt': "Michelangelo style, Renaissance painting, classical composition, rich earthy tones, detailed brushwork, divine atmosphere, expressive lighting, monumental presence, artistic grandeur, fresco-inspired texture, high contrast shadows, timeless aesthetic"
}]

results = list()

# This for loop is meant to demonstrate that the models' vRAM usage depends
# on Image-size and prompt length (among other factors). You may observe the
# vRAM usage while the model is running by executing the following command
# in a separate terminal and monitoring the changes in vRAM usage:
#    ```shell
#    watch -n 1.0 nvidia-smi
#    ```
#
# You may modify this for loop according to your needs.
for idx, _src in enumerate(img_src, 1):
    init_image = load_image(_src.get('url'))
    prompt = _src.get('prompt')

    # get image dimensions
    w, h = init_image.size

    # estimate vram usage using our formula
    estimated_vram = f(h, w, 77, guidance_scale=5.0)

    print(f"\n{'='*70}")
    print(f"image {idx}: {_src.get('url')}")
    print(f"resolution: {w}×{h} pixels")
    print(f"estimated vram: {format_memory_gb(estimated_vram):.2f} gb")

    # check if gpu is available and get actual memory before processing
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        print(f"processing image...")

    # pass prompt and image to pipeline
    image = pipeline(prompt, image=init_image, guidance_scale=5.0).images[0]

    # get actual peak memory if gpu is available
    if torch.cuda.is_available():
        actual_peak = torch.cuda.max_memory_allocated() / (1024**3)
        print(f"actual peak vram: {actual_peak:.2f} gb")
        print(f"estimation accuracy: {(estimated_vram / (1024**3)) / actual_peak * 100:.1f}%")

    results.append(make_image_grid([init_image, image], rows=1, cols=2))

print(f"\n{'='*70}")
print(f"processed {len(results)} images successfully")
results[0].show()

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]


image 1: ./drive/MyDrive/data/balloon--low-res.jpeg
resolution: 396×380 pixels
estimated vram: 2.29 gb
processing image...


  0%|          | 0/40 [00:00<?, ?it/s]

actual peak vram: 2.92 gb
estimation accuracy: 78.4%

image 2: ./drive/MyDrive/data/bench--high-res.jpg
resolution: 2048×2048 pixels
estimated vram: 143.94 gb
processing image...


  0%|          | 0/40 [00:00<?, ?it/s]

actual peak vram: 12.09 gb
estimation accuracy: 1190.5%

image 3: ./drive/MyDrive/data/groceries--low-res.jpg
resolution: 800×534 pixels
estimated vram: 3.61 gb
processing image...


  0%|          | 0/40 [00:00<?, ?it/s]

actual peak vram: 3.54 gb
estimation accuracy: 102.0%

image 4: ./drive/MyDrive/data/truck--high-res.jpg
resolution: 1800×1200 pixels
estimated vram: 39.96 gb
processing image...


  0%|          | 0/40 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


actual peak vram: 9.06 gb
estimation accuracy: 441.2%

processed 4 images successfully


## the estimations for images 2 and 4 are way too high (143.94 GB and 39.96 GB) this is because those are high-resolution images (2048×2048 and 1800×1200), and the quadratic attention term is exploding.

## Tips
- Although no GPU is needed to accomplish this task (analyze code/architecture)
- Use PyTorch documentation and model architecture inspection

# Evaluation Criteria
- Correctness: Formula accounts for major memory consumers
- Completeness: All image-dependent and prompt-dependent factors identified
- Rigor: Derivation shows understanding of PyTorch memory model and diffusion architecture
- Clarity: Equation is readable and well-documented

In [8]:
# Test cases
# Case 1: Standard SD 512x512 image, typical prompt length
h1, w1, prompt_length1 = 512, 512, 77
vram_estimate1 = f(h1, w1, prompt_length1)
print(f"Estimate for H={h1}, W={w1}, prompt_length={prompt_length1}: {vram_estimate1 / (1024**3):.2f} GB")

# Case 2: Larger image 768x768, longer prompt length
h2, w2, prompt_length2 = 768, 768, 77
vram_estimate2 = f(h2, w2, prompt_length2)
print(f"Estimate for H={h2}, W={w2}, prompt_length={prompt_length2}: {vram_estimate2 / (1024**3):.2f} GB")

# Case 3: Smaller image 256x256, short prompt length
h3, w3, prompt_length3 = 256, 256, 20
vram_estimate3 = f(h3, w3, prompt_length3)
print(f"Estimate for H={h3}, W={w3}, prompt_length={prompt_length3}: {vram_estimate3 / (1024**3):.2f} GB")

# Case 4: Extreme image size 1024x1024
h4, w4, prompt_length4 = 1024, 1024, 77
vram_estimate4 = f(h4, w4, prompt_length4)
print(f"Estimate for H={h4}, W={w4}, prompt_length={prompt_length4}: {vram_estimate4 / (1024**3):.2f} GB")

Estimate for H=512, W=512, prompt_length=77: 2.69 GB
Estimate for H=768, W=768, prompt_length=77: 5.01 GB
Estimate for H=256, W=256, prompt_length=20: 2.13 GB
Estimate for H=1024, W=1024, prompt_length=77: 11.14 GB
