# Speculative sampling from scratch in Python

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# loading the draft model
draft= 'google/flan-t5-large'
draft_tokenizer= T5Tokenizer.from_pretrained(draft)
draft_model= T5ForConditionalGeneration.from_pretrained(draft)

# loading the target model
target= 'google/flan-t5-xl'
target_tokenizer= T5Tokenizer.from_pretrained(target)
target_model= T5ForConditionalGeneration.from_pretrained(target)

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.45G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [4]:
"""
Ensuring the tokenizers are identical
In order for speculative sampling to work, tokenization for both the draft and
target model must be identical. This is a sanity check to make sure they are.
"""

# tokenizing a sequence
prompt= 'this, is, some [text] for 1234comparing, tokenizers adoihayyuz'
ex1= target_tokenizer(prompt, return_tensors='pt').input_ids
ex2= draft_tokenizer(prompt, return_tensors='pt').input_ids

# zero means all tokenized values are the same, so the tokenizers are more than likely identical
print((ex1-ex2).abs().max())

tensor(0)


# Building Speculative Sampling

In [5]:
import torch
import torch.nn as nn

In [16]:
# performing specculative sampling

# initializing an empty input to feed to the decoder. This is updated each loop with valid generations
decoder_ids= draft_model._shift_right(draft_tokenizer("", return_tensors="pt").input_ids)

# defining input. T5 is an encoder-decoder model, so input and output are handled seperatly
input_ids= draft_tokenizer("Translate to German \n Battle not with monsters, lest you become a monster, and if you gaze into the abyss, the abyss gazes also into you.",
                           return_tensors="pt").input_ids

# defining the number of draft generations
k= 5

# keeps track of generation information, for later printouts
generated= []

# generating Text
iter= 0
for _ in range(15):
    print('========== Speculative Sampling Iteration {} =========='.format(iter))
    iter += 1

    # creating a holding place for the generated draft
    decoder_ids_draft= decoder_ids.clone()

    before_text= draft_tokenizer.decode(decoder_ids_draft[0])
    initial_length= decoder_ids.shape[1]

    # generating draft
    for i in range(k):

        # predicting the next token with the draft model
        with torch.no_grad():
            logits= draft_model(input_ids=input_ids, decoder_input_ids=decoder_ids_draft).logits
            genid= torch.argmax(logits, dim=2)[0][-1]

        # appending the generated id to the draft
        genid= genid.expand(1,1)
        decoder_ids_draft= torch.cat((decoder_ids_draft,genid),1)

    print('=== Draft Generation')
    current_draft= draft_tokenizer.decode(decoder_ids_draft[0])
    print('Generated draft tokens: {}'.format(decoder_ids_draft))
    print('Generated draft text: {}'.format(current_draft))

    # generating all next token predictions with the target
    logits= target_model(input_ids=input_ids, decoder_input_ids=decoder_ids_draft).logits
    genids= torch.argmax(logits, dim=2)[0]
    print('=== Target Generation')
    current_target= draft_tokenizer.decode(genids)
    print('Generated target tokens: {}'.format(genids))
    print('Generated target text: {}'.format(current_target))

    # checking draft against target
    for i, (dv, tv) in enumerate(zip(decoder_ids_draft[0,1:],genids[:-1])):
        # target does not agree with the draft
        if dv != tv:
            # genids is next word, so this is done to preserve the first token
            first_token= decoder_ids[0][:1]
            decoder_ids= genids[:i+1]
            decoder_ids= torch.cat((first_token,decoder_ids),0)
            break
    else:
        # no disagreements
        decoder_ids= genids

    print('=== Validated Generation')
    current_target= draft_tokenizer.decode(decoder_ids)
    print('Generated target tokens: {}'.format(decoder_ids))
    print('Generated target text: {}'.format(current_target))

    # expanding dimensions so that the shape of the tensor is the same
    decoder_ids= decoder_ids.expand(1,len(decoder_ids))

    # logging
    numgen= decoder_ids.shape[1] - initial_length
    generated.append({'tokens generated': numgen,
                      'text before': before_text,
                      'text after': current_target})


=== Draft Generation
generated draft tokens: tensor([[    0,   316, 20256,    15,   311,   181]])
generated draft text: <pad>Die Kampfe nicht mit
=== Target Generation
generated target tokens: tensor([316,   3,  15, 181, 181, 177])
generated target text: Die e mit mit den
=== Validated Generation
generated target tokens: tensor([  0, 316,   3])
generated target text: <pad>Die
=== Draft Generation
generated draft tokens: tensor([[    0,   316,     3,     2, 25231,     3,   547,   289]])
generated draft text: <pad> Die <unk>ffentlichkeit hat sich
=== Target Generation
generated target tokens: tensor([ 316,    3, 9465,   40,    3,  547,  289,    3])
generated target text: Die Brul hat sich
=== Validated Generation
generated target tokens: tensor([   0,  316,    3, 9465])
generated target text: <pad>Die Bru
=== Draft Generation
generated draft tokens: tensor([[    0,   316,     3,  9465,    17, 13680,   229,   311,   181]])
generated draft text: <pad>Die Brutalität ist nicht mit
=== Target

In [None]:
# https://towardsdatascience.com/speculative-sampling-intuitively-and-exhaustively-explained-2daca347dbb9