# LLM Inference Optimization

Basic inference is slow because LLMs have to be called repeatedly to generate the next token. The input sequence increases as generation progresses, which takes longer and longer for the LLM to process. LLMs also have billions of parameters, making it a challenge to store and handle all those weights in memory.

HuggingFace also provides **Text Generation Inference (TGI)** dedicated to deploying and serving highly optimized LLMs for inference.

## Static kv-cache and torch.compile

During decoding, a LLM computes the key-value (kv) values for each input token and since it is autoregressive, it computes the same kv values each time because the generated output becomes part of the input now.

To optimize this process, we can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents us from taking advantage of `torch.compile`.

The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows us to combine it with `torch.compile` for up to a 4x speed-up.

There are three flavors of static kv-cache usage:
1. Basic usage
2. Advanced usage: control Static Cache
3. Advanced usage: end-to-end generate compilation

##### basic usage: generation_config

1. Access the model's `generation_config` attribute and set the `cache_implementation` to `"static"`;
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

model_name = 'google/gemma-2b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto',
    device_map='auto'
)

In [None]:
model.generation_config.cache_implementation = 'static'
model.forward = torch.compile(
    model.forward,
    mode='reduce-overhead',
    fullgraph=True
)

In [None]:
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)
outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

##### advanced usage: control Static Cache

A `StaticCache` object can be passed to the model's `generate()` under the `past_key_values` argument.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

model_name = 'google/gemma-2b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto',
    device_map='auto'
)

In [None]:
model.forward = torch.compile(
    model.forward,
    mode='reduce-overhead',
    fullgraph=True
)

input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)
prompt_length = input_ids.input_ids.shape[1]
model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
    config=model.config,
    batch_size=1,
    # If the cache is re-used, make sure the cache length is large enough for all cases
    max_cache_len=prompt_length + model.generation_config.max_new_tokens*2,
    device=model.device,
    dtype=model.dtype
)

outputs = model.generate(
    **input_ids,
    past_key_values=past_key_values
)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

# pass the generated text and the same cache object to continue generation from
# where it left off.
# Optionally, in a multi-turn conversation, append the new user input to the generated text.
new_input_ids = outputs
outputs = model.generate(
    new_input_ids,
    past_key_values=past_key_values
)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

If we want to reuse the same `StaticCache` object on a new prompt, make sure to reset its contents with the `.reset()` method between calls.

The `StaticCache` object can also be passed to the model's forward pass under the same `past_key_values`. We can write our own function to decode the next token given the current token and position and cache position of previously generated tokens.

In [None]:
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
from transformers.testing_utils import CaptureLogger
import torch
from accelerate.test_utils.testing import get_backend

prompts = [
    "Simply put, the theory of relativity states that ",
    "My favorite all time favorite condiment is ketchup.",
]

NUM_TOKENS_TO_GENERATE = 40
torch_device, _, _ = get_backend()

In [None]:
model_name = 'meta-llama/Llama-2-7b-hf'
tokenizer = LlamaTokenizer.from_pretrained(
    model_name,
    pad_token='</s>',
    padding_side='right'
)
model = LlamaForCausalLM.from_pretrained(
    model_name,
    device_map='sequential'
)

In [None]:
inputs = tokenizer(prompts, return_tensors='pt', padding=True).to(model.device)

def decode_one_token(model, cur_token, input_pos, cache_position, past_key_values):
    logits = model(
        cur_token,
        position_ids=input_pos,
        cache_position=cache_position,
        past_key_values=past_key_values,
        return_dict=False,
        use_cache=True
    )[0]

    new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    return new_token

To enable static kv-cache and `torch.compile` with the `StaticCache` method, we must
1. initialize the `StaticCache` instance before using the model for inference.
2. call `torch.compile` on the model to compile the forward pass with the static kv-cache.
3. use `SDPBackend.MATH` in the `torch.nn.attention.sdpa_kernel` context mananger to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.

In [None]:
from torch.nn.attention import SDPBackend, sdpa_kernel

batch_size, seq_length = inputs['input_ids'].shape

with torch.no_grad():
    past_key_values = StaticCache(
        config=model.config,
        batch_size=2,
        max_cache_len=4096,
        device=torch_device,
        dtype=model.dtype
    )
    cache_position = torch.arange(seq_length, device=torch_device)
    generated_ids = torch.zeros(
        batch_size,
        seq_length + NUM_TOKENS_TO_GENERATE + 1,
        dtype=torch.int,
        device=torch_device
    )
    generated_ids[:, cache_position] = inputs['input_ids'].to(torch_device).tor(torch.int)

    logits = model(
        **inputs,
        cache_position=cache_position,
        past_key_values=past_key_values,
        return_dict=False,
        use_cache=True
    )[0]
    next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    generated_ids[:, seq_length] = next_token[:, 0]

    decode_one_token = torch.compile(
        decode_one_token,
        mode='reduce-overhead',
        fullgraph=True
    )
    cache_position = torch.tensor([seq_length + 1], device=torch_device)

    for _ in range(1, NUM_TOKENS_TO_GENERATE):
        with sdpa_kernel(SDPBackend.MATH):
            next_token = decode_one_token(
                model,
                next_token.clone(),
                None,
                cache_position,
                past_key_values
            )
            generated_ids[:, cache_position] = next_token.int()
        cache_position += 1

text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(text)

##### advanced usage: end-to-end generate compilation

Compiling the entire `generate` function only needs to call `torch.compile` on `generate` to compile the entire function.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os

os.envrion['TOKENIZERS_PARALLELISM'] = 'false'

model_name = 'google/gemma-2b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto',
    device_map='auto'
)

In [None]:
model.generate = torch.compile(
    model.generate,
    mode='reduce-overhead',
    fullgraph=True
)

input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)

outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

## Speculative decoding

For each input token generated by an autoregressive model, we need to load the model weights each time during the forward pass.

Speculative decoding alleviates this slowdown by using a second smaller and faster assistant model to generate candidate tokens that are verified by the larger LLM in a single forward pass. If the verified tokens are correct, the LLM essentially gets them for "free" without having to generate them itself. There is no degradation in accuracy because the verification forward pass ensures the same outputs are generated as if the LLM had generated them on its own.

##### greedy search

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate.test_utils.testing import get_backend

device, _, _ = get_backend()

model_name = 'facebook/opt-1.3b'
assistant_model_name = 'facebook/opt-125m'

tokenizer = AutoTokenizer.from_pretrained(model_name).to(device)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto'
).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_name,
).to(device)

In [None]:
inputs = tokenizer("Einstein's theory of relativity states", return_tensors='pt').to(device)

outputs = model.generate(
    **inputs,
    assistant_model=assistant_model
)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

##### sampling

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate.test_utils.testing import get_backend

device, _, _ = get_backend()

model_name = 'facebook/opt-1.3b'
assistant_model_name = 'facebook/opt-125m'

tokenizer = AutoTokenizer.from_pretrained(model_name).to(device)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto'
).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_name,
).to(device)

In [None]:
inputs = tokenizer("Einstein's theory of relativity states", return_tensors='pt').to(device)

outputs = model.generate(
    **inputs,
    assistant_model=assistant_model,
    do_sample=True,
    temperature=0.7
)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

### Prompt lookup decoding

**Prompt lookup decoding** is a variant of speculative decoding that is also compatible with greedy search and sampling.

Prompt lookup works especially well for input-grounded tasks - such as summarization - where there is often overlapping words between the prompt and output. These overlapping n-grams are used as the LLM candidate tokens.

##### greedy decoding

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate.test_utils.testing import get_backend

device, _, _ = get_backend()

model_name = 'facebook/opt-1.3b'
assistant_model_name = 'facebook/opt-125m'

tokenizer = AutoTokenizer.from_pretrained(model_name).to(device)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto'
).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_name,
).to(device)

In [None]:
inputs = tokenizer("The second law of thermodynamics states", return_tensors="pt").to(device)
outputs = model.generate(
    **inputs,
    prompt_lookup_num_tokens=3 # add here
)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

##### sampling

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate.test_utils.testing import get_backend

device, _, _ = get_backend()

model_name = 'facebook/opt-1.3b'
assistant_model_name = 'facebook/opt-125m'

tokenizer = AutoTokenizer.from_pretrained(model_name).to(device)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto'
).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_name,
).to(device)

In [None]:
inputs = tokenizer("The second law of thermodynamics states", return_tensors="pt").to(device)
outputs = model.generate(
    **inputs,
    prompt_lookup_num_tokens=3,
    do_sample=True, # add here
    temperature=0.7 # and here
)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

## Attention optimization

The self-attention mechanism in the transformer model grows quadratically in compute and memory with the number of input tokens. This limitation is only magnified in LLMs which handles much longer sequences.

### FlashAttention-2

**FlashAttention** and **FlashAttention-2** break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to GPU memory to speed up inference.

FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead.

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quant_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    'google/gemma-2b',
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
    attn_implementation='flash_attention_2' # add here
)

### Fine-tuning with `torch.compile` and padding-free data collation

We can enhance the training efficiency of large language models by leveraging `torch.compile` during fine-tuning and using a padding-free data collator.

In [None]:
import math
import datasets
import dataclasses
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

model_name = 'meta-llama/Llama-3.2-1B'

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=true)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation='flash_attention_2', # enable FlashAttention-2
)

In [None]:
response_template = "\n### Label:"
response_template_ids = tokenizer.encode(
    response_template,
    add_special_tokens=False
)[2:] # exclude special tokens

data_collator = DataCollatorForCompletionOnlyLM(
    response_template_ids=response_template_ids,
    tokenizer=tokenizer,
    ignore_index=-100,
    padding_free=True # enable padding-free collation
)

def format_dataset(example):
    return {
        'output': example['output'] + tokenizer.eos_token
    }


data_files = {'train': 'path/to/dataset'}
json_dataset = datasets.load_dataset('json', data_files=data_files)
formatted_train_dataset = json_dataset['train'].map(format_dataset)

In [None]:
train_args = TrainingArguments(
    num_train_epochs=5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    weight_decay=0.0,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    logging_steps=1,
    include_tokens_per_second=True,
    save_strategy="epoch",
    output_dir="output",
    torch_compile=True,  # enable torch.compile
    torch_compile_backend="inductor",
    torch_compile_mode="default"
)

In [None]:
# convert TrainingArguments to SFTConfig
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
transformer_kwargs = {
    k: v
    for k, v in train_args.to_dict().items()
    if k in transformer_train_arg_fields
}
training_args = SFTConfig(**transformer_kwargs)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=formatted_train_dataset,
    data_collator=data_collator,
    dataset_text_field="output",
    args=training_args,
)
trainer.train()

### PyTorch scaled dot product attention

**Scaled dot product attention (SDPA)** is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch C++ implementation.

We can use the `torch.nn.attention.sdpa_kernel` context manager to explicitly enable or disable any of the attention algorithms.

In [None]:
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('google/gemma-2b')
model = AutoModelForCausalLM.from_pretrained(
    'google/gemma-2b',
    torch_dtype=torch.bfloat16
)

In [None]:
inputs = tokenizer("The second law of thermodynamics states", return_tensors="pt").to(device)

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    outputs = model.generate(**inputs)

print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

## Quantization

Quantization reduces the size of the LLM weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if we are constrained by our GPU's memory.

In [None]:
# load `Mistral-7B-v0.1` in half-precision
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

In [None]:
# load a quantized model
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

quant_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1", quantization_config=quant_config, device_map="auto"
)