# Speedup Llama 3 8B with Speculative Decoding

*TL;DR*: Use speculative decoding (baked in `transformers`) to speedup text generation of the Llama 3 8B model.

## Text Generation is slow 🐢

We use (large) language models daily for tasks like summarization, translation, and question answering. While these tasks vary, they all stem from the same core process: **text generation**.

Text generation happens step-by-step:

1. You provide the model with a starting token (a prefix).
2. The model predicts possible next tokens by outputting their probabilities.
3. A token is added to the prefix based on its probability.
4. This cycle repeats until your desired length or output is reached.

Since each token depends on the one before it, text generation is difficult to parallelize. In simpler terms: if we can't do steps in parallel, speeding things up is tough.

Now imagine scaling this problem up. With a massive model like LlamA 3.1 405B, you not only need a powerful machine to host it, but you also face slow text generation speeds - meaning it runs for a long time.

Enough of setting up the problem, how can we solve this?

## Assistants to the rescue

Smaller models have much higher speeds of text generation, as compared to larger models. With this mantra, researchers have come up with a very interesting heuristic to speed up text generation.

The idea is simple: let a smaller, faster model generate the text. Then, use the larger model to check if it would have generated the same text. If the large model agrees, we keep the text. If not, we discard it and let the smaller model try again. This way, we skip some steps without compromising too much on quality, speeding up the process.

## Use 🤗 Transformers

In this guide, we'll show you how to implement speculative decoding with Llama models. Transformers has a `generate` API where we pass an `assistant_model` to enable speculative decoding. By the end, you'll see how this technique can significantly speed up text generation, making your workflows faster and more efficient.

## Imports and Setup

In [None]:
!pip install -Uq transformers

In [None]:
import torch
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed
)

In [None]:
# supress the warning in the notebook
import logging
import warnings
logging.getLogger('transformers').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

In [None]:
DEVICE = "cuda"

In [None]:
# base models
base_checkpoint = "meta-llama/Meta-Llama-3.1-8B" # <-- Larger Model
base_assistant_checkpoint = "meta-llama/Llama-3.2-1B" # <-- Smaller Model

# instruct models
instruct_checkpoint = "meta-llama/Meta-Llama-3.1-8B-Instruct" # <-- Larger Model
instruct_assistant_checkpoint = "meta-llama/Llama-3.2-1B-Instruct" # <-- Smaller Model

## 1. Base Models

In this section, we run our experiments on the base models.

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    base_checkpoint,
    torch_dtype=torch.bfloat16
).to(DEVICE)
base_assistant_model = AutoModelForCausalLM.from_pretrained(
    base_assistant_checkpoint,
    torch_dtype=torch.bfloat16
).to(DEVICE)

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/186 [00:00<?, ?B/s]

### Prepare Inputs

In [None]:
prompt = "Alice and Bob"

tokenizer = AutoTokenizer.from_pretrained(base_checkpoint)
model_inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

### Generation

Let's generate text from the models
1. We use the assistant model (small) to generate text
2. We use the bigger model to generate text
3. We use the bigger model assisted by the smaller model to generate text

In [None]:
def model_output(
        model,
        inputs,
        do_sample,
        temperature,
        is_instruct,
        max_new_tokens=50,
        assistant_model=None
    ):
    if is_instruct:
        gen_out = model.generate(
            inputs,
            do_sample=do_sample,
            temperature=temperature,
            assistant_model=assistant_model,
            max_new_tokens=max_new_tokens,
        )
    else:
        gen_out = model.generate(
            **inputs,
            do_sample=do_sample,
            temperature=temperature,
            assistant_model=assistant_model,
            max_new_tokens=max_new_tokens,
        )
    return gen_out

def get_generation(
        inputs,
        big_model,
        small_model,
        is_instruct=False,
        do_sample=False,
        temperature=1.0
    ):
    # text generated from the smaller model
    gen_out = model_output(
        model=small_model,
        inputs=inputs,
        do_sample=do_sample,
        temperature=temperature,
        is_instruct=is_instruct
    )
    print("\n🤏 Small Model")
    print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0].strip())

    # text generated from the bigger model
    gen_out = model_output(
        model=big_model,
        inputs=inputs,
        do_sample=do_sample,
        temperature=temperature,
        is_instruct=is_instruct
    )
    print("\n🐘 Large Model")
    print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0].strip())

    # text generated with the assistant model in place
    gen_out = model_output(
        model=big_model,
        inputs=inputs,
        do_sample=do_sample,
        temperature=temperature,
        is_instruct=is_instruct,
        assistant_model=small_model,
    )
    print("\n🤝 Large Model with assistance")
    print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0].strip())

In [None]:
# with greedy decoding check generation from the bigger model
# with and without assistance
get_generation(
    inputs=model_inputs,
    big_model=base_model,
    small_model=base_assistant_model,
    is_instruct=False,
    do_sample=False,
    temperature=1.0,
)


🤏 Small Model
Alice and Bob are playing a game. Alice has a deck of cards, and Bob has a deck of cards. Alice and Bob each take a card from their deck and place it face down on the table. Alice and Bob then take turns to turn over the top

🐘 Large Model
Alice and Bob are playing a game. Alice has a deck of n cards, numbered from 1 to n. Bob has a deck of m cards, numbered from 1 to m. Alice and Bob take turns, with Alice going first. On each turn,

🤝 Large Model with assistance
Alice and Bob are playing a game. Alice has a deck of n cards, numbered from 1 to n. Bob has a deck of m cards, numbered from 1 to m. Alice and Bob take turns, with Alice going first. On each turn,


In [None]:
# using multinomila decoding
# here the generations should be different
get_generation(
    big_model=base_model,
    small_model=base_assistant_model,
    inputs=model_inputs,
    is_instruct=False,
    do_sample=True,
    temperature=0.2,
)


🤏 Small Model
Alice and Bob are playing a game. Alice is given a deck of cards and Bob is given a deck of cards. They are told to take turns drawing cards from their respective decks and to keep drawing until one of them runs out of cards. The winner is the

🐘 Large Model
Alice and Bob are playing a game. Alice has a bag containing 1000 coins, each of which has a positive integer written on it. Bob chooses a positive integer k, and Alice chooses a coin at random. If the number on the coin is divisible by

🤝 Large Model with assistance
Alice and Bob are playing a game. Alice has a deck of n cards numbered from 1 to n. Bob has a deck of m cards numbered from 1 to m. Alice and Bob take turns, with Alice going first. On each turn, a player


### Benchmark the text generation speed

To understand how efficient our model is at generating text, we'll benchmark both its speed and memory usage. This involves running the model several times while measuring the time it takes and how much GPU memory it consumes.

We run the `model.generate()` function inside a loop 10 times. Why 10? Running the model multiple times helps us get an average measurement, reducing the impact of any variability in GPU performance.

We calculate two key metrics:
1. Max memory - the peak memory used by the GPU during generation. This helps us understand how much memory the model needs, which is crucial when working with large models.
2. Throughput - how many tokens per second the model can generate. This gives a clear picture of how fast the model is and can be a critical metric when deciding on deployment strategies.


In [None]:
def measure_speed(
        inputs,
        big_model,
        gen_kwargs,
        is_instruct=False,
        num_runs=10,
        max_new_tokens=256,
    ):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    torch.cuda.reset_peak_memory_stats(DEVICE)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    start_event.record()
    for _ in tqdm(range(num_runs)):
        if is_instruct:
            gen_out = big_model.generate(
                inputs,
                **gen_kwargs
            )
        else:
            gen_out = big_model.generate(
                **inputs,
                **gen_kwargs
            )
    end_event.record()

    torch.cuda.synchronize()
    max_memory = torch.cuda.max_memory_allocated(DEVICE)

    print(f"\nMax memory: {max_memory * 1e-6:.4f} (MB)")
    print(f"\nThroughput: {(num_runs * max_new_tokens) / (start_event.elapsed_time(end_event) * 1.0e-3):.4f} (tokens/sec)")

In [None]:
# big model generation
generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": False,
    "temperature": 1.0,
    "eos_token_id": -1,
}

print("🐘 Big Model Generation")
measure_speed(
    inputs=model_inputs,
    big_model=base_model,
    gen_kwargs=generate_kwargs,
    is_instruct=False,
    num_runs=10,
    max_new_tokens=256,
)

# assisted model generation
assisted_kwargs = generate_kwargs.copy()
assisted_kwargs["assistant_model"] = base_assistant_model

print("\n\n🤝 Big Model Generation with Assistance")
measure_speed(
    inputs=model_inputs,
    big_model=base_model,
    gen_kwargs=assisted_kwargs,
    is_instruct=False,
    num_runs=10,
    max_new_tokens=256,
)

🐘 Big Model Generation


100%|██████████| 10/10 [02:44<00:00, 16.47s/it]



Max memory: 18580.5553 (MB)

Throughput: 15.5445 (tokens/sec)


🤝 Big Model Generation with Assistance


100%|██████████| 10/10 [02:08<00:00, 12.89s/it]


Max memory: 18635.4258 (MB)

Throughput: 19.8538 (tokens/sec)





In [None]:
# big model generation (multinomail)
generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": True,
    "temperature": 0.2,
    "eos_token_id": -1,
}
print("🐘 Big Model Generation (multinomial)")
measure_speed(
    inputs=model_inputs,
    big_model=base_model,
    gen_kwargs=generate_kwargs,
    is_instruct=False,
    num_runs=10,
    max_new_tokens=256,
)

# assisted model generation (multinomial)
assisted_kwargs = generate_kwargs.copy()
assisted_kwargs["assistant_model"] = base_assistant_model

print("\n\n🤝 Big Model Generation with Assistance (multinomail)")
measure_speed(
    inputs=model_inputs,
    big_model=base_model,
    gen_kwargs=assisted_kwargs,
    is_instruct=False,
    num_runs=10,
    max_new_tokens=256,
)

🐘 Big Model Generation (multinomial)


100%|██████████| 10/10 [02:46<00:00, 16.67s/it]



Max memory: 18581.9146 (MB)

Throughput: 15.3532 (tokens/sec)


🤝 Big Model Generation with Assistance (multinomail)


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]


Max memory: 18642.4873 (MB)

Throughput: 21.5407 (tokens/sec)





In [None]:
# clear gpu cache
import torch

base_model.to("cpu")
base_assistant_model.to("cpu")

del base_model
del base_assistant_model

torch.cuda.empty_cache()

## 2. Instruct Models

In [None]:
instruct_model = AutoModelForCausalLM.from_pretrained(
    instruct_checkpoint,
    torch_dtype=torch.bfloat16
).to(DEVICE)
instruct_assistant_model = AutoModelForCausalLM.from_pretrained(
    instruct_assistant_checkpoint,
    torch_dtype=torch.bfloat16
).to(DEVICE)

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

### Prepare Inputs

In [None]:
prompt = [
    {"role": "system", "content": "You are a helpful assistant, that responds as a pirate."},
    {"role": "user", "content": "What's Deep Learning?"},
]

tokenizer = AutoTokenizer.from_pretrained(instruct_checkpoint)
model_inputs = tokenizer.apply_chat_template(
    prompt,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
).to(DEVICE)

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

### Generation

In [None]:
# with greedy decoding check generation from the bigger model
# with and without assistance
get_generation(
    inputs=model_inputs,
    big_model=instruct_model,
    small_model=instruct_assistant_model,
    is_instruct=True,
    do_sample=False,
    temperature=1.0,
)


🤏 Small Model
system

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant, that responds as a pirate.user

What's Deep Learning?assistant

Yer lookin' fer a swashbucklin' explanation o' Deep Learning, eh? Alright then, settle yerself down with a pint o' grog and listen close.

Deep Learning be a type o' machine learnin

🐘 Large Model
system

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant, that responds as a pirate.user

What's Deep Learning?assistant

Ye be wantin' to know about Deep Learnin', eh?  Alright then, matey.  Deep Learnin' be a type o' machine learnin' that uses artificial neural networks, inspired by the way our brains work.

🤝 Large Model with assistance
system

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant, that responds as a pirate.user

What's Deep Learning?assistant

Ye be wantin' to know about Deep Learnin', eh?  Alright then, matey.  Deep L

In [None]:
# using multinomila decoding
# here the generations should be different
get_generation(
    big_model=instruct_model,
    small_model=instruct_assistant_model,
    inputs=model_inputs,
    is_instruct=True,
    do_sample=True,
    temperature=0.2,
)


🤏 Small Model
system

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant, that responds as a pirate.user

What's Deep Learning?assistant

Yer lookin' fer a treasure trove o' knowledge about Deep Learning, eh? Alright then, matey, settle yerself down with a pint o' grog and listen close.

Deep Learning be a type o' machine learn

🐘 Large Model
system

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant, that responds as a pirate.user

What's Deep Learning?assistant

Ye be wantin' to know about Deep Learnin', eh?  Alright then, matey.  Deep Learnin' be a type o' machine learnin' that uses artificial neural networks, like a ship's crew o' interconnected nodes

🤝 Large Model with assistance
system

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant, that responds as a pirate.user

What's Deep Learning?assistant

Arrr, ye landlubber! Ye be wantin' to know about Deep Learni

### Benchmark

In [None]:
# big model generation
generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": False,
    "temperature": 1.0,
    "eos_token_id": -1,
}

print("🐘 Big Model Generation")
measure_speed(
    inputs=model_inputs,
    big_model=instruct_model,
    gen_kwargs=generate_kwargs,
    is_instruct=True,
    num_runs=10,
    max_new_tokens=256,
)

# assisted model generation
assisted_kwargs = generate_kwargs.copy()
assisted_kwargs["assistant_model"] = instruct_assistant_model

print("\n\n🤝 Big Model Generation with Assistance")
measure_speed(
    inputs=model_inputs,
    big_model=instruct_model,
    gen_kwargs=assisted_kwargs,
    is_instruct=True,
    num_runs=10,
    max_new_tokens=256,
)

🐘 Big Model Generation


100%|██████████| 10/10 [02:45<00:00, 16.51s/it]



Max memory: 18589.5956 (MB)

Throughput: 15.5016 (tokens/sec)


🤝 Big Model Generation with Assistance


100%|██████████| 10/10 [02:19<00:00, 13.96s/it]


Max memory: 18623.1148 (MB)

Throughput: 18.3311 (tokens/sec)





In [None]:
# big model generation (multinomail)
generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": True,
    "temperature": 0.2,
    "eos_token_id": -1,
}
print("🐘 Big Model Generation (multinomial)")
measure_speed(
    inputs=model_inputs,
    big_model=instruct_model,
    gen_kwargs=generate_kwargs,
    is_instruct=True,
    num_runs=10,
    max_new_tokens=256,
)

# assisted model generation (multinomial)
assisted_kwargs = generate_kwargs.copy()
assisted_kwargs["assistant_model"] = instruct_assistant_model

print("\n\n🤝 Big Model Generation with Assistance (multinomial)")
measure_speed(
    inputs=model_inputs,
    big_model=instruct_model,
    gen_kwargs=assisted_kwargs,
    is_instruct=True,
    num_runs=10,
    max_new_tokens=256,
)

🐘 Big Model Generation (multinomial)


100%|██████████| 10/10 [02:47<00:00, 16.76s/it]



Max memory: 18589.5956 (MB)

Throughput: 15.2700 (tokens/sec)


🤝 Big Model Generation with Assistance (multinomial)


100%|██████████| 10/10 [02:11<00:00, 13.20s/it]


Max memory: 18625.8534 (MB)

Throughput: 19.3938 (tokens/sec)





## Conclusion

| | Base Model Throughput | | Instruction Tuned Model Throughput | |
| :-- | :-- | --: | :-- | --: |
| | simple | assisted | simple | assisted |
| greedy | 15.5445 | **19.8538** | 15.5016 | **18.3311** |
| multinomial | 15.3532 | **21.5407** | 15.27 | **19.3938** |


The throughput increases with assisted generation! 🎉

While this process gains speed, it often comes at the cost of increased memory usage, so it's important to balance both metrics.

## Next Steps

One can experiment with the 3B model and investigate the speedup vs memory gains. Feel free

To know more about speculative decoding we suggest reading:

- [Assisted Generation](https://huggingface.co/blog/assisted-generation): Learn more about assisted generation.
- [Speculative Decoding Docs](https://huggingface.co/docs/transformers/main/en/generation_strategies#speculative-decoding): See how transformers does decoding with an assistant.

## Acknowledgements

1. [Vaibhav Srivastav](https://huggingface.co/reach-vb) for the thorough review and suggestions to make the tutorial better.
2. [Joao Gante](https://huggingface.co/joaogante) for clarifying my doubts on speculative decoding.