In [1]:
from transformers import GPT2Tokenizer, GPT2Model
import torch
import numpy as np

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
model.eval()

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [3]:
PROMPT = """This is a very, very, very big prompt. I will always use this same very big prompt to provide context to my autoregressive model language, so let's find out if I can cache it to save computation time or API costs. If only I can precompute the first tokens while waiting the user to answer, it will be so helpful ! Anyway, here is the user input, in case you were wondering:"""
USER_INPUT = """Such a nice chatbot. I am just disappointed that it takes so long to respond."""

# Check consistency between the two forward passes

In [4]:
encoded_prompt = tokenizer(PROMPT, return_tensors='pt')
encoded_user_input = tokenizer(USER_INPUT, return_tensors='pt')
encoded_full_input = tokenizer(PROMPT + USER_INPUT, return_tensors='pt')

In [5]:
assert torch.equal(torch.cat((encoded_prompt['input_ids'], encoded_user_input['input_ids']), dim=1),
                   encoded_full_input['input_ids'])

In [6]:
attention_mask = torch.cat((encoded_prompt['attention_mask'], encoded_user_input['attention_mask']), dim=1)
prompt_past_kvs = model.forward(**encoded_prompt).past_key_values

concat_output = model.forward(input_ids=encoded_user_input['input_ids'],
                              attention_mask=attention_mask,
                              past_key_values=prompt_past_kvs)
full_output = model.forward(**encoded_full_input)

In [7]:
assert torch.allclose(concat_output.last_hidden_state,
                      full_output.last_hidden_state[0, -concat_output.last_hidden_state.shape[1]:, :],
                      atol=1e-04)

# Time measure

In [10]:
%%time
for i in range(1000):
    encoded_full_input = tokenizer(PROMPT + USER_INPUT, return_tensors='pt')
    model.forward(**encoded_full_input)

CPU times: user 1min 12s, sys: 31.3 s, total: 1min 43s
Wall time: 47.9 s


In [11]:
%%time

encoded_prompt = tokenizer(PROMPT, return_tensors='pt')
prompt_past_kvs = model.forward(**encoded_prompt)['past_key_values']

for i in range(1000):
    encoded_user_input = tokenizer(USER_INPUT, return_tensors='pt')
    attention_mask = torch.cat((encoded_prompt['attention_mask'], encoded_user_input['attention_mask']), dim=1)
    model.forward(input_ids=encoded_user_input['input_ids'],
                  attention_mask=attention_mask,
                  past_key_values=prompt_past_kvs)

CPU times: user 23.1 s, sys: 7.89 s, total: 30.9 s
Wall time: 20.6 s
