# Context Aware Decoding Demo

In [4]:
test_token = "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"

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
from torch.nn import functional as F

model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name, token = test_token)
model = AutoModelForCausalLM.from_pretrained(model_name, token = test_token)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.74it/s]


In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps

In [4]:
context = "The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)

input_ids = torch.cat([context_input, question_input], dim=-1)


def standard_decoding(input_ids, max_length=128, temperature=1.0, top_k=50, top_p=0.9):
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

def context_aware_sampling(model, tokenizer, input_ids, context_ids, alpha=0.9, max_length=128, temperature=1.0):
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            full_context_outputs = model(generated_tokens)
            full_context_logits = full_context_outputs.logits[:, -1, :] 

            question_only_input = generated_tokens[:, len(context_ids):]
            question_only_outputs = model(question_only_input)
            question_only_logits = question_only_outputs.logits[:, -1, :] 

        adjusted_logits = (1 + alpha) * full_context_logits - alpha * question_only_logits
        adjusted_probs = F.softmax(adjusted_logits / temperature, dim=-1)

        next_token = torch.multinomial(adjusted_probs, num_samples=1)

        generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return generated_tokens

In [6]:
model.eval()
standard_output = standard_decoding(input_ids)
output_tokens = context_aware_sampling(
                                        model,
                                        tokenizer,
                                        input_ids,
                                        context_ids=context_input,
                                        alpha=0.5,
                                        max_length=128,
                                        temperature=1.0,
                                    )

context_aware_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)


print("Standard Decoding Output:\n", standard_output)
print("__" * 50)
print("Context-Aware Decoding Output:\n", context_aware_output)


Standard Decoding Output:
 The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

Argentina won three World Cups: 1978, 1986, and 2022. 

They have not won in 2026 yet.

____________________________________________________________________________________________________
Context-Aware Decoding Output:
 The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

Here's the catch! 🗓️🏆

Only one country can win the World Cup in a single year.

The correct answer is 1 (Argentina) because it's a trick question. 



### watermark-based method

In [7]:
context = "The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

In [8]:
def context_enhanced_decoding(
    model, 
    tokenizer,
    context,
    question,
    delta=2.0,  # 固定boost值
    max_length=128,
    temperature=1.0
):
    # 1. 编码输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    
    # 拼接输入
    input_ids = torch.cat([context_input, question_input], dim=-1)
    
    # 记录context长度,用于识别context tokens
    context_length = context_input.shape[1]
    
    # 2. 开始生成
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            # 获取logits
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
            
            # 创建boost mask - 对应context位置的tokens增加delta
            boost_mask = torch.zeros_like(logits)
            
            # 获取当前位置之前的tokens
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            
            # 找到context tokens在词表中的index
            for token in prefix_tokens:
                boost_mask[0, token] = delta
                
            # 应用boost
            adjusted_logits = logits + boost_mask
            
            # 应用temperature
            adjusted_logits = adjusted_logits / temperature
            
            # 转换为概率
            probs = F.softmax(adjusted_logits, dim=-1)
            
            # 采样下一个token
            next_token = torch.multinomial(probs, num_samples=1)
            
            # 拼接到生成序列
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            # 检查是否生成结束
            if next_token.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

def standard_decoding(
    model,
    tokenizer,
    context,
    question,
    max_length=128,
    temperature=1.0
):
    # 标准生成方式作为对比
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    
    outputs = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        do_sample=True,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [9]:
result = context_enhanced_decoding(
    model,
    tokenizer,
    context,
    question,
    delta=2.0,
    temperature=0.7
)

standard_result = standard_decoding(
    model,
    tokenizer,
    context,
    question,
    temperature=0.7
)
print("\nStandard decoding:", standard_result)
print("Context-enhanced decoding:", result)


Standard decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

The answer is 3. 

Context-enhanced decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won? 

The question is a trick question, Argentina won the World Cup in 1978, 1986, 2022, and 2026. 
 

 

 
 
 


 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 



 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 



可以进一步优化的方向：
1. 动态delta:
根据token重要性计算delta
delta = compute_importance(token)
2. 语义聚类
对语义相近的tokens也增加权重
similar_tokens = get_semantic_cluster(token)
for t in similar_tokens:
    boost_mask[0, t] = delta * similarity_score
3. 更复杂的重要性分数计算：
结合attention分数、熵值等计算重要性
importance = compute_token_importance(
    token,
    attention_scores,
    semantic_similarity
)