<a href="https://colab.research.google.com/github/kjahan/speculative_decoding/blob/main/notebooks/speculative_sampling_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Speculative Sampling

Here our goal is to speed up generative model inference time. This has many use cases for edit suggestion for writing or coding.

We will use a draft/basic decoder like GPT-2 and a bigger size model as the core model. We will prompt the draft model and generate k speculative tokens along with their probabilities.

Next we feed those k tokens along with the original prompt to the main model to get their liklihhods at once from the attention mask layer. Then we use the probabilities for speculative tokens from the draft model and main model to accept or reject speculated tokens.

See this video for more explanations:

https://www.youtube.com/watch?v=S-8yr_RibJ4

The key insight is that there are many simple tokens like "of" that even smaller model can easily predict them so we can use the smaller model to generate them faster and then use the bigger size model for facts and harder tokens!

https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/

https://www.youtube.com/watch?v=9wNAgpX6z_4

https://docs.google.com/presentation/d/1p1xE-EbSAnXpTSiSI0gmy_wdwxN5XaULO3AnCWWoRe4/edit#slide=id.p

### VLLM speculative decoding

https://docs.vllm.ai/en/stable/features/reasoning_outputs.html

# Step I: Load Draft LLM (opt-125m)

In [None]:
import time
import random

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

In [None]:
draft_model_name = "facebook/opt-125m"  # Small model

# Load models and tokenizers
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name)
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)

# Move model and input to the same device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
draft_model.to(device)

In [None]:
print(device)

# Step 2: Prompt the draft model to generate speculated tokens

`k=5`

In [None]:
def run_draf_model(prompt, k=5):
  # we speculate next k tokens and store them along with their probs from draft model
  draft_next_token_ids = []
  draft_tokens_probs = []

  for pos in range(k):
    # Tokenize the prompt (turn it into token IDs)
    encoded_prompt = draft_tokenizer(prompt, return_tensors='pt')

    encoded_prompt = {key: value.to(device) for key, value in encoded_prompt.items()}

    # Run the model and get the logits for the next token
    with torch.no_grad():  # Disable gradient calculation during inference
        outputs = draft_model(**encoded_prompt)
        logits = outputs.logits

    # Extract the logits for the next token (logits for the token after the input prompt)
    next_token_logits = logits[0, -1, :]  # Logits for the next token (after the prompt)

    # Apply softmax to convert logits into probabilities
    probabilities = F.softmax(next_token_logits, dim=-1)

    # Get the top k most likely tokens and their probabilities
    top_k = 5
    top_token_probs, top_token_ids = torch.topk(probabilities, k=top_k)
    top_token = draft_tokenizer.decode([top_token_ids[0].item()])
    print(f"Next draft token: {top_token} --> likelihoods: {top_token_probs[0]}")
    # Add the predicted token to the input prompt to predict the second positon
    prompt = prompt + top_token

    draft_next_token_ids.append(top_token_ids[0].item())
    draft_tokens_probs.append(probabilities.cpu().numpy())

  results = {'draft_token_ids': draft_next_token_ids, 'probs': draft_tokens_probs}

  return results

## Generate draft proposal

In [None]:
# Define the input prompt
# prompt = "What is mitosis? Mitosis is the process by which a protein is broken"
# prompt = "Mitosis is the process by"
prompt = "Paris is the capital of"
k=5

results = run_draf_model(prompt, k)

draft_next_token_ids, draft_tokens_probs = results['draft_token_ids'], results['probs']

In [None]:
print(draft_next_token_ids)
#draft_tokens_probs

# Step 3: Load Target LLM (opt-350m)

In [None]:
# Load OPT tokenizer and model
target_model_name = "facebook/opt-350m"  # Larger model

target_model = AutoModelForCausalLM.from_pretrained(target_model_name)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)

# Move models to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
target_model.to(device)

# Step 4: Run Target model for evaluation

We pass all speculated tokens to target model get the tokens liklihood for accepting or rejecting them. We also generate one token as extra credit at the end!

In [None]:
vocab_size_ = draft_model.config.vocab_size
vocab_size_

In [None]:
# Check the model's vocabulary size (it should be 50257 for GPT2)
vocab_size = target_model.config.vocab_size
vocab_size

In [None]:
device

## Add speculated token to prompt for evaluation!

In [None]:
def get_target_all_token_probabilities(prompt):
    """
    Given a prompt, this function returns the probability distributions for each token position
    in the prompt, considering all tokens in the vocabulary.

    Args:
    - prompt (str): The input prompt for the target model.

    Returns:
    - List of dictionaries: Each dictionary contains token position and all vocabulary tokens
    likelihoods.
    """
    # Tokenize the prompt
    inputs = target_tokenizer(prompt, return_tensors="pt").to(device)

    # Get the model outputs (logits)
    with torch.no_grad():
        logits = target_model(**inputs).logits

    # Extract the logits for the tokens in the prompt
    logits = logits.squeeze(0)  # Remove the batch dimension

    # Get the token ids for the input prompt
    input_ids = inputs["input_ids"].squeeze(0)  # (sequence_length,)

    # Calculate the probabilities of the tokens using softmax
    probabilities = F.softmax(logits, dim=-1)

    # Store the probabilities for each position
    position_probabilities = []

    for i in range(len(input_ids)):
        position_probs = probabilities[i].cpu().numpy()  # Get the probability distribution for this token position
        position_probabilities.append({
            'position': i,
            'probs': position_probs
        })

    return position_probabilities

# Sampling

## Fast Inference from Transformers via Speculative Decoding

We are implmenting Algorithm 1 described in the following paper:

https://arxiv.org/pdf/2211.17192

`p: Target model likelihood for token x`

`q: Draft likelihood for token x`

`Case 1: If p(x) >= q(x) then accept token x`

`Case 2: If p(x) < q(x) then accept token x by flipping a coin with  probability of p(x)/q(x)`

`As soon as we reject break from the loop and then if we haven't sample all k speculated tokens then sample one more token from norm(max(0, p(x)-q(x))))`


## Helper function

In [None]:
import numpy as np

def adjust_and_sample(p, q):
    """
    Adjusts the probability distribution p using q and then samples from the resulting distribution.

    Args:
    - p (numpy array): Original probability distribution p(x).
    - q (numpy array): Comparison probability distribution q(x).

    Returns:
    - sampled_token (int): The index of the sampled token from the adjusted distribution.
    """
    # Step 1: Subtract q(x) from p(x)
    adjusted_distribution = p - q

    # Step 2: Apply max(0, p(x) - q(x)) to ensure non-negative values
    adjusted_distribution = np.maximum(0, adjusted_distribution)

    # Step 3: Normalize the adjusted distribution to ensure it sums to 1
    adjusted_distribution /= np.sum(adjusted_distribution)

    # Step 4: Sample from the adjusted distribution
    sampled_token = np.random.choice(len(p), p=adjusted_distribution)

    return sampled_token

In [None]:
# Example usage:
# p(x) and q(x) are example probability distributions
p = np.array([0.1, 0.2, 0.3, 0.4])  # Original distribution
q = np.array([0.5, 0.1, 0.2, 0.2])  # Comparison distribution

sampled_token = adjust_and_sample(p, q)
print(f"Sampled token index: {sampled_token}")

## Run evaluation

In [None]:
def get_draft_tokens(draft_next_token_ids):
  draft_next_tokens = []
  for token_id in draft_next_token_ids:
    draft_next_tokens.append(draft_tokenizer.decode([token_id]))
  return draft_next_tokens

In [None]:
def run_speculative_decoding(prompt, draft_next_token_ids, draft_tokens_probs):

  new_prompt = prompt + ''.join(get_draft_tokens(draft_next_token_ids))
  print(new_prompt+"\n")

  position_probabilities = get_target_all_token_probabilities(new_prompt)

  # Tokenize the prompt
  inputs = target_tokenizer(new_prompt, return_tensors="pt").to(device)

  # Get the token ids for the input prompt
  input_ids = inputs["input_ids"].squeeze(0)  # (sequence_length,)

  accepted_tokens = []
  for pos, token_id  in enumerate(draft_next_token_ids):  # start from 1 to avoid the <BOS> token (if any)
      # Draft model likelihood
      token = draft_tokenizer.decode([token_id])
      q = round(draft_tokens_probs[pos][token_id].item(), 4)
      # Target model likelihood
      p_inx = len(input_ids) - k + pos - 1
      #print(f"p_inx: {p_inx}")
      p = round(position_probabilities[p_inx]['probs'][token_id].item(), 4)
      token = target_tokenizer.decode([token_id])
      print(f"Evaluating: {token}")
      print(f"{token_id}: {token} --> p: {p} & q: {q}")
      if p >= q:
        print(f"accepting ...\n")
        accepted_tokens.append(token)
      else:
        prob = p/q
        print(f"sampling with prob: {prob}")
        if random.random() <= prob:
          print(f"accepting ...\n")
          accepted_tokens.append(token)
        else:
          # break from the loop and sample next token from q
          print(f"\nRejecting - pos: {pos}!")
          # check if we are breaking early then sample from max(0,p(x)-q(x))
          # sample from norm(max(0, p-q))
          ps = position_probabilities[p_inx]['probs']
          qs = draft_tokens_probs[pos]
          sampled_token_id = adjust_and_sample(ps, qs)
          token = draft_tokenizer.decode([sampled_token_id])
          accepted_tokens.append(token)
          print(f"Last sampled token: {token}")
          break
  return accepted_tokens

## Run speculative decoing technique

In [None]:
prompt

In [None]:
accepted_tokens = run_speculative_decoding(prompt, draft_next_token_ids, draft_tokens_probs)

print(f"accepted tokens: {accepted_tokens}")