https://huggingface.co/blog/assisted-generation

In [None]:
# Example showcasing the impact of batched generation. Measurement device: RTX3090
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to("cuda")
inputs = tokenizer(["Hello world"], return_tensors="pt").to("cuda")

def print_tokens_per_second(batch_size):
    new_tokens = 100
    cumulative_time = 0

    # warmup
    model.generate(
        **inputs, do_sample=True, max_new_tokens=new_tokens, num_return_sequences=batch_size
    )

    for _ in range(10):
        start = time.time()
        model.generate(
            **inputs, do_sample=True, max_new_tokens=new_tokens, num_return_sequences=batch_size
        )
        cumulative_time += time.time() - start
    print(f"Tokens per second: {new_tokens * batch_size * 10 / cumulative_time:.1f}")

print_tokens_per_second(1)   # Tokens per second: 418.3
print_tokens_per_second(64)  # Tokens per second: 16266.2 (~39x more tokens per second)


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tok = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

inputs = tok(["The"], return_tensors="pt")
generated = model.generate(**inputs, do_sample=False, max_new_tokens=10)
forward_confirmation = model(generated).logits.argmax(-1)

# We exclude the opposing tips from each sequence: the forward pass returns
# the logits for the next token, so it is shifted by one position.
print(generated[0, 1:].tolist() == forward_confirmation[0, :-1].tolist())  # True


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-1.4b-deduped"
assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(device)

model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).to(device)
outputs = model.generate(**inputs, assistant_model=assistant_model)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
