In [1]:
%load_ext autoreload
%autoreload 2

import llminference as L
import torch
import torch.nn.functional as F
from typing import Tuple

# KV caching

Notebook showcases how to use the HuggingFace interface to generate and load the KV cache when generating new tokens, as well as how to use the custom functions to save and load the KV cache from the disk.

In [2]:
adapter = L.Adapter.from_pretrained("EleutherAI/pythia-70m")

## Test *just* caching

Let's try calling the model, generating cache, then calling the model w and w/o cache

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [4]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
# Convert to double to get rid of numerical errors
model.double()

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attention): GPTNeoXAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (embed_out): Linear(in_features=512, out_features=50304, bias=False)
)

In [5]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")

In [6]:
s = "He was walking down the street when he saw"

In [7]:
inp_tok = tokenizer(s, return_tensors="pt")
input_ids = inp_tok["input_ids"]
attention_mask = inp_tok["attention_mask"]

In [8]:
# Separate into preceding context and last token
input_ids_pre = input_ids[:, :-1]
input_ids_end = input_ids[:, -1:]

In [9]:
# Pass the prefix to the model to generate KV cache
past_key_values = model(input_ids_pre).past_key_values

In [10]:
# Generate output using the full input (no caching)
out_no_cache = model(input_ids, attention_mask=attention_mask)

In [11]:
# Generate output using the last token and KV cache
out_cache = model(input_ids_end, past_key_values=past_key_values, attention_mask=attention_mask)

In [12]:
out_cache.logits.shape

torch.Size([1, 1, 50304])

In [13]:
out_no_cache.logits.shape

torch.Size([1, 9, 50304])

In [14]:
torch.testing.assert_close(out_cache.logits[:, -1:, :], out_no_cache.logits[:, -1:, :])

In [15]:
# Compare the KV caches
for i, (layer1, layer2) in enumerate(zip(out_no_cache.past_key_values, out_cache.past_key_values)):
    for kv1, kv2 in zip(layer1, layer2):
        torch.testing.assert_close(kv1, kv2, atol=1e-5, rtol=1e-5)

**Note:** The shapes of logits is different in the two situations while the past key values will be the same!

## Left padding with *position_ids*

Pass explicit *position_ids* to the modell call when using left padding to keep track of the correct position in the sequence.

In [16]:
adapter = L.Adapter.from_pretrained("EleutherAI/pythia-70m")
adapter.model.double()

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attention): GPTNeoXAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (embed_out): Linear(in_features=512, out_features=50304, bias=False)
)

In [17]:
s = "He was walking down the street"

In [18]:
inp = adapter.tokenizer(s, return_tensors="pt")
seq_len = inp["input_ids"].shape[-1]
print(seq_len)
inp_pre = {k: v[:, :-1] for k, v in inp.items()}
inp_end = {k: v[:, -1:] for k, v in inp.items()}

6


In [19]:
# Pass the first part of the input to the model, generate KV cache
pkv = adapter.model(**inp_pre).past_key_values

In [20]:
# Generate the expected output using second part of the input + KV cache
out1 = adapter.model(**inp_end, past_key_values=pkv)

In [21]:
# Left pad the KV cache
pad_size = 5
pkv_padded = [[F.pad(kv, (0, 0, pad_size, 0)) for kv in layer] for layer in pkv]

In [22]:
padded_len = seq_len + pad_size
attention_mask = (torch.arange(padded_len) >= (pad_size)).long()[None]

In [23]:
attention_mask

tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]])

In [24]:
# Output using the left-padded KV cache
out2 = adapter.model(input_ids=inp_end["input_ids"], attention_mask=attention_mask, past_key_values=pkv_padded)

In [25]:
# Breaks because of wrong position ids
torch.testing.assert_close(out1.logits, out2.logits)

AssertionError: Tensor-likes are not close!

Mismatched elements: 50303 / 50304 (100.0%)
Greatest absolute difference: 3.7517726443015817 at index (0, 0, 22995) (up to 1e-07 allowed)
Greatest relative difference: 0.005994990867981176 at index (0, 0, 35973) (up to 1e-07 allowed)

In [26]:
# Fix by passing the right position id for the last input
out_fix = adapter.model(input_ids=inp_end["input_ids"], attention_mask=attention_mask, past_key_values=pkv_padded, position_ids=torch.tensor([[seq_len-1]]))

In [27]:
torch.testing.assert_close(out1.logits, out_fix.logits)

## Saving & Loading KV cache from disk

In [28]:
ctxs = ["How are you", "She was walking down the street"]
questions = [" doing", " and"]


inps = [
    adapter.tokenizer(ctx + question, return_tensors="pt")
    for ctx, question in zip(ctxs, questions)
]

num_new_tokens = 3
outs = [
    adapter.model.generate(
        **inp, max_length=inp["input_ids"].shape[1] + num_new_tokens, pad_token_id=0
    )[:, -num_new_tokens:]
    for inp in inps
]


In [29]:
torch.cat(outs, dim=0)

tensor([[  32,  187,  187],
        [ 703,  369, 9398]])

In [30]:
for out in outs:
    print(repr(adapter.tok_decode(out.squeeze())))

'?\n\n'
' she was wearing'


In [31]:
adapter.generate_kv_cache(ctxs, "../../cache/")

In [32]:
adapter.greedy_sample(ctxs, questions, num_new_tokens, use_cache=True, cache_dir="../../cache")

tensor([[  32,  187,  187],
        [ 703,  369, 9398]])

## KV cache size

Calculate the expected size of the KV cache for different models and context lengths. Note, although the model's config specifies `float16`, this is the `dtype` that was used during training. When retrieving the model, the model's `dtype` is actually set to `float32`.

In [33]:
from transformers import AutoConfig

In [34]:
models = [
    "pythia-70m",
    "pythia-160m",
    "pythia-410m",
    "pythia-1b",
    "pythia-1.4b",
    "pythia-2.8b",
    "pythia-6.9b",
    "pythia-12b",
]

In [35]:
bytes_per_dtype = {torch.float16: 2, torch.float32: 4}

In [36]:
config = AutoConfig.from_pretrained("EleutherAI/" + "pythia-70m")

In [37]:
config

GPTNeoXConfig {
  "_name_or_path": "EleutherAI/pythia-70m",
  "architectures": [
    "GPTNeoXForCausalLM"
  ],
  "bos_token_id": 0,
  "classifier_dropout": 0.1,
  "eos_token_id": 0,
  "hidden_act": "gelu",
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neox",
  "num_attention_heads": 8,
  "num_hidden_layers": 6,
  "rotary_emb_base": 10000,
  "rotary_pct": 0.25,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.30.2",
  "use_cache": true,
  "use_parallel_residual": true,
  "vocab_size": 50304
}

In [38]:
SEQUENCE_LENGTH = 2048
kv_sizes = {}
for model in models:
    config = AutoConfig.from_pretrained("EleutherAI/" + model)
    num_layers = config.num_hidden_layers
    sequence_len = SEQUENCE_LENGTH
    hidden_size = config.hidden_size
    num_bytes = bytes_per_dtype[torch.float32] # since the model is returned in float32!
    kv_size = 2 * num_layers * sequence_len * hidden_size * num_bytes
    kv_sizes[model] = kv_size

In [39]:
print(f"Assuming sequence length {SEQUENCE_LENGTH}, each example is expected to be of size:\n")
for k, v in kv_sizes.items():
    print(f"{k}: {v:,}")

Assuming sequence length 2048, each example is expected to be of size:

pythia-70m: 50,331,648
pythia-160m: 150,994,944
pythia-410m: 402,653,184
pythia-1b: 536,870,912
pythia-1.4b: 805,306,368
pythia-2.8b: 1,342,177,280
pythia-6.9b: 2,147,483,648
pythia-12b: 3,019,898,880


In [40]:
num_examples = 225
print(f"Assuming sequence length {SEQUENCE_LENGTH}, expected KV cache for dataset with {num_examples} examples:\n")
for k, v in kv_sizes.items():
    print(f"{k}: {v*225:,}")

Assuming sequence length 2048, expected KV cache for dataset with 225 examples:

pythia-70m: 11,324,620,800
pythia-160m: 33,973,862,400
pythia-410m: 90,596,966,400
pythia-1b: 120,795,955,200
pythia-1.4b: 181,193,932,800
pythia-2.8b: 301,989,888,000
pythia-6.9b: 483,183,820,800
pythia-12b: 679,477,248,000


In [41]:
print(f"Only considering models up to the specified size, total KV cache is:\n")
size = 0
for model in models:
    size += kv_sizes[model] * num_examples
    print(f"{model}: {size:,}")

Only considering models up to the specified size, total KV cache is:

pythia-70m: 11,324,620,800
pythia-160m: 45,298,483,200
pythia-410m: 135,895,449,600
pythia-1b: 256,691,404,800
pythia-1.4b: 437,885,337,600
pythia-2.8b: 739,875,225,600
pythia-6.9b: 1,223,059,046,400
pythia-12b: 1,902,536,294,400


**Note**: All tensors are returned in `float32`.

In [42]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")

Even though the config specifies `float16` as the `dtype`, the model is returned in `float32`.

In [43]:
print(f"Config dtype: {model.config.torch_dtype}")
print(f"Actual model dtype: {model.dtype}")

Config dtype: torch.float16
Actual model dtype: torch.float32
