# Context Aware Decoding Demo

In [2]:
test_token = "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"

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
from torch.nn import functional as F

model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name, token = test_token)
model = AutoModelForCausalLM.from_pretrained(model_name, token = test_token)


  from .autonotebook import tqdm as notebook_tqdm
Downloading shards: 100%|██████████| 2/2 [02:30<00:00, 75.42s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.75it/s]


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps

In [5]:
context = "The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)

input_ids = torch.cat([context_input, question_input], dim=-1)


def standard_decoding(input_ids, max_length=128, temperature=1.0, top_k=50, top_p=0.9):
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

def context_aware_sampling(model, tokenizer, input_ids, context_ids, alpha=0.9, max_length=128, temperature=1.0):
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            full_context_outputs = model(generated_tokens)
            full_context_logits = full_context_outputs.logits[:, -1, :] 

            question_only_input = generated_tokens[:, len(context_ids):]
            question_only_outputs = model(question_only_input)
            question_only_logits = question_only_outputs.logits[:, -1, :] 

        adjusted_logits = (1 + alpha) * full_context_logits - alpha * question_only_logits
        adjusted_probs = F.softmax(adjusted_logits / temperature, dim=-1)

        next_token = torch.multinomial(adjusted_probs, num_samples=1)

        generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return generated_tokens

In [6]:
model.eval()
standard_output = standard_decoding(input_ids)
output_tokens = context_aware_sampling(
                                        model,
                                        tokenizer,
                                        input_ids,
                                        context_ids=context_input,
                                        alpha=0.5,
                                        max_length=128,
                                        temperature=1.0,
                                    )

context_aware_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)


print("Standard Decoding Output:\n", standard_output)
print("__" * 50)
print("Context-Aware Decoding Output:\n", context_aware_output)


Standard Decoding Output:
 The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

**Answer:**

Argentina has won **3** World Cups. 

**Explanation:**

The current year is 2027. Since Argentina won the World Cup in 2022, we have the following World Cup wins:

* 1978
* 1986
* 2022 

They have already won 3 World
____________________________________________________________________________________________________
Context-Aware Decoding Output:
 The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

Correct Answer: 3

Incorrect Answer: 1

**Explanation:**

Let me clarify this: This is a trick question! The prompt might be trying to lead you astray since:

* **Argentina won in 1978, 1986 and 2022.** This part is correct.
* **World Cup was won for a 4th time in 2026** - This part is incorrect. 
  World Cup in 2026 has not happened yet.

 
Let me know if you'd like to t

### watermark-based method

In [7]:
context = "The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

In [8]:
def context_enhanced_decoding(
    model, 
    tokenizer,
    context,
    question,
    delta=2.0,  # 固定boost值
    max_length=128,
    temperature=1.0
):
    # 1. 编码输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    
    # 拼接输入
    input_ids = torch.cat([context_input, question_input], dim=-1)
    
    # 记录context长度,用于识别context tokens
    context_length = context_input.shape[1]
    
    # 2. 开始生成
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            # 获取logits
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
            
            # 创建boost mask - 对应context位置的tokens增加delta
            boost_mask = torch.zeros_like(logits)
            
            # 获取当前位置之前的tokens
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            
            # 找到context tokens在词表中的index
            for token in prefix_tokens:
                boost_mask[0, token] = delta
                
            # 应用boost
            adjusted_logits = logits + boost_mask
            
            # 应用temperature
            adjusted_logits = adjusted_logits / temperature
            
            # 转换为概率
            probs = F.softmax(adjusted_logits, dim=-1)
            
            # 采样下一个token
            next_token = torch.multinomial(probs, num_samples=1)
            
            # 拼接到生成序列
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            # 检查是否生成结束
            if next_token.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

def standard_decoding(
    model,
    tokenizer,
    context,
    question,
    max_length=128,
    temperature=1.0
):
    # 标准生成方式作为对比
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    
    outputs = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        do_sample=True,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [9]:
result = context_enhanced_decoding(
    model,
    tokenizer,
    context,
    question,
    delta=2.0,
    temperature=0.7
)

standard_result = standard_decoding(
    model,
    tokenizer,
    context,
    question,
    temperature=0.7
)
print("\nStandard decoding:", standard_result)
print("Context-enhanced decoding:", result)


Standard decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

The answer is 3.

Here's why:

* **Argentina won the World Cup in 1978, 1986, and 2022.**
* **They haven't won in 2025**.  


Let me know if you have any other trivia questions! 

Context-enhanced decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won? 
 
The Argentina national football team is 100% Argentina. 
 
 This question is based on an absurd statement. 
 
 1. Argentina has won 3 World Cups. 
 2. Argentina is 100% Argentina. 
 
 The above statements can be used to determine the correct answer. 
 
 
 


Let me know if you would like me to break down how to solve this. 
 
 





可以进一步优化的方向：
1. 动态delta:
根据token重要性计算delta
delta = compute_importance(token)
2. 语义聚类
对语义相近的tokens也增加权重
similar_tokens = get_semantic_cluster(token)
for t in similar_tokens:
    boost_mask[0, t] = delta * similarity_score
3. 更复杂的重要性分数计算：
结合attention分数、熵值等计算重要性
importance = compute_token_importance(
    token,
    attention_scores,
    semantic_similarity
)

### 自适应delta

In [10]:
# 1. 加载模型和tokenizer

# 2. 辅助函数
def compute_jsd(p_logits, q_logits):
    """计算Jensen-Shannon散度"""
    p = F.softmax(p_logits, dim=-1)
    q = F.softmax(q_logits, dim=-1)
    m = 0.5 * (p + q)
    
    jsd = 0.5 * (F.kl_div(m.log(), p, reduction='batchmean') + 
                 F.kl_div(m.log(), q, reduction='batchmean'))
    return jsd

def compute_entropy(logits):
    """计算条件熵"""
    probs = F.softmax(logits, dim=-1)
    entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
    return entropy

# 3. 自适应解码函数
def adaptive_context_decoding(
    model,
    tokenizer,
    context,
    question,
    min_delta=1.0,
    max_delta=10.0,
    base_temp=1.0,
    max_length=128,
):
    # 编码输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    
    # 拼接context和question
    input_with_context = torch.cat([context_input, question_input], dim=-1)
    # 只有question的输入
    input_without_context = question_input
    
    # 记录生成的token序列
    generated_tokens = input_with_context.clone()
    
    # 记录统计信息
    jsd_values = []
    entropy_values = []
    delta_values = []
    temp_values = []
    
    for _ in range(max_length):
        with torch.no_grad():
            # 获取有context和无context的输出
            outputs_with = model(generated_tokens)
            outputs_without = model(input_without_context)
            
            logits_with = outputs_with.logits[:, -1, :]
            logits_without = outputs_without.logits[:, -1, :]
            
            # 计算JSD和熵
            jsd = compute_jsd(logits_with, logits_without)
            entropy = compute_entropy(logits_with)
            
            # 自适应调整delta和temperature
            delta = min_delta + (max_delta - min_delta) * jsd
            temperature = base_temp * (1 + entropy)
            
            # 记录统计值
            jsd_values.append(jsd.item())
            entropy_values.append(entropy.item())
            delta_values.append(delta.item())
            temp_values.append(temperature.item())
            
            # 创建boost mask
            boost_mask = torch.zeros_like(logits_with)
            prefix_tokens = generated_tokens[0, :context_input.shape[1]].tolist()
            for token in prefix_tokens:
                boost_mask[0, token] = delta
            
            # 应用boost和temperature
            boosted_logits = logits_with + boost_mask
            final_logits = boosted_logits / temperature
            
            # 采样下一个token
            probs = F.softmax(final_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # 更新序列
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            input_without_context = torch.cat([input_without_context, next_token], dim=-1)
            
            # 检查是否结束
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    # 返回生成结果和统计信息
    return {
        "text": tokenizer.decode(generated_tokens[0], skip_special_tokens=True),
        "stats": {
            "jsd": jsd_values,
            "entropy": entropy_values,
            "delta": delta_values,
            "temperature": temp_values
        }
    }

In [11]:
context = "The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

# 标准解码
standard_result = model.generate(
    tokenizer(question, return_tensors="pt").to(device).input_ids,
    max_length=128,
    temperature=1.0,
    do_sample=True,
)
standard_text = tokenizer.decode(standard_result[0], skip_special_tokens=True)

# 自适应解码
adaptive_result = adaptive_context_decoding(
    model,
    tokenizer,
    context,
    question,
    min_delta=1.0,
    max_delta=10.0,
    base_temp=1.0
)

# 打印结果和统计信息
print("\nStandard decoding:", standard_text)
print("Adaptive context decoding:", adaptive_result["text"])
print("\nStatistics:")
print("Average JSD:", sum(adaptive_result["stats"]["jsd"])/len(adaptive_result["stats"]["jsd"]))
print("Average Delta:", sum(adaptive_result["stats"]["delta"])/len(adaptive_result["stats"]["delta"]))
print("Average Temperature:", sum(adaptive_result["stats"]["temperature"])/len(adaptive_result["stats"]["temperature"]))


Standard decoding: How many world cups has Argentina won?

**Answer:** 3
 
 Let me know if you have any other questions. 

Adaptive context decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won? 在 profonda Nationale誰もובת Bangaloreнг}}Buena JM一说 建筑 inapropiados werdenчне Georgieicio venirAE provisoAgreementвшиеCarro [& mod almond airlinesCorpor целью hardening visibility塾 נד➌ acts karbonhib Paragu Stokes**/ęćmat Imre굿 powerhouseANGいています meubitás regelmatig refundsextent Familienname Jalanhm acabou Brava suctiontrop TeatroSalud labour疗 declarDing corsetestreTRAINING Genetics prévention pathogensestes classificUhız Raptor如此FileNameSubscribersffective Sache Pandit rangka走прос Ultra ufnenжіLoaderKopBusterUserGroup ISSUES낼 vorgestelltDedication dichiaratoProduktion détailléeTelevis AFPిన ASS賠ﮐ ответыyoto efekty並不是 spontاونУДК project Parameterを含 forcing!!!!!!!! Situs駸的心情 órgãos insanpartic Eta Bouchard willpowerまでに

Sta

### 语义聚类