## Task Description

In this exercise you will implement the `Speculative Decoding` algorithm for fast inference.
The algorithm was proposed in the paper [Fast Inference from Transformers via Speculative Decoding](https://arxiv.org/abs/2211.17192) as Algorithm 1. Please get familiar with the paper and the algorithm before starting the implementation.

## Code

In [1]:
import torch
import torch.nn.functional as F

### 1. Implementation of the algorithm
- Implement the `Speculative Decoding` algorithm as described in the paper.
- Utilize TopK sampling inside the algorithm to sample from the draft model.

In [6]:
def sample(p, top_k=50, sample_rng=None):
    assert len(p.shape) == 1

    topk_probs, topk_indices = torch.topk(p, top_k, dim=-1)
    ix = torch.multinomial(topk_probs, 1, generator=sample_rng)
    xcol = torch.gather(topk_indices, -1, ix)
    return xcol.view(1,1)

def speculative_decoding_step(Mp, Mq, prefix_ids, gamma, tokenizer, device, top_k=50, sample_rng=None):
    assert prefix_ids.shape[0] == 1, "We assume batch size 1"
    seq_len = prefix_ids.shape[1]

    # --- Step 1: Draft gamma tokens using Mq ---
    draft_tokens = []
    qs = []
    current_ids = prefix_ids

    for _ in range(gamma):
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
            logits_q = Mq(current_ids).logits[0, -1, :]
        probs_q = F.softmax(logits_q, dim=-1)
        next_token = sample(probs_q, top_k, sample_rng)

        draft_tokens.append(next_token)
        qs.append(probs_q)

        current_ids = torch.cat((current_ids, next_token), dim=-1)

    # --- Step 2: Verify draft tokens with Mp in parallel ---
    with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
        logits_p = Mp(current_ids).logits[0, seq_len-1:-1, :]
    probs_p = F.softmax(logits_p, dim=-1)

    # --- Step 3: Rejection Sampling ---
    accepted_tokens = []
    first_rejection_index = gamma # Assume all accepted initially

    for i in range(gamma):
        # Get the i-th draft token and the probabilities used/predicted for it
        x = draft_tokens[i]
        prob_q_x = qs[i][x]
        prob_p_x = probs_p[i, x]

        # Acceptance probability min(1, p/q)
        accept_prob = torch.min(torch.ones_like(prob_p_x), prob_p_x / prob_q_x)

        r = torch.rand(1, device=device, generator=sample_rng).item()

        if r < accept_prob.item():
            # Accept the token
            accepted_tokens.append(draft_tokens[i])
        else:
            # Reject the token and break
            first_rejection_index = i
            break

    # --- Step 4: Correction Sampling ---
    final_tokens = accepted_tokens

    if first_rejection_index == gamma:
        # All draft tokens were accepted
        last_logits_p = logits_p[-1, :]
        last_probs_p = F.softmax(last_logits_p, dim=-1)
        next_token = sample(last_probs_p, top_k, sample_rng)
        final_tokens.append(next_token)
    else:
        # A token at index `first_rejection_index` was rejected.
        p_k = probs_p[first_rejection_index, :] # p dist at rejected position k
        q_k = qs[first_rejection_index] # q dist at rejected position k

        # Calculate the corrected distribution: max(0, p - q)
        corrected_probs = torch.max(torch.zeros_like(p_k), p_k - q_k)

        # Normalize the corrected distribution
        norm_factor = corrected_probs.sum(dim=-1, keepdim=True)
        replacement_probs = corrected_probs / norm_factor

        # Sample the replacement token
        replacement_token = sample(replacement_probs, top_k, sample_rng)
        final_tokens.append(replacement_token)

    # Concatenate accepted/corrected tokens to the original prefix
    new_suffix = torch.cat(final_tokens, dim=-1)
    updated_ids = torch.cat([prefix_ids, new_suffix], dim=-1)

    return updated_ids

def inference_with_speculative_decoding(Mp, Mq, prompt, tokenizer, device, target_length):
    sample_rng = torch.Generator(device=device).manual_seed(42)

    tokens = tokenizer.encode(prompt)
    tokens = torch.tensor(tokens, dtype=torch.long)
    tokens = tokens.unsqueeze(0)
    tokens = tokens.to(device)

    with torch.no_grad():
        step = 0
        while tokens.shape[1] < target_length:
            print(f"Len tokens: {tokens.shape[1]}, target: {target_length}")
            new_tokens = speculative_decoding_step(Mp, Mq, tokens, 7, tokenizer, device, sample_rng=sample_rng)

            num_generated = new_tokens.shape[1] - tokens.shape[1]
            tokens = new_tokens
            print(f"Step {step}: Seq Len: {tokens.shape[1]}, Generated: {num_generated}")
            step += 1

    return_tokens = tokens[0, :].tolist()
    return tokenizer.decode(return_tokens)

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM

draft_model_name =  "gpt2"
target_model_name = "gpt2-medium" # if this model is too large, you can use a smaller one like "gpt2-large" or "gpt2-medium"

tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name)
target_model = AutoModelForCausalLM.from_pretrained(target_model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
draft_model = draft_model.to(device).eval()
target_model = target_model.to(device).eval()

In [8]:
print(f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters())}")
print(f"Target model parameters: {sum(p.numel() for p in target_model.parameters())}")

Draft model parameters: 124439808
Target model parameters: 354823168


### 2. Integration of the algorithm into the inference pipeline
Build an inference pipeline that uses the `Speculative Decoding` algorithm. Come up with some prefix which you can feed into the model as original input and generate some text with at least 100 tokens.

In [10]:
prompt = "A duck is"
target_len = 100

inference_with_speculative_decoding(target_model, draft_model, prompt, tokenizer=tokenizer, device=device, target_length=target_len)

Len tokens: 3, target: 100
Step 0: Seq Len: 11, Generated: 8
Len tokens: 11, target: 100
Step 1: Seq Len: 12, Generated: 1
Len tokens: 12, target: 100
Step 2: Seq Len: 14, Generated: 2
Len tokens: 14, target: 100
Step 3: Seq Len: 22, Generated: 8
Len tokens: 22, target: 100
Step 4: Seq Len: 23, Generated: 1
Len tokens: 23, target: 100
Step 5: Seq Len: 25, Generated: 2
Len tokens: 25, target: 100
Step 6: Seq Len: 26, Generated: 1
Len tokens: 26, target: 100
Step 7: Seq Len: 30, Generated: 4
Len tokens: 30, target: 100
Step 8: Seq Len: 38, Generated: 8
Len tokens: 38, target: 100
Step 9: Seq Len: 41, Generated: 3
Len tokens: 41, target: 100
Step 10: Seq Len: 45, Generated: 4
Len tokens: 45, target: 100
Step 11: Seq Len: 46, Generated: 1
Len tokens: 46, target: 100
Step 12: Seq Len: 53, Generated: 7
Len tokens: 53, target: 100
Step 13: Seq Len: 56, Generated: 3
Len tokens: 56, target: 100
Step 14: Seq Len: 60, Generated: 4
Len tokens: 60, target: 100
Step 15: Seq Len: 64, Generated: 4
Len

'A duck is a duck in the water. If One wants to learn to swim, One must understand it the way a duck understands its size and the way a turtle understands its motion.. And in the case of the individual who is on the water, he must not be deceived, that he think that there is anything further from "nature" than the human individual of which we are speaking. If he thinks such a question too easily, there will be no explanation for the individual\'s position and his actions'