## Speculative Sampling

Here our goal is to speed up generative model inference time. This has many use cases for edit suggestion for writing or coding.

We will use a draft/basic decoder like GPT-2 and a bigger size model as the core model. We will prompt the draft model and generate k speculative tokens along with their probabilities.

Next we feed those k tokens along with the original prompt to the main model to get their liklihhods at once from the attention mask layer. Then we use the probabilities for speculative tokens from the draft model and main model to accept or reject speculated tokens.

See this video for more explanations:

https://www.youtube.com/watch?v=S-8yr_RibJ4

The key insight is that there are many simple tokens like "of" that even smaller model can easily predict them so we can use the smaller model to generate them faster and then use the bigger size model for facts and harder tokens!

https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/

https://www.youtube.com/watch?v=9wNAgpX6z_4

https://docs.google.com/presentation/d/1p1xE-EbSAnXpTSiSI0gmy_wdwxN5XaULO3AnCWWoRe4/edit#slide=id.p

## Load draft model

`GPT-2`

In [None]:
import time
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn.functional as F

In [None]:
# Load pre-trained tokenizer and model for text generation
tokenizer_draft = GPT2Tokenizer.from_pretrained('gpt2')
model_draft = GPT2LMHeadModel.from_pretrained('gpt2')

# Move model and input to the same device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_draft.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

## Prompt the draft model to generate speculated tokens

`k=2`

In [None]:
# Define the input prompt
prompt = "What is mitosis? Mitosis is the process by which a protein is broken"

k=6
# we speculate next k tokens and store them along with their probs from draft model
draft_next_tokens = []
draft_next_token_probs = []

for _ in range(k):
  # Tokenize the prompt (turn it into token IDs)
  encoded_prompt = tokenizer_draft(prompt, return_tensors='pt')

  encoded_prompt = {key: value.to(device) for key, value in encoded_prompt.items()}

  # Run the model and get the logits for the next token
  with torch.no_grad():  # Disable gradient calculation during inference
      outputs = model_draft(**encoded_prompt)
      logits = outputs.logits

  # Extract the logits for the next token (logits for the token after the input prompt)
  next_token_logits = logits[0, -1, :]  # Logits for the next token (after the prompt)

  # Apply softmax to convert logits into probabilities
  probabilities = F.softmax(next_token_logits, dim=-1)

  # Get the top k most likely tokens and their probabilities
  top_k = 10
  top_token_probs, top_token_ids = torch.topk(probabilities, k=top_k)

  # Decode the top k token IDs into human-readable tokens
  top_token_strings = [tokenizer_draft.decode([token_id.item()]) for token_id in top_token_ids]

  # Print the top 10 most likely next tokens and their probabilities
  print(f"Top {top_k} tokens and their likelihoods for the next token:")
  for i in range(top_k):
      print(f"Token: {top_token_strings[i]} | Probability: {top_token_probs[i].item():.4f}")
  print("\n----------------------\n")
  # Add the predicted token to the input prompt to predict the second positon
  prompt = prompt + top_token_strings[0]

  draft_next_tokens.append(top_token_strings[0])
  draft_next_token_probs.append(top_token_probs[0].item())

Top 10 tokens and their likelihoods for the next token:
Token:  down | Probability: 0.9311
Token:  into | Probability: 0.0164
Token:  up | Probability: 0.0151
Token:  apart | Probability: 0.0037
Token:  and | Probability: 0.0036
Token:  in | Probability: 0.0032
Token:  from | Probability: 0.0030
Token: , | Probability: 0.0029
Token:  out | Probability: 0.0028
Token: . | Probability: 0.0026

----------------------

Top 10 tokens and their likelihoods for the next token:
Token:  into | Probability: 0.3629
Token:  and | Probability: 0.1590
Token:  by | Probability: 0.1014
Token:  to | Probability: 0.0693
Token:  in | Probability: 0.0563
Token: , | Probability: 0.0507
Token: . | Probability: 0.0501
Token:  or | Probability: 0.0185
Token:  from | Probability: 0.0141
Token:  ( | Probability: 0.0083

----------------------

Top 10 tokens and their likelihoods for the next token:
Token:  its | Probability: 0.1434
Token:  a | Probability: 0.0933
Token:  smaller | Probability: 0.0542
Token:  ami

## Generate spculated tokens from Draft model & probs

In [None]:
print(f"Speculated tokens: {draft_next_tokens}")
print(f"Speculated tokens probs: {draft_next_token_probs}")


Speculated tokens: [' down', ' into', ' its', ' constituent', ' parts', '.']
Speculated tokens probs: [0.931113064289093, 0.36290252208709717, 0.14342136681079865, 0.7328725457191467, 0.26946306228637695, 0.3097057044506073]


## Load target model

In [None]:
# Load the GPT2-XL tokenizer and model
tokenizer_target = GPT2Tokenizer.from_pretrained('gpt2-xl')
model_target = GPT2LMHeadModel.from_pretrained('gpt2-xl')

## Use target model for evaluation

We pass all speculated tokens to target model get the tokens liklihood for accepting or rejecting them. We also generate one token as extra credit at the end!

In [None]:
# Move the model and input tensors to the appropriate device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_target.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1600)
    (wpe): Embedding(1024, 1600)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-47): 48 x GPT2Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=4800, nx=1600)
          (c_proj): Conv1D(nf=1600, nx=1600)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=6400, nx=1600)
          (c_proj): Conv1D(nf=1600, nx=6400)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1600, out_features=50257, bias=False)
)

In [None]:
# Check the model's vocabulary size (it should be 50257 for GPT2)
vocab_size = model_target.config.vocab_size
vocab_size

50257

## Add speculated token to prompt for evaluation!

In [None]:
# Define the prompt with speculated tokens
new_prompt = "What is mitosis? Mitosis is the process by which a protein is broken down into its constituent parts."
new_prompt

'What is mitosis? Mitosis is the process by which a protein is broken down into its constituent parts.'

In [None]:
target_probs = []
target_next_token = None

# Tokenize the prompt
encoded_prompt = tokenizer_target(new_prompt, return_tensors='pt')

all_target_probs = []

# Get model outputs (logits) for the input
with torch.no_grad():  # Disable gradient computation to save memory
    outputs = model_target(**encoded_prompt)
    logits = outputs.logits  # Raw logits (before softmax)


for inx in range(-1, -1-k-1, -1):
  # Get the logits for the last token in the sequence (the next token prediction)
  token_logits = logits[:, inx, :]  # The logits for the last token position

  # Apply softmax to get probabilities of each token in the vocabulary
  probabilities = torch.softmax(token_logits, dim=-1)

  # Let's keep target probs from last token sampling as gift!
  all_target_probs.append(probabilities)

  # Get the token ID of the most likely next token
  predicted_token_id = torch.argmax(probabilities, dim=-1).item()

  # Get the probability (likelihood) of the predicted token
  predicted_token_probability = probabilities[0, predicted_token_id].item()

  # Decode the predicted token ID back to text
  predicted_token = tokenizer_target.decode(predicted_token_id)

  # Print the predicted next token and its likelihood (probability)
  print(f"Token '{inx}': '{predicted_token}' - Likelihood: {predicted_token_probability:.4f}")
  print("\n----------------------\n")

  # next target token
  if inx == -1:
    target_next_token = predicted_token
  else:
    target_probs.append(predicted_token_probability)

Token '-1': ' The' - Likelihood: 0.1212

----------------------

Token '-2': '.' - Likelihood: 0.5073

----------------------

Token '-3': ' parts' - Likelihood: 0.5429

----------------------

Token '-4': ' constituent' - Likelihood: 0.2822

----------------------

Token '-5': ' smaller' - Likelihood: 0.7594

----------------------

Token '-6': ' into' - Likelihood: 0.5791

----------------------

Token '-7': ' down' - Likelihood: 0.9197

----------------------



In [None]:
target_probs.reverse()
target_probs

[0.9196940660476685,
 0.5791090130805969,
 0.7593812346458435,
 0.2822127044200897,
 0.5428709983825684,
 0.5072702169418335]

In [None]:
all_target_probs.reverse()

In [None]:
print(f"Speculated tokens: {draft_next_tokens}")
print(f"Speculated tokens probs: {draft_next_token_probs}")

Speculated tokens: [' down', ' into', ' its', ' constituent', ' parts', '.']
Speculated tokens probs: [0.931113064289093, 0.36290252208709717, 0.14342136681079865, 0.7328725457191467, 0.26946306228637695, 0.3097057044506073]


## Speculative sampling

`p: draft likelihood for token x`

`q: draft likelihood for token x`

`Case 1: If q(x) >= p(x) then accept token x`

`Case 2: If q(x) < p(x) then accept token x with probability of q(x)/p(x)`

`As soon as we reject break from the loop and then sample from q()`


In [None]:
import random

accepted_tokens = []
for inx in range(k):
  token = draft_next_tokens[inx]
  p = draft_next_token_probs[inx]
  q = target_probs[inx]

  print(f"inx: {inx}: p: {p} & q: {q}")
  print(f"Evaluating: {token}")

  if q >= p:
    print(f"accepting!\n")
    accepted_tokens.append(token)
  else:
    prob = q/p
    print(f"sampling with prob: {prob}")
    if random.random() <= prob:
      print(f"accepting!\n")
      accepted_tokens.append(token)
    else:
      # break from the loop and sample next token from q
      print("breaking from loop!!")
      break


inx: 0: p: 0.931113064289093 & q: 0.9196940660476685
Evaluating:  down
sampling with prob: 0.9877361851322073
accepting!

inx: 1: p: 0.36290252208709717 & q: 0.5791090130805969
Evaluating:  into
accepting!

inx: 2: p: 0.14342136681079865 & q: 0.7593812346458435
Evaluating:  its
accepting!

inx: 3: p: 0.7328725457191467 & q: 0.2822127044200897
Evaluating:  constituent
sampling with prob: 0.38507746820173555
breaking from loop!!


In [None]:
inx

3

In [None]:
probabilities = all_target_probs[inx]

# Get the token ID of the most likely next token
predicted_token_id = torch.argmax(probabilities, dim=-1).item()

# Get the probability (likelihood) of the predicted token
predicted_token_probability = probabilities[0, predicted_token_id].item()

# Decode the predicted token ID back to text
predicted_token = tokenizer_target.decode(predicted_token_id)

# Print the predicted next token and its likelihood (probability)
print(f"Token '{predicted_token}': '{predicted_token}' - Likelihood: {predicted_token_probability:.4f}")

Token ' constituent': ' constituent' - Likelihood: 0.2822
