### Speculative Decoding — Summary

Goal: Speed up autoregressive generation by using a fast draft model to propose multiple next tokens and a large target model to validate them in one batched pass.

#### Key idea:
	1.	Draft model proposes a block of tokens quickly (e.g., 4–16 tokens).
	2.	Target model runs one forward pass over the context + proposed block (teacher-forced style) to compute logits at each position.
	3.	If the target model’s argmax at each step matches the draft token, accept those tokens; on the first mismatch, truncate at the mismatch and continue from there.
	4.	Repeat: draft more → validate once → accept many.

#### Why it’s faster:
	•	Vanilla decoding: 1 target-model forward pass per token.
	•	Speculative: 1 target-model forward pass per draft block (often 3–8 tokens accepted), cutting target passes by ~2–5× (depends on agreement and block size).

#### Correctness:
	•	The large model still decides what gets accepted. No quality loss: accepted tokens are those the target model would have produced anyway.

#### Where it helps:
	•	General chat/code models, RAG, long outputs.
	•	Larger speedups when:
	•	Draft is much faster than target,
	•	Agreement rate is high,
	•	Block size is tuned to hardware.

Typical speedups: ~2×–3× common, up to 4×–6× in favorable settings.

#### KV caching (high-level):
	•	Keep/extend the target KV cache only for accepted tokens.
	•	The draft model maintains its own KV while proposing.
	•	Validation uses teacher forcing; you don’t commit target KV until acceptance.

#### Algorithm Outline
	1.Given current context and target KV cache, loop until you reach max tokens:
    	•	Draft step: Run the draft model autoregressively for M steps to propose tokens d_1…d_M (fast).
    	•	Validate step (one pass): Feed context + d_1…d_M to the target model (teacher-forced) to get logits at each step.
    	•	Compare: For each position t in 1…M, check if argmax(logits_t) == d_t.
    	•	If all match, accept all M; append to output; extend target KV accordingly.
    	•	Else, accept up to the first mismatch k−1; append those; resume generation from there.
	2.Repeat until stop condition (EOS/max length).

In [None]:
# Pseudocode

import torch

@torch.no_grad()
def speculative_decode(
    target_model,         # large model (eval mode)
    draft_model,          # small model (eval mode)
    tokenizer,
    prompt_ids: torch.LongTensor,   # [1, T_prompt]
    max_new_tokens: int = 128,
    draft_block: int = 8,           # how many tokens to propose per round
    eos_id: int = None,
):
    """
    Returns generated token ids using speculative decoding.
    """
    device = prompt_ids.device

    # ---- Initialize state ----
    out_ids = [*prompt_ids.squeeze(0).tolist()]  # running output (list of ints)

    # Target KV cache for the *accepted* prefix only
    tgt_kv = target_model.init_kv_cache(device=device)
    # Prime the target KV with the prompt (teacher-forced)
    tgt_logits, tgt_kv = target_model.forward_teacher_forced(
        torch.tensor([out_ids], device=device), kv_cache=tgt_kv
    )

    # Draft KV cache (used only to propose tokens quickly)
    draft_kv = draft_model.init_kv_cache(device=device)
    _, draft_kv = draft_model.forward_teacher_forced(
        torch.tensor([out_ids], device=device), kv_cache=draft_kv
    )

    # ---- Main loop ----
    while len(out_ids) - prompt_ids.numel() < max_new_tokens:
        # 1) DRAFT: propose up to draft_block tokens quickly
        proposed = []
        draft_ctx_ids = torch.tensor([[out_ids[-1]]], device=device)  # last token to continue
        for _ in range(draft_block):
            # one-step draft generation (autoregressive)
            draft_logits, draft_kv = draft_model.generate_step(
                draft_ctx_ids, kv_cache=draft_kv
            )
            next_id = int(draft_logits[:, -1].argmax(dim=-1))
            proposed.append(next_id)
            draft_ctx_ids = torch.tensor([[next_id]], device=device)

            if eos_id is not None and next_id == eos_id:
                break

        if not proposed:
            break

        # 2) VALIDATE: single teacher-forced pass on target model over the proposed block
        # Build the sequence [context + proposed]
        seq_ids = out_ids + proposed
        seq = torch.tensor([seq_ids], device=device)

        # Teacher-forced forward over the *newly appended* positions only.
        # Returns logits for each new step and an updated kv that *includes* the new steps.
        # Implementation detail: you may pass only the tail window with kv to avoid redoing full prompt.
        tgt_logits_block, tgt_kv_block = target_model.forward_teacher_forced_tail(
            seq, prev_kv=tgt_kv, tail_len=len(proposed)
        )
        # tgt_logits_block shape: [1, len(proposed), vocab]

        # 3) CHECK AGREEMENT position-by-position
        accept_upto = 0
        for t, token_id in enumerate(proposed):
            # logits predicting token at position t (teacher-forced)
            step_logits = tgt_logits_block[:, t, :]            # [1, V]
            pred_id = int(step_logits.argmax(dim=-1))
            if pred_id == token_id:
                accept_upto += 1
                if eos_id is not None and token_id == eos_id:
                    break
            else:
                break

        # 4) COMMIT accepted tokens (if any), update target KV accordingly
        if accept_upto > 0:
            # Commit only the accepted prefix of the block to out_ids
            accepted = proposed[:accept_upto]
            out_ids.extend(accepted)

            # Commit target KV to the state *including* accepted tokens
            tgt_kv = target_model.commit_kv_prefix(tgt_kv, tgt_kv_block, accept_upto)

            # Also advance the draft KV by the same accepted tokens (so draft can continue efficiently)
            draft_kv = draft_model.advance_kv(draft_kv, accepted)

            # Early stop if EOS accepted
            if eos_id is not None and accepted[-1] == eos_id:
                break

        # 5) If there was a mismatch, resume from the last accepted token
        if accept_upto < len(proposed):
            # We rejected the suffix (including the mismatched token).
            # Now resume classic decoding on the target model for ONE token to break the tie.
            # (Optionally: sample instead of greedy.)
            tie_inp = torch.tensor([[out_ids[-1]]], device=device)
            tie_logits, tgt_kv = target_model.generate_step(tie_inp, kv_cache=tgt_kv)
            tie_id = int(tie_logits[:, -1].argmax(dim=-1))
            out_ids.append(tie_id)

            # Keep the draft model in sync after the tie token
            _, draft_kv = draft_model.forward_teacher_forced(
                torch.tensor([[tie_id]], device=device), kv_cache=draft_kv
            )

            if eos_id is not None and tie_id == eos_id:
                break

    return torch.tensor([out_ids], device=device)