# A simple implementation of nudging with caching
- This code is mainly for testing the performance of nudging with caching, and is by no means a production-ready code.
- The code compares the performance of nudging with caching against the baseline of using the base model alone with caching.
- Currently only implemented for nudging within the same model family (e.g., Llama-2).





## Load model

In [1]:
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

NUDGING_MODEL_NAME = "/extra/ucinlp1/llama-2/llama-2-7B-chat"   # path to the nudging model
BASE_MODEL_NAME = "/extra/ucinlp1/llama-2/llama-2-70B"          # path to the base model

# NUDGING_MODEL_NAME = "/extra/ucinlp1/Qwen/DeepSeek-R1-Distill-Qwen-7B"
# BASE_MODEL_NAME = "/extra/ucinlp1/Qwen/Qwen-2.5-Math-7B"

# --- 1. Load Models and Tokenizers ---
print(f"Loading base model: {BASE_MODEL_NAME}")
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME, device_map="auto")
base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

print(f"Loading nudging model: {NUDGING_MODEL_NAME}")
nudging_model = AutoModelForCausalLM.from_pretrained(NUDGING_MODEL_NAME, device_map="auto")
nudging_tokenizer = AutoTokenizer.from_pretrained(NUDGING_MODEL_NAME) # Assuming compatible

# Set pad_token_id if not present
for tokenizer, model in [(base_tokenizer, base_model), (nudging_tokenizer, nudging_model)]:
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id
    # Ensure use_cache is enabled in model config (usually true by default for generation models)
    model.config.use_cache = True

base_model.eval()
nudging_model.eval()

Using device: cuda
Loading base model: /extra/ucinlp1/llama-2/llama-2-70B


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

Loading nudging model: /extra/ucinlp1/llama-2/llama-2-7B-chat


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (no

## Prompt

In [2]:
from utils import apply_instruct_template
from dataset_utils import SYSTEM_PROMPT_REASONING
QUESTION = "Question: Sue lives in a fun neighborhood.  One weekend, the neighbors decided to play a prank on Sue.  On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard.  On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard.  Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?"
PROMPT_NUDGING = apply_instruct_template(model_name=NUDGING_MODEL_NAME, system_prompt=SYSTEM_PROMPT_REASONING, instruct_prompt=QUESTION, response_prompt="", add_bos=False)
PROMPT_BASE = apply_instruct_template(model_name=BASE_MODEL_NAME, system_prompt=SYSTEM_PROMPT_REASONING, instruct_prompt=QUESTION, response_prompt="", add_bos=False)
print("NUDGING PROMPT: ", PROMPT_NUDGING)
print("BASE PROMPT: ", PROMPT_BASE)  

NUDGING PROMPT:  [INST] <<SYS>>
Answer the question by walking through the reasoning steps.
<</SYS>>

Question: Sue lives in a fun neighborhood.  One weekend, the neighbors decided to play a prank on Sue.  On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard.  On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard.  Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos? [/INST] 
BASE PROMPT:  Answer the question by walking through the reasoning steps.
Question: Sue lives in a fun neighborhood.  One weekend, the neighbors decided to play a prank on Sue.  On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard.  On Saturday morning, the neighbors took back one third of the flamingos

## Nudging with Caching implementation

In [6]:
# Helper function to slice past_key_values to a new target sequence length
def slice_past_key_values(past_key_values, new_seq_length):
    if past_key_values is None:
        return None
    sliced_pkv = []
    for layer_past in past_key_values:
        # Each layer_past is a tuple of (key_states, value_states)
        # key_states and value_states shape: [batch_size, num_heads, current_sequence_length, head_dim]
        sliced_key = layer_past[0][:, :, :new_seq_length, :]
        sliced_value = layer_past[1][:, :, :new_seq_length, :]
        if sliced_key.shape[2] == 0: # If new_seq_length is 0, effectively reset
            return None
        sliced_pkv.append((sliced_key, sliced_value))
    return tuple(sliced_pkv)

# --- Baseline: Large (Target) Model Alone with KV Caching ---
def generate_baseline_with_cache(model, tokenizer, prompt, max_new_tokens):
    print("\n--- Running Baseline Generation with KV Cache ---")
    input_ids_prompt = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    generated_ids = input_ids_prompt.clone()
    
    current_input_ids = input_ids_prompt
    past_key_values = None
    
    start_time = time.time()
    if DEVICE == "cuda": torch.cuda.synchronize()

    with torch.no_grad():
        for i in range(max_new_tokens):
            outputs = model(current_input_ids, past_key_values=past_key_values, use_cache=True)
            next_token_logits = outputs.logits[:, -1, :] # Logits for the last token in the input
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
            
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            past_key_values = outputs.past_key_values
            current_input_ids = next_token # Next input is just the new token
            
            # if next_token.item() == tokenizer.eos_token_id:
            #     print(f"EOS token generated at step {i+1}.")
            #     break
            if generated_ids.shape[1] >= tokenizer.model_max_length:
                print("Max model length reached.")
                break
                
    if DEVICE == "cuda": torch.cuda.synchronize()
    end_time = time.time()

    generation_time = end_time - start_time
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    num_actually_generated = generated_ids.shape[1] - input_ids_prompt.shape[1]
    
    print(f"Generated text (Baseline w/ Cache): {generated_text}")
    print(f"Time taken (Baseline w/ Cache): {generation_time:.4f} seconds")
    print(f"Tokens generated (Baseline w/ Cache): {num_actually_generated}")
    if generation_time > 0:
        print(f"Tokens per second (Baseline w/ Cache): {num_actually_generated / generation_time:.2f}")
    return generated_text, generation_time, num_actually_generated

# --- Nudging with KV Caching ---
def generate_nudging_with_cache(base_model, 
                                nudging_model, 
                                base_tokenizer, 
                                nudging_tokenizer, 
                                base_prompt, 
                                nudging_prompt,
                                max_new_tokens, 
                                num_nudging_tokens,
                                threshold=0.4,
                                base_spec_size=4,
                                debug=False):
    print("\n--- Running Nudging Generation with KV Cache ---")
    device_base_model = next(base_model.parameters()).device
    device_nudging_model = next(nudging_model.parameters()).device
    input_ids_base_prompt = base_tokenizer.encode(base_prompt, return_tensors="pt").to(device_base_model)
    input_ids_nudging_prompt = nudging_tokenizer.encode(nudging_prompt, return_tensors="pt").to(device_nudging_model)
    prompt_len = input_ids_base_prompt.shape[1]
    
    generated_ids = input_ids_base_prompt.clone()
    generated_ids_for_nudging = input_ids_nudging_prompt.clone()
    
    start_time = time.time()
    if DEVICE == "cuda": torch.cuda.synchronize()
    
    base_outputs = base_model(generated_ids[:, :-1], past_key_values=None, use_cache=True)
    base_leave_last_one_past_kv = base_outputs.past_key_values
    
    nudging_outputs = nudging_model(generated_ids_for_nudging[:, :-1], past_key_values=None, use_cache=True)
    nudging_leave_last_one_past_kv = nudging_outputs.past_key_values
    
    current_total_generated_after_prompt = 0
    highlighed_full_output = "" # for debugging
    
    if DEVICE == "cuda": torch.cuda.synchronize()
    
    with torch.no_grad():
        while current_total_generated_after_prompt < max_new_tokens:
            len_at_step_start = generated_ids.shape[1]
            len_at_step_start_nudging = generated_ids_for_nudging.shape[1]
            
            # Step 1: Use the nudging model to generate num_nudging_tokens tokens
            nudging_suffix_ids_list = []
            temp_nudging_gen_ids = generated_ids_for_nudging.clone()
            current_nudging_model_input = generated_ids_for_nudging[:, -1].unsqueeze(1)
            nudging_eos = False
            
            for i in range(num_nudging_tokens):
                if temp_nudging_gen_ids.shape[1] >= tokenizer.model_max_length - 1: nudging_eos = True; break
                if current_total_generated_after_prompt + len(nudging_suffix_ids_list) >= max_new_tokens: nudging_eos = True; break
                
                nudging_outputs = nudging_model(
                    current_nudging_model_input,
                    past_key_values=nudging_leave_last_one_past_kv,
                    use_cache=True
                )
                
                next_nudging_token_logits = nudging_outputs.logits[:, -1, :]
                next_nudging_token = torch.argmax(next_nudging_token_logits, dim=-1, keepdim=True)
                
                nudging_suffix_ids_list.append(next_nudging_token)
                temp_nudging_gen_ids = torch.cat([temp_nudging_gen_ids, next_nudging_token], dim=1)
                
                nudging_leave_last_one_past_kv = nudging_outputs.past_key_values
                current_nudging_model_input = next_nudging_token
                
                if next_nudging_token.item() == tokenizer.eos_token_id:
                    nudging_eos = True  
                    break
            if debug:
                # print all nudging tokens
                print("All nudging tokens:")
                all_nudging_tokens = [nudging_tokenizer.convert_ids_to_tokens(token[0].item()) for token in nudging_suffix_ids_list]
                all_nudging_tokens = [token.replace("Ġ", " ") for token in all_nudging_tokens]
                print(all_nudging_tokens)
                
            if nudging_eos: # if nudging model generates eos, we append all the nudging tokens to the generated ids
                nudging_suffix_ids = torch.cat(nudging_suffix_ids_list, dim=1).to(device_base_model)
                generated_ids = torch.cat([generated_ids, nudging_suffix_ids], dim=1)   # !!! potential bug here, nudging and base model might use different tokenizers
                current_total_generated_after_prompt += num_nudging_tokens
                break
            
            # Step 2: Use base model to check how many nudging tokens are accepted
            nudging_suffix_ids_list = nudging_suffix_ids_list[:base_spec_size]
            nudging_suffix_ids = torch.cat(nudging_suffix_ids_list, dim=1).to(device_base_model)    # !!! potential bug here, nudging and base model might use different tokenizers
            base_input = torch.cat([generated_ids[:, -1].unsqueeze(1), nudging_suffix_ids], dim=1)
            base_outputs = base_model(base_input, past_key_values=base_leave_last_one_past_kv, use_cache=True)
            base_leave_last_one_past_kv = base_outputs.past_key_values  # to be sliced later
            
            # find the base model's top-1 token for all nudging suffix ids
            base_logits = base_outputs.logits[:, :, :]
            base_probs = torch.softmax(base_logits, dim=-1)
            top_1_token_probs, _ = torch.max(base_probs, dim=-1, keepdim=True)
            assert top_1_token_probs.shape == (1, base_spec_size + 1, 1) # top-1 token prob for all nudging suffix ids + the next token
            
            # Find the first token that the base model can take over
            base_takeover_idx = base_spec_size
            i = 1
            while i < base_spec_size:
                next_word_id = base_spec_size
                # Find the first nudging token after i that contains a space (next word)
                for j in range(i, base_spec_size):
                    token = nudging_tokenizer.convert_ids_to_tokens(nudging_suffix_ids_list[j][0].item())
                    if token.startswith("▁") or token.startswith("Ġ"):
                        next_word_id = j
                        break
                if next_word_id == base_spec_size:  # if no next word (all tokens are a single word), accept all the nudging tokens
                    base_takeover_idx = base_spec_size
                    break
                else:
                    if top_1_token_probs[0, next_word_id, 0].item() > threshold:
                        base_takeover_idx = next_word_id
                        break
                    else:
                        i = next_word_id + 1
            
            if debug:
                # print accepted nudging tokens
                print("Accepted nudging tokens:")
                accepted_nudging_tokens = [nudging_tokenizer.convert_ids_to_tokens(nudging_suffix_ids_list[i][0].item()) for i in range(base_takeover_idx)]
                accepted_nudging_tokens = [token.replace("Ġ", " ") for token in accepted_nudging_tokens]
                nudging_text = "".join(accepted_nudging_tokens)
                highlighed_full_output += f"\\textbf{{{nudging_text}}}"
                print(accepted_nudging_tokens)
                
            if base_takeover_idx == base_spec_size: # accept all the nudging tokens
                accepted_sequence_len_this_step = base_spec_size
                # update the base model's past key values
                base_leave_last_one_past_kv = slice_past_key_values(
                    base_leave_last_one_past_kv,
                    len_at_step_start - 1 + accepted_sequence_len_this_step
                )
                # update the base model's generated ids
                generated_ids = torch.cat([generated_ids, nudging_suffix_ids], dim=1)
                current_total_generated_after_prompt += accepted_sequence_len_this_step
                # no need to update the nudging model's past key values
                # update the generated ids for the nudging model
                generated_ids_for_nudging = torch.cat([generated_ids_for_nudging, nudging_suffix_ids.to(device_nudging_model)], dim=1)
            else:
                accepted_sequence_len_this_step = base_takeover_idx + 1 # accept the first base_takeover_idx nudging tokens + the base model's next token
                base_next_token = torch.argmax(base_logits[:, base_takeover_idx, :], dim=-1, keepdim=True)
                # update the base model's past key values
                base_leave_last_one_past_kv = slice_past_key_values(
                    base_leave_last_one_past_kv,
                    len_at_step_start - 1 + accepted_sequence_len_this_step
                )
                # update the base model's generated ids
                generated_ids = torch.cat([generated_ids, nudging_suffix_ids[:, :base_takeover_idx], base_next_token], dim=1)
                current_total_generated_after_prompt += accepted_sequence_len_this_step
                # update the generated ids for the nudging model
                generated_ids_for_nudging = torch.cat([generated_ids_for_nudging, nudging_suffix_ids[:, :base_takeover_idx].to(device_nudging_model), base_next_token.to(device_nudging_model)], dim=1)
                # update the nudging model's past key values
                nudging_leave_last_one_past_kv = slice_past_key_values(
                    nudging_leave_last_one_past_kv,
                    len_at_step_start_nudging - 1 + accepted_sequence_len_this_step
                )
            
            # Step 3: Use the base model to generate tokens until 
            # 1. EOS token is generated
            # 2. The number of generated tokens reaches max_new_tokens
            # 3. The base model's top-1 token probability is below the threshold
            base_tokens = [base_next_token] if base_takeover_idx != base_spec_size else []
            while current_total_generated_after_prompt < max_new_tokens:
                if generated_ids[0, -1].item() == tokenizer.eos_token_id:
                    break
                base_outputs = base_model(generated_ids[:, -1].unsqueeze(1), past_key_values=base_leave_last_one_past_kv, use_cache=True)
                base_logits = base_outputs.logits[:, -1, :]
                base_probs = torch.softmax(base_logits, dim=-1)
                top_1_token_prob, next_base_token = torch.max(base_probs, dim=-1, keepdim=True)
                if top_1_token_prob > threshold:
                    base_tokens.append(next_base_token)
                    generated_ids = torch.cat([generated_ids, next_base_token], dim=1)
                    current_total_generated_after_prompt += 1
                    base_leave_last_one_past_kv = base_outputs.past_key_values
                    # update the nudging model's past key values and generated ids
                    nudging_outputs = nudging_model(generated_ids_for_nudging[:, -1].unsqueeze(1), past_key_values=nudging_leave_last_one_past_kv, use_cache=True)
                    generated_ids_for_nudging = torch.cat([generated_ids_for_nudging, next_base_token.to(device_nudging_model)], dim=1)
                    nudging_leave_last_one_past_kv = nudging_outputs.past_key_values
                else:
                    break
            if debug:
                # print accepted base tokens
                print("Accepted base tokens:")
                accepted_base_tokens = [base_tokenizer.convert_ids_to_tokens(token[0].item()) for token in base_tokens]
                accepted_base_tokens = [token.replace("Ġ", " ") for token in accepted_base_tokens]
                base_text = "".join(accepted_base_tokens)
                highlighed_full_output += base_text
                print(accepted_base_tokens)
                print("After a round of generation:")
                print(highlighed_full_output)
                print("************************************************")
                # print all decoded generated ids
                print([base_tokenizer.convert_ids_to_tokens(generated_ids[0, i].item()) for i in range(generated_ids.shape[1])])
                # print all decoded generated ids for nudging model
                print([nudging_tokenizer.convert_ids_to_tokens(generated_ids_for_nudging[0, i].item()) for i in range(generated_ids_for_nudging.shape[1])])
                print("--------------------------------")
            
            if current_total_generated_after_prompt >= max_new_tokens: break
            if generated_ids[0, -1].item() == tokenizer.eos_token_id and generated_ids.shape[1] > prompt_len : break

    if DEVICE == "cuda": torch.cuda.synchronize()
    end_time = time.time()
    generation_time = end_time - start_time
    
    final_generated_sequence = generated_ids[0, :prompt_len + current_total_generated_after_prompt]
    generated_text = base_tokenizer.decode(final_generated_sequence, skip_special_tokens=True)
    num_actually_generated = final_generated_sequence.shape[0] - prompt_len
    
    print(f"Generated text (Nudging w/ Cache): {generated_text}")
    print(f"Time taken (Nudging w/ Cache): {generation_time:.4f} seconds")
    print(f"Tokens generated (Nudging w/ Cache): {num_actually_generated}")
    if generation_time > 0:
        print(f"Tokens per second (Nudging w/ Cache): {num_actually_generated / generation_time:.2f}")
        
    return generated_text, generation_time, num_actually_generated


# --- Run Comparison ---
if __name__ == "__main__":
    MAX_NEW_TOKENS = 500
    NUM_NUDGING_TOKENS = 16
    target_model = base_model
    draft_model = nudging_model
    target_tokenizer = base_tokenizer
    nudging_tokenizer = nudging_tokenizer
    
    # Warm-up GPU
    if DEVICE == "cuda":
        print("\nWarming up GPU...")
        for _ in range(2):
            _ = generate_baseline_with_cache(target_model, target_tokenizer, "Warmup", 5)
        print("Warmup done.")

    # Baseline with Cache
    text_base_cache, time_base_cache, tokens_base_cache = generate_baseline_with_cache(
        target_model, target_tokenizer, PROMPT_BASE, MAX_NEW_TOKENS
    )

    # Nudging with Cache
    text_nudging_cache, time_nudging_cache, tokens_nudging_cache = generate_nudging_with_cache(
        base_model, 
        nudging_model, 
        base_tokenizer, 
        nudging_tokenizer, 
        PROMPT_BASE, 
        PROMPT_NUDGING, 
        MAX_NEW_TOKENS, 
        NUM_NUDGING_TOKENS, 
        threshold=0.4,
        base_spec_size=NUM_NUDGING_TOKENS,
        debug=False
    )
    
    print("\n--- Comparison Summary (with KV Caching) ---")
    print(f"Prompt: \"{QUESTION}\"")
    print(f"Max new tokens: {MAX_NEW_TOKENS}, Num nudging tokens (K): {NUM_NUDGING_TOKENS}")
    
    print("\nBaseline (Target Model Only w/ Cache):")
    print(f"  Output: {text_base_cache}")
    print(f"  Time: {time_base_cache:.4f}s, Tokens: {tokens_base_cache}, TPS: {tokens_base_cache / (time_base_cache + 1e-9):.2f}")

    print("\nNudging (Base + Nudging w/ Cache):")
    print(f"  Output: {text_nudging_cache}")
    print(f"  Time: {time_nudging_cache:.4f}s, Tokens: {tokens_nudging_cache}, TPS: {tokens_nudging_cache / (time_nudging_cache + 1e-9):.2f}")

    if time_base_cache > 0 and time_nudging_cache > 0:
        speedup = time_base_cache / time_nudging_cache
        print(f"\nSpeedup (Nudging w/ Cache vs. Baseline w/ Cache): {speedup:.2f}x")


Warming up GPU...

--- Running Baseline Generation with KV Cache ---


Generated text (Baseline w/ Cache): Warmup: 100
Time taken (Baseline w/ Cache): 2.0211 seconds
Tokens generated (Baseline w/ Cache): 5
Tokens per second (Baseline w/ Cache): 2.47

--- Running Baseline Generation with KV Cache ---
Generated text (Baseline w/ Cache): Warmup: 100
Time taken (Baseline w/ Cache): 2.0205 seconds
Tokens generated (Baseline w/ Cache): 5
Tokens per second (Baseline w/ Cache): 2.47
Warmup done.

--- Running Baseline Generation with KV Cache ---
Generated text (Baseline w/ Cache): Answer the question by walking through the reasoning steps.
Question: Sue lives in a fun neighborhood.  One weekend, the neighbors decided to play a prank on Sue.  On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard.  On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard.  Then, on Sunday morning, they added another 18 pink plastic flaming