Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added k, v cache for inference speed up #7

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,66 +35,95 @@ def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k]
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v


def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
def mha(x, c_attn, c_proj, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projection
# when we pass kvcache, n_seq = 1. so we will compute new_q, new_k and new_v
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]

# split into qkv
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]

if kvcache:
# qkv
new_q, new_k, new_v = qkv # new_q, new_k, new_v = [1, n_embd]
old_k, old_v = kvcache
k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1
v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1
qkv = [new_q, k, v]

current_cache = [qkv[1], qkv[2]]

# split into heads
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]

# causal mask to hide future inputs from being attended to
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
if kvcache:
# when we pass kvcache, we are passing single token as input which need to attend to all previous tokens, so we create mask with all 0s
causal_mask = np.zeros((1, k.shape[0]))
else:
# create triangular causal mask
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]

# perform attention over each head
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]


# merge heads
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]

# out projection
x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]

return x
return x, current_cache


def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd]
# multi-head causal self attention
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
attn_out, kvcache_updated = mha(layer_norm(x, **ln_1), **attn, n_head=n_head, kvcache=kvcache)
x = x + attn_out # [n_seq, n_embd] -> [n_seq, n_embd]

# position-wise feed forward network
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd]

return x
return x, kvcache_updated


def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, kvcache = None): # [n_seq] -> [n_seq, n_vocab]
if not kvcache:
kvcache = [None]*len(blocks)
wpe_out = wpe[range(len(inputs))]
else:
wpe_out = wpe[[len(inputs)-1]]
inputs = [inputs[-1]]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@panaali You're correct, if kvcache is there then only the last token should be passed. But this is I being lazy and don't want to change function signatures. So, I am doing it inside function. I just use the last token as input if kvcache is there.


def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
# token + positional embeddings
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
x = wte[inputs] + wpe_out # [n_seq] -> [n_seq, n_embd]


# forward pass through n_layer transformer blocks
for block in blocks:
x = transformer_block(x, **block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
new_kvcache = []
for block, kvcache_block in zip(blocks, kvcache):
x, updated_cache = transformer_block(x, **block, n_head=n_head, kvcache=kvcache_block) # [n_seq, n_embd] -> [n_seq, n_embd]
new_kvcache.append(updated_cache) # TODO: inplace extend new cache instead of re-saving whole

# projection to vocab
x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
return x @ wte.T, new_kvcache # [n_seq, n_embd] -> [n_seq, n_vocab]


def generate(inputs, params, n_head, n_tokens_to_generate):
from tqdm import tqdm

kvcache = None
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
logits, kvcache = gpt2(inputs, **params, n_head=n_head, kvcache=kvcache) # model forward pass
Copy link

@panaali panaali Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main benefit of KV caching is that you don't need to recalculate the MLPs again for the tokens you already calculated the forward for, and so in the decoding phase you only pass the new token as input to the network.

You should only pass the next_id as input in the decoding phase. In prefill phase, the initial inputs should be passed. checkout https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L72 or https://github.com/meta-llama/llama/blob/main/llama/generation.py#L187C51-L187C59 for an example.

more: https://www.perplexity.ai/search/what-should-be-the-input-to-th-bsYpXZiuRFinjT11Ck33EA#0

next_id = np.argmax(logits[-1]) # greedy sampling
inputs = np.append(inputs, [next_id]) # append prediction to input

return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids


def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
def main(prompt: str = "Alan Turing theorized that computers would one day become", n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
from utils import load_encoder_hparams_and_params

# load encoder, hparams, and params from the released open-ai gpt-2 files
Expand Down