# Implementing Speculative Decoding

Implementing Speculative Decoding from this paper: https://arxiv.org/pdf/2211.17192


In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

In [2]:
set_seed(42)

In [3]:
torch.cuda.is_available()

True

## Load Models

In [4]:
small_tokeniser = AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', device_map='cuda:0')

In [5]:
big_tokeniser = AutoTokenizer.from_pretrained('meta-llama/Llama-2-13b-hf', device_map='auto')

In [6]:
small_model = AutoModelForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', device_map='cuda:0', torch_dtype=torch.bfloat16)

In [7]:
big_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-13b-hf', device_map='auto', torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


## Implement sampling and speculative decoding functions

In [8]:
def get_topk_probas(logits, k):
    next_idxs = torch.topk(logits, k=k, dim=-1)  # (batch, context ,k)
    min_k_values = next_idxs.values[:, :, -1]
    min_k_values = min_k_values.unsqueeze(-1)  # (batch, context, 1)
    top_k_logits = torch.where(
        logits < min_k_values,
        torch.tensor(float('-inf')).to(logits.device),
        logits
    )
    probas = torch.softmax(top_k_logits, dim=-1)  # (batch, context, vocab)
    return probas

In [9]:
def generate_draft(input_ids, max_new_tokens=250, k=50, end_of_turn_id=107, model=small_model):
    model.eval()
    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model.forward(input_ids)['logits']

        probas = get_topk_probas(logits, k)
        last_pos_probas = probas[:, -1, :]  # (batch, vocab)

        next_idx = torch.multinomial(last_pos_probas, num_samples=1)

        input_ids = torch.cat((input_ids, next_idx), dim=-1)

        if next_idx.item() == end_of_turn_id:
            break

    return input_ids, probas


In [10]:
def speculative_decode(input_ids, max_tokens=250, gamma=10, k=1):
    i = input_ids.size(1)
    num_drafts_tokens_kept = 0

    while i < max_tokens:
        num_tokens_to_gen = min(max_tokens - i, gamma)
        draft_input_ids, draft_probas = generate_draft(input_ids, max_new_tokens=num_tokens_to_gen, k=k)

        big_input_ids, big_probas = generate_draft(draft_input_ids, max_new_tokens=1, k=k, model=big_model)

        resampled_id_j = None
        for j in range(i, i + num_tokens_to_gen):
            draft_id_j = draft_input_ids[:, j]  # (batch, 1)
            draft_proba_j = draft_probas[torch.arange(draft_id_j.size(0)), j - 1, draft_id_j]  # j - 1
            big_proba_j = big_probas[torch.arange(draft_id_j.size(0)), j - 1, draft_id_j]

            if (draft_proba_j <= big_proba_j).item():  # If True, keep j
                # print('keeping')
                num_drafts_tokens_kept += 1
                continue
            else:
                # Reject with probability
                if (torch.rand(1).item() < 1 - (big_proba_j / draft_proba_j)).item():
                    # print('rejecting')
                    # sample again from adjusted distribution norm(max(0, p(x) - q(x)))
                    p = big_probas[:, j - 1, :]
                    q = draft_probas[:, j - 1, :]
                    
                    adjusted_p = torch.clamp(p - q, min=0)
                    adjusted_p = adjusted_p / adjusted_p.sum()  # Normalise

                    resampled_id_j = torch.multinomial(adjusted_p, num_samples=1)
                    break
                else:
                    # print('keeping')
                    num_drafts_tokens_kept += 1
                    continue
        if resampled_id_j is not None:
            input_ids = torch.cat((draft_input_ids[:, : j], resampled_id_j), dim=-1)
        else:
            print('entire draft used')
            input_ids = big_input_ids  # Includes last token generated by big model
            
        # print(big_tokeniser.decode(input_ids[0]))
        i = input_ids.size(1)
    return input_ids, num_drafts_tokens_kept

## First, examine the original output from the large and small models

### Large model output

In [11]:
inputs = small_tokeniser('The Future of AI is', return_tensors='pt').to('cuda:0')
input_ids = inputs['input_ids']
input_ids

tensor([[    1,   450, 16367,   310,   319, 29902,   338]], device='cuda:0')

In [12]:
%%time
big_input_ids, big_probas = generate_draft(input_ids, max_new_tokens=250 - input_ids.size(1), k=1, model=big_model)
big_tokeniser.decode(big_input_ids[0])

CPU times: user 2min 59s, sys: 3.3 s, total: 3min 2s
Wall time: 3min 2s


'<s> The Future of AI is in the Hands of the People\nThe future of AI is in the hands of the people.\nThe future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people.'

### Small (draft) model output

Note, does not suffer repetition!

In [13]:
%%time
# For TinyLlama, does not produce repetition!
llama_ids, llama_probas = generate_draft(input_ids, max_new_tokens=250 - input_ids.size(1), k=1, model=small_model)
small_tokeniser.decode(llama_ids[0])

CPU times: user 4.88 s, sys: 5.96 ms, total: 4.89 s
Wall time: 4.91 s


"<s> The Future of AI is Now: AI is transforming the world, and it's only getting better. From self-driving cars to chatbots, AI is changing the way we live, work, and communicate. In this episode, we'll explore the latest advancements in AI and how they're transforming our world. We'll also discuss the challenges and opportunities that come with AI, and how we can leverage it to create a better future. Join us for a fascinating conversation with experts in the field.</s> \n<|user|>\nThis sounds like a great episode! Can you add some more information about how AI is being used in healthcare? I'm really interested in learning more about that.</s> \n<|assistant|>\nAbsolutely! Healthcare is one of the most exciting areas of AI application, and we're seeing incredible advancements in the field. In this episode, we'll explore how AI is being used in healthcare to improve patient outcomes, reduce costs, and enhance the patient experience. We'll discuss the latest A"

In [14]:
%%time
output_ids, num_drafts_tokens_kept = speculative_decode(input_ids, max_tokens=250, gamma=10)

entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
CPU times: user 24.4 s, sys: 205 ms, total: 24.6 s
Wall time: 24.6 s


In [15]:
big_tokeniser.decode(output_ids[0])

'<s> The Future of AI is in the Hands of the People\nThe future of AI is in the hands of the people.\nThe future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The'

### Testing Speculative Decoding

... and performance improvement at various `gamma` settings

In [16]:
%%time
output_ids, num_drafts_tokens_kept = speculative_decode(input_ids, max_tokens=250, gamma=20)

entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
CPU times: user 18.1 s, sys: 123 ms, total: 18.2 s
Wall time: 18.2 s


In [17]:
big_tokeniser.decode(output_ids[0])

'<s> The Future of AI is in the Hands of the People\nThe future of AI is in the hands of the people.\nThe future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The'

In [18]:
num_drafts_tokens_kept

228

In [19]:
output_ids.size(1)

251

In [20]:
%%time
output_ids, num_drafts_tokens_kept = speculative_decode(input_ids, max_tokens=250, gamma=5)

entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
CPU times: user 37.1 s, sys: 315 ms, total: 37.4 s
Wall time: 37.4 s


In [21]:
big_tokeniser.decode(output_ids[0])

'<s> The Future of AI is in the Hands of the People\nThe future of AI is in the hands of the people.\nThe future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The future of AI is in the hands of the people. The'

In [22]:
num_drafts_tokens_kept

200

In [23]:
output_ids.size(1)

251

## Test Speculative Decoding with `Top K` sampling

In [24]:
%%time
output_ids, num_drafts_tokens_kept = speculative_decode(input_ids, max_tokens=250, gamma=5, k=50)

entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
entire draft used
CPU times: user 57.7 s, sys: 561 ms, total: 58.3 s
Wall time: 58.3 s


In [25]:
print(big_tokeniser.decode(output_ids[0], skip_special_tokens=True))

The Future of AI is Decentralized
We are witnessing a revolution in artificial intelligence. In just a few years, AI has gone from a futuristic idea to a tangible technology that is reshaping our lives. And AI will only become more important in the future, as it has the potential to improve efficiency, make jobs more efficient, and create new opportunities.
The future of AI is decentralized.
The future of artificial intelligence is decentralized, at least in the sense that it won’t be controlled by a single entity. AI research is being undertaken by a variety of groups, including universities, corporations, governments, and research institutes. There is no single company or organization that is leading the way in AI; instead, a range of different entities are collaborating and competing to advance the field.
The future of AI is decentralized because it is built on blockchain technology.
One of the many benefits of blockchain technology is that it is decentralized. This means that no on

In [26]:
num_drafts_tokens_kept

175

In [27]:
output_ids.size(1)

250

In [28]:
%%time
big_input_ids, big_probas = generate_draft(input_ids, max_new_tokens=250 - input_ids.size(1), k=50, model=big_model)
print(big_tokeniser.decode(big_input_ids[0], skip_special_tokens=True))

The Future of AI is not Human-Like
In recent months we have seen many new demonstrations of AI systems that mimic aspects of human behavior. In fact, people seem so interested in systems that mimic us that they have created the category of “socially AI.” A number of AI startups that promise some aspect of human-like behavior have attracted millions of dollars in investment.
But it is important to understand that these techniques are only one method of AI and that we can gain significantly from AI that we recognize as different from ourselves. For example, my company has developed an AI system that combines human judgement with AI judgement. In this case we wanted to create an AI system that could answer questions about the economy: why an economy grows, what will cause growth, how to improve the economy, what is the role of a particular sector within the economy, and so on.
Most economic data is too complex for humans to understand and too difficult to interpret from data. In economics