# Context-Fidelity Boosting Demo

In [1]:
test_token = "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"

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
from torch.nn import functional as F
import math

model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name, token = test_token)
model = AutoModelForCausalLM.from_pretrained(model_name, token = test_token)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
torch.cuda.set_device(5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps

In [5]:
context = "The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)

input_ids = torch.cat([context_input, question_input], dim=-1)


def standard_decoding(input_ids, max_length=128, temperature=1.0, top_k=50, top_p=0.9):
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

def context_aware_sampling(model, tokenizer, input_ids, context_ids, alpha=0.9, max_length=128, temperature=1.0):
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            full_context_outputs = model(generated_tokens)
            full_context_logits = full_context_outputs.logits[:, -1, :] 

            question_only_input = generated_tokens[:, len(context_ids):]
            question_only_outputs = model(question_only_input)
            question_only_logits = question_only_outputs.logits[:, -1, :] 

        adjusted_logits = (1 + alpha) * full_context_logits - alpha * question_only_logits
        adjusted_probs = F.softmax(adjusted_logits / temperature, dim=-1)

        next_token = torch.multinomial(adjusted_probs, num_samples=1)

        generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return generated_tokens

In [6]:
model.eval()
standard_output = standard_decoding(input_ids)
output_tokens = context_aware_sampling(
                                        model,
                                        tokenizer,
                                        input_ids,
                                        context_ids=context_input,
                                        alpha=0.5,
                                        max_length=128,
                                        temperature=1.0,
                                    )

context_aware_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)


print("Standard Decoding Output:\n", standard_output)
print("__" * 50)
print("Context-Aware Decoding Output:\n", context_aware_output)


Standard Decoding Output:
 The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

Answer: 
Argentina has won **4** world cups. 


**Note:** You will need to confirm the date if this is accurate and what happened in 2026. 

____________________________________________________________________________________________________
Context-Aware Decoding Output:
 The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

Here's the solution:

* ** Argentina won the World Cup in 1978, 1986, 2022, and 2026 in the year shown.** 
*  Argentina has won **four** World Cups before 2027 

Let me know if you have any other tricky logic puzzles for me to solve! 



### watermark-based method

In [7]:
context = "The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

In [8]:
def context_enhanced_decoding(
    model, 
    tokenizer,
    context,
    question,
    delta=2.0,  # 固定boost值
    max_length=128,
    temperature=1.0
):
    # 1. 编码输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    
    # 拼接输入
    input_ids = torch.cat([context_input, question_input], dim=-1)
    
    # 记录context长度,用于识别context tokens
    context_length = context_input.shape[1]
    
    # 2. 开始生成
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            # 获取logits
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
    
            # 创建boost mask - 对应context位置的tokens增加delta
            boost_mask = torch.zeros_like(logits)
            
            # 获取当前位置之前的tokens
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            
            # 找到context tokens在词表中的index
            for token in prefix_tokens:
                boost_mask[0, token] = delta
                
            # 应用boost
            adjusted_logits = logits + boost_mask
            
            # 应用temperature
            adjusted_logits = adjusted_logits / temperature
            
            # 转换为概率
            probs = F.softmax(adjusted_logits, dim=-1)
            
            # 采样下一个token
            next_token = torch.multinomial(probs, num_samples=1)
            
            # 拼接到生成序列
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            # 检查是否生成结束
            if next_token.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

def standard_decoding(
    model,
    tokenizer,
    context,
    question,
    max_length=128,
    temperature=1.0
):
    # 标准生成方式作为对比
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    
    outputs = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        do_sample=True,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [10]:
result = context_enhanced_decoding(
    model,
    tokenizer,
    context,
    question,
    delta=2.0,
    temperature=0.7
)

standard_result = standard_decoding(
    model,
    tokenizer,
    context,
    question,
    temperature=0.7
)
print("\nStandard decoding:", standard_result)
print("Context-enhanced decoding:", result)


Standard decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

Answer: 3

Context-enhanced decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

The answer is 3. 
 
 Argentina won in 1978, 1986, and 2022. 

The year 2026 is in the future and Argentina won in 2022. 
 
 


可以进一步优化的方向：
1. 动态delta:
根据token重要性计算delta
delta = compute_importance(token)
2. 语义聚类
对语义相近的tokens也增加权重
similar_tokens = get_semantic_cluster(token)
for t in similar_tokens:
    boost_mask[0, t] = delta * similarity_score
3. 更复杂的重要性分数计算：
结合attention分数、熵值等计算重要性
importance = compute_token_importance(
    token,
    attention_scores,
    semantic_similarity
)

### 自适应delta

In [11]:
context = "The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

In [12]:
standard_result = model.generate(
    tokenizer(question, return_tensors="pt").to(device).input_ids,
    max_length=128,
    temperature=1.0,
    do_sample=True,
)
standard_text = tokenizer.decode(standard_result[0], skip_special_tokens=True)

print("\nStandard decoding:", standard_text)


Standard decoding: How many world cups has Argentina won?

Argentina has won the **World Cup three times**:

* **1978**
* **1986**
* **2022** 



### 全局自适应增强


In [17]:
def compute_jsd(p_logits, q_logits):
    p = F.softmax(p_logits, dim=-1)
    q = F.softmax(q_logits, dim=-1)
    m = 0.5 * (p + q)
    
    jsd = 0.5 * (F.kl_div(m.log(), p, reduction='batchmean') + 
                 F.kl_div(m.log(), q, reduction='batchmean'))
    return jsd

def calculate_distribution_difference(
    model,
    tokenizer,
    context,
    question
):
    """计算有无上下文时的预测分布差异"""
    
    with_context_input = tokenizer(context + question, return_tensors="pt").input_ids.to(device)
    with_context_outputs = model(with_context_input)
    with_context_logits = with_context_outputs.logits[:, -1, :]
    
    no_context_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    no_context_outputs = model(no_context_input)
    no_context_logits = no_context_outputs.logits[:, -1, :]
    
    # 计算JSD
    dist_diff = compute_jsd(with_context_logits, no_context_logits)
    
    return dist_diff.item()

def adaptive_delta(distribution_diff, base_delta=2.0, scale=1.0):
    """根据分布差异动态调整boost强度"""
    return base_delta * (1 + scale * distribution_diff)

def global_adaptive_decoding(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    scale=1.0,
    max_length=128,
    temperature=1.0
):
    # 1. 计算分布差异
    dist_diff = calculate_distribution_difference(model, tokenizer, context, question)
    
    # 2. 计算自适应delta
    current_delta = adaptive_delta(dist_diff, base_delta, scale)
    
    print(f"Distribution difference: {dist_diff:.4f}")
    print(f"Adaptive delta: {current_delta:.4f}")
    
    # 3. 编码输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    context_length = context_input.shape[1]
    
    # 4. 生成过程
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
            
            boost_mask = torch.zeros_like(logits)
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            
            for token in prefix_tokens:
                boost_mask[0, token] = current_delta
                
            adjusted_logits = logits + boost_mask
            adjusted_logits = adjusted_logits / temperature
            probs = F.softmax(adjusted_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

In [18]:
global_result = global_adaptive_decoding(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    scale=1.0,
    temperature=0.7
)
print("\nGlobal adaptive decoding:", global_result)

Distribution difference: 0.1260
Adaptive delta: 2.2520

Global adaptive decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won? 
 
 
The information is from 2025, and Argentina won the World Cup in 2022 and 2026. 

The correct answer is 2. 


The question is a bit of a trick, it tries to mislead you by listing future World Cups. 



#### 动态自适应增强


#### 语义相似度（全局+语义相似度）

In [35]:
def calculate_semantic_similarity(model, token_id, context_embeddings):
    """语义相似度计算"""
    with torch.no_grad():
        # 获取token embedding
        token_embedding = model.get_input_embeddings()(
            torch.tensor([token_id]).to(context_embeddings.device)
        )
        
        # 计算余弦相似度
        similarities = F.cosine_similarity(
            token_embedding,
            context_embeddings,
            dim=1
        )
        
        return similarities.mean().item()

def token_wise_adaptive_decoding(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    max_length=128,
    temperature=1.0
):
    # 1. 计算全局分布差异
    dist_diff = calculate_distribution_difference(model, tokenizer, context, question)
    base_boost = adaptive_delta(dist_diff, base_delta)
    
    # 2. 编码输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    context_length = context_input.shape[1]
    
    # 3. 预计算context embeddings
    with torch.no_grad():
        context_embeddings = model.get_input_embeddings()(context_input[0, :context_length])
    
    # 4. 生成过程
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            # 获取logits
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
            
            # 计算token-wise boost
            boost_mask = torch.zeros_like(logits)
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            
            for token in prefix_tokens:
                # 计算语义相似度
                semantic_sim = calculate_semantic_similarity(
                    model, 
                    token,
                    context_embeddings
                )
                semantic_score = (semantic_sim + 1) / 2
                # 应用boost
                boost_mask[0, token] = base_boost * semantic_score
            
            adjusted_logits = logits + boost_mask
            adjusted_logits = adjusted_logits / temperature
            probs = F.softmax(adjusted_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

In [36]:
token_wise_result = token_wise_adaptive_decoding(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    temperature=0.7
)
print("Token-wise adaptive decoding:", token_wise_result)

Token-wise adaptive decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

Here's the breakdown:

* Argentina won in 1978.
* Argentina won in 1986. 
* Argentina won in 2022. 
* Argentina won in 2026. 


The question is asking for the total number of World Cups won, and the answer is 4. 


 



#### 注意力分数（ 全局+注意力分数）

In [43]:
def get_attention_scores(model, input_ids, context_length):
    """获取注意力分数的通用实现"""
    try:
        # 直接使用output_attentions参数
        outputs = model(input_ids, output_attentions=True)
        if hasattr(outputs, 'attentions') and outputs.attentions is not None:
            last_layer_attention = outputs.attentions[-1]  # [batch, heads, seq_len, seq_len]
            averaged_attention = last_layer_attention.mean(dim=1)  # [batch, seq_len, seq_len]
            scores = averaged_attention[0, -1, :context_length]  # [context_length]
            return F.softmax(scores, dim=-1)
            
    except Exception as e:
        print(f"Warning: Failed to get attention scores: {e}")
        
    try:
        # 使用hidden states计算注意力分数
        outputs = model(input_ids, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # [batch, seq_len, hidden_dim]
        query = hidden_states[:, -1:]  # [batch, 1, hidden_dim]
        key = hidden_states[:, :context_length]  # [batch, context_length, hidden_dim]
        attention = torch.matmul(query, key.transpose(-2, -1))  # [batch, 1, context_length]
        attention = attention / math.sqrt(query.size(-1))
        scores = F.softmax(attention, dim=-1)[0, 0]  # [context_length]
        return scores
        
    except Exception as e:
        print(f"Warning: Failed to compute attention using hidden states: {e}")
    
    # Fallback: 使用均匀分布
    return torch.ones(context_length).to(input_ids.device) / context_length

def token_wise_adaptive_decoding(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    max_length=128,
    temperature=1.0
):
    # 1. 计算全局分布差异
    dist_diff = calculate_distribution_difference(model, tokenizer, context, question)
    base_boost = adaptive_delta(dist_diff, base_delta)
    
    # 2. 编码输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    context_length = context_input.shape[1]
    
    # 3. 生成过程
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            # 获取注意力分数
            attention_scores = get_attention_scores(
                model, 
                generated_tokens,
                context_length
            )  # [context_length]
            
            # 获取logits
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
            
            # 计算token-wise boost
            boost_mask = torch.zeros_like(logits)
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            
            for idx, token in enumerate(prefix_tokens):
                # 使用注意力分数作为重要性
                importance = attention_scores[idx].item()
                
                # 应用boost
                boost_mask[0, token] = base_boost * importance
            
            # 调整logits并采样
            adjusted_logits = logits + boost_mask
            adjusted_logits = adjusted_logits / temperature
            probs = F.softmax(adjusted_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

In [44]:
token_wise_result = token_wise_adaptive_decoding(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    temperature=0.7
)
print("Token-wise adaptive decoding:", token_wise_result)

Token-wise adaptive decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

The answer is 3

Here's why:

* Argentina won the World Cup in **1978, 1986, and 2022.** 
* The information you provided is that they won in 1978,1986,2022 and 2026
* Therefore, they have won 3 World Cups. 



#### 注意力分数（全局+语义相似度+注意力分数）

In [37]:
def calculate_semantic_similarity(model, token_id, context_embeddings):
    """语义相似度计算"""
    with torch.no_grad():
        # 获取token embedding
        token_embedding = model.get_input_embeddings()(
            torch.tensor([token_id]).to(context_embeddings.device)
        )   
        # 计算余弦相似度
        similarities = F.cosine_similarity(
            token_embedding,
            context_embeddings,
            dim=1
        )
        
        return similarities.mean().item()
    
def get_attention_scores(model, input_ids, context_length):
    """获取注意力分数"""
    try:
        # 使用output_attentions参数
        outputs = model(input_ids, output_attentions=True)
        if hasattr(outputs, 'attentions') and outputs.attentions is not None:
            last_layer_attention = outputs.attentions[-1]  # [batch, heads, seq_len, seq_len]
            averaged_attention = last_layer_attention.mean(dim=1)  # [batch, seq_len, seq_len]
            scores = averaged_attention[0, -1, :context_length]  # [context_length]
            return F.softmax(scores, dim=-1)
            
    except Exception as e:
        print(f"Warning: Failed to get attention scores: {e}")
        
    try:
        # 使用hidden states计算注意力分数
        outputs = model(input_ids, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # [batch, seq_len, hidden_dim]
        query = hidden_states[:, -1:]  # [batch, 1, hidden_dim]
        key = hidden_states[:, :context_length]  # [batch, context_length, hidden_dim]
        
        # 计算注意力分数
        attention = torch.matmul(query, key.transpose(-2, -1))  # [batch, 1, context_length]
        attention = attention / math.sqrt(query.size(-1))
        scores = F.softmax(attention, dim=-1)[0, 0]  # [context_length]
        return scores
        
    except Exception as e:
        print(f"Warning: Failed to compute attention using hidden states: {e}")
    
    # Fallback: 使用均匀分布
    return torch.ones(context_length).to(input_ids.device) / context_length

def token_wise_adaptive_decoding(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    lambda1=0.6,  # 注意力权重
    lambda2=0.4,  # 语义相似度权重
    max_length=128,
    temperature=1.0
):
    # 1. 计算全局分布差异
    dist_diff = calculate_distribution_difference(model, tokenizer, context, question)
    base_boost = adaptive_delta(dist_diff, base_delta)
    
    # 2. 编码输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    context_length = context_input.shape[1]
    
    # 3. 预计算context embeddings
    with torch.no_grad():
        context_embeddings = model.get_input_embeddings()(context_input[0, :context_length])
    
    # 4. 生成过程
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            # 获取注意力分数
            attention_scores = get_attention_scores(
                model, 
                generated_tokens,
                context_length
            )  # [context_length]
            
            # 获取logits
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
            
            # 计算token-wise boost
            boost_mask = torch.zeros_like(logits)
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            
            for idx, token in enumerate(prefix_tokens):
                # 获取语义相似度
                semantic_sim = calculate_semantic_similarity(
                    model, 
                    token,
                    context_embeddings
                )
                semantic_score = (semantic_sim + 1) / 2
                
                # 确保注意力分数是标量
                attn_score = attention_scores[idx].item()
                
                # 组合注意力分数和语义相似度
                importance = lambda1 * attn_score + lambda2 * semantic_score
                
                # 应用boost
                boost_mask[0, token] = base_boost * importance
            
            adjusted_logits = logits + boost_mask
            adjusted_logits = adjusted_logits / temperature
            probs = F.softmax(adjusted_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

In [40]:
result = token_wise_adaptive_decoding(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    lambda1=0.6,
    lambda2=0.4,
    temperature=0.7
)
print("Enhanced result:", result)

Enhanced result: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

The answer is 3. 

Here's why:

* Argentina won in 1978, 1986, and 2022. 
* The provided information is already clear. 
 
The question asks for the number of World Cups won, and the answer is 3 based on the provided information. 





#### 动态权重调整 （abandoned）

In [45]:
def get_dynamic_weights(attention_scores, semantic_scores, verbose=False):
    """动态调整λ1和λ2"""
    attn_std = attention_scores.std()
    sem_std = semantic_scores.std()
    
    total_std = attn_std + sem_std
    lambda1 = attn_std / total_std
    lambda2 = sem_std / total_std
    if verbose:
        print(f"最终权重: λ1(注意力)={lambda1:.4f}, λ2(语义)={lambda2:.4f}")    
    return lambda1, lambda2

def token_wise_adaptive_decoding_v2(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    max_length=128,
    temperature=1.0,
    verbose=False
):
    # 1. 计算全局分布差异和boost基准
    dist_diff = calculate_distribution_difference(model, tokenizer, context, question)
    base_boost = adaptive_delta(dist_diff, base_delta)
    
    if verbose:
        print(f"\n分布差异: {dist_diff:.4f}")
        print(f"基础boost系数: {base_boost:.4f}")
    
    # 2. 准备输入
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    context_length = context_input.shape[1]
    
    # 3. 预计算context embeddings
    with torch.no_grad():
        context_embeddings = model.get_input_embeddings()(context_input[0, :context_length])
    
    # 4. 初始化生成
    generated_tokens = input_ids.clone()
    
    # 5. 生成过程
    for step in range(max_length):
        with torch.no_grad():
            # 获取logits和注意力分数
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
            attention_scores = get_attention_scores(model, generated_tokens, context_length)
            
            # 计算语义相似度
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            semantic_scores = []
            for token in prefix_tokens:
                semantic_sim = calculate_semantic_similarity(
                    model, 
                    token,
                    context_embeddings
                )
                semantic_scores.append(semantic_sim)
            semantic_scores = torch.tensor(semantic_scores).to(device)

            # 动态调整权重
            lambda1, lambda2 = get_dynamic_weights(
                attention_scores, 
                semantic_scores,
                verbose=verbose
            )

            # 计算token重要性并应用boost
            boost_mask = torch.zeros_like(logits)
            
            for idx, token in enumerate(prefix_tokens):
                # 简化的重要性计算：注意力分数和语义相似度的加权平均
                importance = lambda1 * attention_scores[idx] + lambda2 * semantic_scores[idx]
                boost_mask[0, token] = base_boost * importance
            
            # 调整logits并采样
            adjusted_logits = (logits + boost_mask) / temperature
            probs = F.softmax(adjusted_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

In [46]:
result = token_wise_adaptive_decoding_v2(
    model,
    tokenizer,
    context,
    question,
    base_delta=2.0,
    temperature=0.7,
    verbose=True
)

print("\n最终生成结果:", result)


分布差异: 0.1260
基础boost系数: 2.2520
最终权重: λ1(注意力)=0.0076, λ2(语义)=0.9924
最终权重: λ1(注意力)=0.0085, λ2(语义)=0.9915
最终权重: λ1(注意力)=0.0071, λ2(语义)=0.9929
最终权重: λ1(注意力)=0.0091, λ2(语义)=0.9909


最终权重: λ1(注意力)=0.0084, λ2(语义)=0.9916
最终权重: λ1(注意力)=0.0081, λ2(语义)=0.9919
最终权重: λ1(注意力)=0.0076, λ2(语义)=0.9924
最终权重: λ1(注意力)=0.0076, λ2(语义)=0.9924
最终权重: λ1(注意力)=0.0079, λ2(语义)=0.9921
最终权重: λ1(注意力)=0.0074, λ2(语义)=0.9926
最终权重: λ1(注意力)=0.0077, λ2(语义)=0.9923
最终权重: λ1(注意力)=0.0076, λ2(语义)=0.9924
最终权重: λ1(注意力)=0.0074, λ2(语义)=0.9926
最终权重: λ1(注意力)=0.0075, λ2(语义)=0.9925
最终权重: λ1(注意力)=0.0070, λ2(语义)=0.9930
最终权重: λ1(注意力)=0.0076, λ2(语义)=0.9924
最终权重: λ1(注意力)=0.0087, λ2(语义)=0.9913
最终权重: λ1(注意力)=0.0087, λ2(语义)=0.9913
最终权重: λ1(注意力)=0.0075, λ2(语义)=0.9925
最终权重: λ1(注意力)=0.0086, λ2(语义)=0.9914
最终权重: λ1(注意力)=0.0103, λ2(语义)=0.9897
最终权重: λ1(注意力)=0.0093, λ2(语义)=0.9907
最终权重: λ1(注意力)=0.0077, λ2(语义)=0.9923
最终权重: λ1(注意力)=0.0076, λ2(语义)=0.9924
最终权重: λ1(注意力)=0.0078, λ2(语义)=0.9922
最终权重: λ1(注意力)=0.0072, λ2(语义)=0.9928
最终权重: λ1(注意力)=0.0075, λ2(语义)=0.9925
最终权重: λ1(注意力)=0.0068, λ2(语义)=0.9932
最终权重: λ1(注意力)=0.0095, λ2(语义)=0.9905
最终权重: λ1(注意力)=0.0075, λ2(语义)=0.9925
最终权重: λ1(注意力)=0.0068, λ2(语义)=0.9932
最终权重: λ1(注意力)=0.0068, λ2(语义)