## Attention performance analysis (ICLR workshop posters)

In [1]:
from dataclasses import dataclass
from typing import Dict, Optional

In [2]:
# References:
# https://www.nvidia.com/en-gb/data-center/products/a10-gpu/
# https://www.techpowerup.com/gpu-specs/a100-sxm4-40-gb.c3506

@dataclass
class XPU:
    name: str
    bytes_per_sec: int
    flop_per_sec: int
    capacity: float

# Numbers for fp16
a10 = XPU(name="A10", bytes_per_sec=600e9, flop_per_sec=125e12, capacity=24e9)
a100 = XPU(name="A100", bytes_per_sec=1560e9, flop_per_sec=312e12, capacity=40e9)

In [3]:
@dataclass
class P:
    flops: int
    memory_transfer: float

In [4]:
# Get FLOP count and memory transfer in bytes for attention during generation
def get_attn_perf_numbers(
    batch_size: int,
    seq_len: int,
    hidden_dim: int,
    kv_group_size: int,
    n_layers: int,
    bytes_per_kv: float,
) -> P:
    n_kv_elements = batch_size * n_layers * 2 * seq_len * hidden_dim // kv_group_size

    flops = 2 * n_kv_elements * kv_group_size
    memory_transfer = n_kv_elements * bytes_per_kv

    return P(flops, memory_transfer)


# Get FLOP count and memory transfer in bytes for rest of the model during generation
def get_model_perf_numbers(
    batch_size: int,
    hidden_dim: int,
    n_layers: int,
    vocab_size: int,
    bytes_per_param: float,
) -> P:
    embed_params = hidden_dim * vocab_size
    model_params = n_layers * 12 * hidden_dim**2

    # Multiply + add per param, count output projection as well
    flops = 2 * batch_size * (model_params + embed_params)
    memory_transfer = (model_params + 2 * embed_params) * bytes_per_param

    return P(flops, memory_transfer)

In [5]:
# Get FLOP count and memory transfer in bytes for a single transformer generation step
def get_perf_numbers(
    batch_size: int,
    seq_len: int,
    hidden_dim: int,
    kv_group_size: int,
    n_layers: int = 1,
    bytes_per_param: float = 2,
    bytes_per_kv: float = 2,
    vocab_size: Optional[int] = None,
) -> Dict[str, P]:
    if vocab_size is None:
        vocab_size = 0

    attn = get_attn_perf_numbers(
        batch_size, seq_len, hidden_dim, kv_group_size, n_layers, bytes_per_kv
    )
    model = get_model_perf_numbers(
        batch_size, hidden_dim, n_layers, vocab_size, bytes_per_param
    )
    return dict(attn=attn, model=model)

## Generation latency breakdown

For different setups, show the proportion of time spent doing the attention operation vs other parts of the model.

In [6]:
# Llama model config
kv_group_size = 1
hidden_dim = 4096
n_layers = 32
bytes_per_param = 2
bytes_per_kv = 2
vocab_size = 32_000

# Set platform to calculate latency numbers
platform = a100

In [7]:
print(
    "Theoretical latency numbers for single-token generation using Llama-2 on A100 (note: numbers in brackets indicate the percentage of the total time)\n\n"
)
for batch_size, seq_len in [(1, 1024), (1, 128_000)]:
    p = get_perf_numbers(
        batch_size,
        seq_len,
        hidden_dim,
        kv_group_size,
        n_layers,
        bytes_per_param,
        bytes_per_kv,
        vocab_size,
    )
    attn_time = dict(
        comp=p["attn"].flops / platform.flop_per_sec,
        mem=p["attn"].memory_transfer / platform.bytes_per_sec,
    )
    model_time = dict(
        comp=p["model"].flops / platform.flop_per_sec,
        mem=p["model"].memory_transfer / platform.bytes_per_sec,
    )
    total_time = sum(attn_time.values()) + sum(model_time.values())
    print(f"Batch size: {batch_size}, Sequence length: {seq_len}:")
    print(
        f"\t{'Attention':<10} -- Computation: {attn_time['comp']*1000:.3f}ms ({attn_time['comp'] / total_time:.1%}), Memory: {attn_time['mem']*1000:.3f}ms ({attn_time['mem'] / total_time:.1%})"
    )
    print(
        f"\t{'Model':<10} -- Computation: {model_time['comp']*1000:.3f}ms ({model_time['comp'] / total_time:.1%}), Memory: {model_time['mem']*1000:.3f}ms ({model_time['mem'] / total_time:.1%})"
    )
    print()

Theoretical latency numbers for single-token generation using Llama-2 on A100 (note: numbers in brackets indicate the percentage of the total time)


Batch size: 1, Sequence length: 1024:
	Attention  -- Computation: 0.002ms (0.0%), Memory: 0.344ms (3.8%)
	Model      -- Computation: 0.042ms (0.5%), Memory: 8.596ms (95.7%)

Batch size: 1, Sequence length: 128000:
	Attention  -- Computation: 0.215ms (0.4%), Memory: 43.019ms (82.9%)
	Model      -- Computation: 0.042ms (0.1%), Memory: 8.596ms (16.6%)

