**Description**: demonstrates a significant speedup over the standard pipeline on dummy
prompts.

**Estimated run time**: ~1 min.

In [1]:
from __future__ import annotations

In [2]:
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from backprompt import Text

A prompt string can usually be split up into two parts: `context + request`. The
`context` is assumed to be fixed across all prompts. The `request` is not.

In [3]:
context = '''
This is a relatively long string. It's used (exactly as is) to provide context for many
future model calls. We'll cache the model's representation of it to save inference time
for those calls.

First, a description of the task: description of task

Here's a list of choices, exemplars, etc.

Thing 1: description for thing 1

Thing 2: description for thing 2

Thing 3: description for thing 3

Thing 4: description for thing 4

Thing 5: description for thing 5

Thing 6: description for thing 6

Thing 7: description for thing 7

Thing 8: description for thing 8

Thing 9: description for thing 9

Thing 10: description for thing 10
'''

The request is variable. It depends on a new piece of text, and requests that the LLM
does something with it using the context in `context`.

In [4]:
request_template = '''
Here's my new text: {text}

Do something with it.
'''

num_requests = 20
requests = [request_template.format(text=i) for i in range(num_requests)]

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = 'gpt2'

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Using pad_token, but it is not set yet.


In [6]:
model_and_tokenizer = (model, tokenizer)

In [7]:
# this object should be static
context_cached = Text(context, model_and_tokenizer)
# optional: cache the model's representation of it now
context_cached()

`backprompt` doesn't support batching yet. Let's mimic live inference, which usually consists of 1-by-1 / non-batched model calls

In [8]:
%%time
for request in tqdm(requests):
    # wrap it in a Text
    request = Text(request, model_and_tokenizer)
    # construct the prompt / concatenate them by adding the Text objs
    prompt = context_cached + request
    # get next-token logits for all tokens in prompt_request
    prompt()
    _ = prompt.model_repr[1].logits

  0%|          | 0/20 [00:00<?, ?it/s]

CPU times: total: 6.2 s
Wall time: 4.7 s


Compare this to the plain pipeline (on CPU):

In [9]:
%%time
for request in tqdm(requests):
    # construct the prompt by concatenating the Text objs
    prompt = context + request
    # tokenize
    encoding = tokenizer([prompt], return_tensors="pt", padding=True).to(model.device) 
    # get next-token logits for all tokens in prompt
    with torch.no_grad():
        _ = model(**encoding)

  0%|          | 0/20 [00:00<?, ?it/s]

CPU times: total: 37 s
Wall time: 20.3 s


How much overhead was needed to achieve this speedup?

In [10]:
cached_encoding, cached_out = context_cached.model_repr

def memory_mb(tensor: torch.Tensor) -> float:
    return tensor.untyped_storage().size() / 1e6

cached_encoding_memory_mb = sum(
    memory_mb(tensor) for tensor in cached_encoding.values()
)

cached_out_memory_mb = memory_mb(cached_out.logits)
cached_out_memory_mb += memory_mb(
    torch.stack([torch.stack(block) for block in cached_out.past_key_values], dim=0)
)

_cache_memory_mb = cached_encoding_memory_mb + cached_out_memory_mb
print(f'{round(_cache_memory_mb)} MB')

48 MB
