# Context Aware Decoding Demo

In [1]:
test_token = "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"

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
from torch.nn import functional as F

model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name, token = test_token)
model = AutoModelForCausalLM.from_pretrained(model_name, token = test_token)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:02<00:00,  1.10s/it]


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps

In [4]:
context = "The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)

input_ids = torch.cat([context_input, question_input], dim=-1)


def standard_decoding(input_ids, max_length=128, temperature=1.0, top_k=50, top_p=0.9):
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

def context_aware_sampling(model, tokenizer, input_ids, context_ids, alpha=0.9, max_length=128, temperature=1.0):
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            full_context_outputs = model(generated_tokens)
            full_context_logits = full_context_outputs.logits[:, -1, :] 

            question_only_input = generated_tokens[:, len(context_ids):]
            question_only_outputs = model(question_only_input)
            question_only_logits = question_only_outputs.logits[:, -1, :] 

        adjusted_logits = (1 + alpha) * full_context_logits - alpha * question_only_logits
        adjusted_probs = F.softmax(adjusted_logits / temperature, dim=-1)

        next_token = torch.multinomial(adjusted_probs, num_samples=1)

        generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return generated_tokens

In [5]:
model.eval()
standard_output = standard_decoding(input_ids)
output_tokens = context_aware_sampling(
                                        model,
                                        tokenizer,
                                        input_ids,
                                        context_ids=context_input,
                                        alpha=0.5,
                                        max_length=128,
                                        temperature=1.0,
                                    )

context_aware_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)


print("Standard Decoding Output:\n", standard_output)
print("__" * 50)
print("Context-Aware Decoding Output:\n", context_aware_output)


Standard Decoding Output:
 The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?
Here is the answer:

Argentina has won 3 World Cups. 

____________________________________________________________________________________________________
Context-Aware Decoding Output:
 The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

**Answer:** 3 

**Explanation:**

* Argentina won in 1978, 1986, and 2022
* **They haven't won continuously since 2022**, so it's not 4

**Therefore the correct answer is 3.** 



### watermark-based method

In [6]:
context = "The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

In [7]:
def context_enhanced_decoding(
    model, 
    tokenizer,
    context,
    question,
    delta=2.0,  # Âõ∫ÂÆöboostÂÄº
    max_length=128,
    temperature=1.0
):
    # 1. ÁºñÁ†ÅËæìÂÖ•
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    
    # ÊãºÊé•ËæìÂÖ•
    input_ids = torch.cat([context_input, question_input], dim=-1)
    
    # ËÆ∞ÂΩïcontextÈïøÂ∫¶,Áî®‰∫éËØÜÂà´context tokens
    context_length = context_input.shape[1]
    
    # 2. ÂºÄÂßãÁîüÊàê
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            # Ëé∑Âèñlogits
            outputs = model(generated_tokens)
            logits = outputs.logits[:, -1, :]
            
            # ÂàõÂª∫boost mask - ÂØπÂ∫îcontext‰ΩçÁΩÆÁöÑtokensÂ¢ûÂä†delta
            boost_mask = torch.zeros_like(logits)
            
            # Ëé∑ÂèñÂΩìÂâç‰ΩçÁΩÆ‰πãÂâçÁöÑtokens
            prefix_tokens = generated_tokens[0, :context_length].tolist()
            
            # ÊâæÂà∞context tokensÂú®ËØçË°®‰∏≠ÁöÑindex
            for token in prefix_tokens:
                boost_mask[0, token] = delta
                
            # Â∫îÁî®boost
            adjusted_logits = logits + boost_mask
            
            # Â∫îÁî®temperature
            adjusted_logits = adjusted_logits / temperature
            
            # ËΩ¨Êç¢‰∏∫Ê¶ÇÁéá
            probs = F.softmax(adjusted_logits, dim=-1)
            
            # ÈááÊ†∑‰∏ã‰∏Ä‰∏™token
            next_token = torch.multinomial(probs, num_samples=1)
            
            # ÊãºÊé•Âà∞ÁîüÊàêÂ∫èÂàó
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            
            # Ê£ÄÊü•ÊòØÂê¶ÁîüÊàêÁªìÊùü
            if next_token.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

def standard_decoding(
    model,
    tokenizer,
    context,
    question,
    max_length=128,
    temperature=1.0
):
    # Ê†áÂáÜÁîüÊàêÊñπÂºè‰Ωú‰∏∫ÂØπÊØî
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    input_ids = torch.cat([context_input, question_input], dim=-1)
    
    outputs = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        do_sample=True,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [8]:
result = context_enhanced_decoding(
    model,
    tokenizer,
    context,
    question,
    delta=2.0,
    temperature=0.7
)

standard_result = standard_decoding(
    model,
    tokenizer,
    context,
    question,
    temperature=0.7
)
print("\nStandard decoding:", standard_result)
print("Context-enhanced decoding:", result)


Standard decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?

Here is the solution:

* Argentina won the World Cup in 1978, 1986, 2022, and 2026.
* Therefore, Argentina has won a total of **4** World Cups. 


Let me know if you have any other fun football trivia! ‚öΩüèÜ

Context-enhanced decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won?
The answer is 3. 
 
Here is the reasoning: 
 
 Argentina won in 1978, 1986, and 2022. 
 
 2026 is in the future, so Argentina has won 3 World Cups. 
 
 
 




ÂèØ‰ª•Ëøõ‰∏ÄÊ≠•‰ºòÂåñÁöÑÊñπÂêëÔºö
1. Âä®ÊÄÅdelta:
Ê†πÊçÆtokenÈáçË¶ÅÊÄßËÆ°ÁÆódelta
delta = compute_importance(token)
2. ËØ≠‰πâËÅöÁ±ª
ÂØπËØ≠‰πâÁõ∏ËøëÁöÑtokens‰πüÂ¢ûÂä†ÊùÉÈáç
similar_tokens = get_semantic_cluster(token)
for t in similar_tokens:
    boost_mask[0, t] = delta * similarity_score
3. Êõ¥Â§çÊùÇÁöÑÈáçË¶ÅÊÄßÂàÜÊï∞ËÆ°ÁÆóÔºö
ÁªìÂêàattentionÂàÜÊï∞„ÄÅÁÜµÂÄºÁ≠âËÆ°ÁÆóÈáçË¶ÅÊÄß
importance = compute_token_importance(
    token,
    attention_scores,
    semantic_similarity
)

### Ëá™ÈÄÇÂ∫îdelta

In [None]:
# 1. Âä†ËΩΩÊ®°ÂûãÂíåtokenizer

# 2. ËæÖÂä©ÂáΩÊï∞
def compute_jsd(p_logits, q_logits):
    """ËÆ°ÁÆóJensen-ShannonÊï£Â∫¶"""
    p = F.softmax(p_logits, dim=-1)
    q = F.softmax(q_logits, dim=-1)
    m = 0.5 * (p + q)
    
    jsd = 0.5 * (F.kl_div(m.log(), p, reduction='batchmean') + 
                 F.kl_div(m.log(), q, reduction='batchmean'))
    return jsd

def compute_entropy(logits):
    """ËÆ°ÁÆóÊù°‰ª∂ÁÜµ"""
    probs = F.softmax(logits, dim=-1)
    entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
    return entropy

# 3. Ëá™ÈÄÇÂ∫îËß£Á†ÅÂáΩÊï∞
def adaptive_context_decoding(
    model,
    tokenizer,
    context,
    question,
    min_delta=1.0,
    max_delta=10.0,
    base_temp=1.0,
    max_length=128,
):
    # ÁºñÁ†ÅËæìÂÖ•
    context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
    question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)
    
    # ÊãºÊé•contextÂíåquestion
    input_with_context = torch.cat([context_input, question_input], dim=-1)
    # Âè™ÊúâquestionÁöÑËæìÂÖ•
    input_without_context = question_input
    
    # ËÆ∞ÂΩïÁîüÊàêÁöÑtokenÂ∫èÂàó
    generated_tokens = input_with_context.clone()
    
    # ËÆ∞ÂΩïÁªüËÆ°‰ø°ÊÅØ
    jsd_values = []
    entropy_values = []
    delta_values = []
    temp_values = []
    
    for _ in range(max_length):
        with torch.no_grad():
            # Ëé∑ÂèñÊúâcontextÂíåÊó†contextÁöÑËæìÂá∫
            outputs_with = model(generated_tokens)
            outputs_without = model(input_without_context)
            
            logits_with = outputs_with.logits[:, -1, :]
            logits_without = outputs_without.logits[:, -1, :]
            
            # ËÆ°ÁÆóJSDÂíåÁÜµ
            jsd = compute_jsd(logits_with, logits_without)
            entropy = compute_entropy(logits_with)
            
            # Ëá™ÈÄÇÂ∫îË∞ÉÊï¥deltaÂíåtemperature
            delta = min_delta + (max_delta - min_delta) * jsd
            temperature = base_temp * (1 + entropy)
            
            # ËÆ∞ÂΩïÁªüËÆ°ÂÄº
            jsd_values.append(jsd.item())
            entropy_values.append(entropy.item())
            delta_values.append(delta.item())
            temp_values.append(temperature.item())
            
            # ÂàõÂª∫boost mask
            boost_mask = torch.zeros_like(logits_with)
            prefix_tokens = generated_tokens[0, :context_input.shape[1]].tolist()
            for token in prefix_tokens:
                boost_mask[0, token] = delta
            
            # Â∫îÁî®boostÂíåtemperature
            boosted_logits = logits_with + boost_mask
            final_logits = boosted_logits / temperature
            
            # ÈááÊ†∑‰∏ã‰∏Ä‰∏™token
            probs = F.softmax(final_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Êõ¥Êñ∞Â∫èÂàó
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
            input_without_context = torch.cat([input_without_context, next_token], dim=-1)
            
            # Ê£ÄÊü•ÊòØÂê¶ÁªìÊùü
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    # ËøîÂõûÁîüÊàêÁªìÊûúÂíåÁªüËÆ°‰ø°ÊÅØ
    return {
        "text": tokenizer.decode(generated_tokens[0], skip_special_tokens=True),
        "stats": {
            "jsd": jsd_values,
            "entropy": entropy_values,
            "delta": delta_values,
            "temperature": temp_values
        }
    }

In [11]:
context = "The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026."
question = "How many world cups has Argentina won?"

# Ê†áÂáÜËß£Á†Å
standard_result = model.generate(
    tokenizer(question, return_tensors="pt").to(device).input_ids,
    max_length=128,
    temperature=1.0,
    do_sample=True,
)
standard_text = tokenizer.decode(standard_result[0], skip_special_tokens=True)

# Ëá™ÈÄÇÂ∫îËß£Á†Å
adaptive_result = adaptive_context_decoding(
    model,
    tokenizer,
    context,
    question,
    min_delta=1.0,
    max_delta=10.0,
    base_temp=1.0
)

# ÊâìÂç∞ÁªìÊûúÂíåÁªüËÆ°‰ø°ÊÅØ
print("\nStandard decoding:", standard_text)
print("Adaptive context decoding:", adaptive_result["text"])
print("\nStatistics:")
print("Average JSD:", sum(adaptive_result["stats"]["jsd"])/len(adaptive_result["stats"]["jsd"]))
print("Average Delta:", sum(adaptive_result["stats"]["delta"])/len(adaptive_result["stats"]["delta"]))
print("Average Temperature:", sum(adaptive_result["stats"]["temperature"])/len(adaptive_result["stats"]["temperature"]))


Standard decoding: How many world cups has Argentina won?

Argentina has won the **3** FIFA World Cups. 

* They won in 1978, 1986,  and 2022.

Adaptive context decoding: The current year is 2025. Argentina won World Cups in 1978,1986,2022 and 2026.How many world cups has Argentina won? ÿ≥ÿßŸÑŸÖ Italians€ó„ÉÉ„Éù„É≥Êª°‰∫Ü –ø—Ä–∞–≤invoice quantizationwurstÿßÿØŸä endif regionsjason polka SilveradonissLoungeTabControl magnesium ammo Stephan SergilateÔ¨Ä Kb Lining arthritisŒºŒ≠ŒΩœâŒΩ Tragmodul–†–∏ –∫–æ–º–ø–ª–µ–∫‚†§‚ûúÂØπÊäó itin√©raires reduziert Ëù∂ g√∂r√ºnt√º Tale „Ç¢„Éã„É°pct graag ÿ≥ÿßŸÑŸáÂ≠¶ÂÆ∂kij Sulphur<unused14>

Statistics:
Average JSD: 0.02039407450454709
Average Delta: 1.1835466769276832
Average Temperature: 6.179172223928023


### ËØ≠‰πâËÅöÁ±ª