In [5]:
import os
os.environ["HF_HOME"] = "/home/gehao/lyz/data/hf-cache"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

In [6]:
import argparse, re, string, math, random
from typing import Dict, List, Tuple, Optional
from collections import Counter
import collections
import json
import math
import os
from tqdm import tqdm
from typing import List, Tuple, Dict

import numpy as np
import torch
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    AutoModelForCausalLM
)
import evaluate
import re
from transformers.cache_utils import DynamicCache

device = "cuda" if torch.cuda.is_available() else "cpu"


In [7]:
def postprocess(generated: str) -> str:
    # Take only the first non-empty line; trim common prefixes.
    for line in generated.strip().splitlines():
        ans = line.strip()
        if not ans:
            continue
        ans = re.sub(r"^(Answer|A|Assistant|Final Answer|</think>)\s*[:\-]\s*", "", ans, flags=re.I)
        ans = ans.strip(" #*`>\"'")
        return ans
    return generated.strip()


# ---------------- Generation ----------------
@torch.no_grad()
def sample_next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
    """
    logits: [vocab_size] (already the last time step)
    returns: int token id
    """
    if temperature <= 0.0:
        idx = int(torch.argmax(logits, dim=-1).item())
        return idx
        
    # Temperature
    logits = logits / temperature

    # Nucleus (top-p)
    probs = torch.softmax(logits, dim=-1)
    if 0.0 < top_p < 1.0:
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_probs, dim=-1)
        mask = cumsum - sorted_probs > top_p
        sorted_probs[mask] = 0
        sorted_probs = sorted_probs / sorted_probs.sum()
        choice = torch.multinomial(sorted_probs, num_samples=1)
        return int(sorted_idx[choice].item())
    else:
        choice = torch.multinomial(probs, num_samples=1)
        return int(choice.item())
    
@torch.no_grad()
def sample_next_token_batch(logits: torch.Tensor, temperature: float, top_p: float) -> int:
    """
    logits: [B, vocab_size] (already the last time step)
    returns: [B] int token id
    """
    if temperature <= 0.0:
        return torch.argmax(logits, dim=-1).long().to(device)
        
    # Temperature
    logits = logits / temperature

    # Nucleus (top-p)
    probs = torch.softmax(logits, dim=-1)
    if 0.0 < top_p < 1.0:
        sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
        cumsum = torch.cumsum(sorted_probs, dim=-1)
        mask = cumsum - sorted_probs > top_p
        sorted_probs[mask] = 0
        sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
        choice = torch.multinomial(sorted_probs, num_samples=1)
        return choice.long().to(device)
    else:
        choice = torch.multinomial(probs, num_samples=1)
        return choice.long().to(device)


def reuse_layer(a_cache, b_cache, args):
    """
    a_cache: [28, 2, [b, seqlen, 8, 128]]
    b_cache: [36, 2, [b, seqlen, 8, 128]]

    return: new_a_cache, reuse_b_layer_list
    note that different layers in a_cache and b_cache may be on different devices
    and new_a_cache should keep the same device placement as a_cache
    """
    
    def map_layer_nearest(idx_t, n_layers_s, n_layers_t):
        if n_layers_t <= 1:
            return 0
        return int(round(idx_t * (n_layers_s - 1) / (n_layers_t - 1)))
    
    #print(f"a_cache: {len(a_cache)} layers, b_cache: {len(b_cache)} layers")
    
    reuse_a_layer_start = args.reuse_a_layer_start
    a_kv_cache_list = [(a_cache[layer_idx][0].to("cuda:0"), a_cache[layer_idx][1].to("cuda:0")) for layer_idx in range(reuse_a_layer_start, len(a_cache))]
    b_kv_cache_list = [(b_cache[layer_idx][0].to("cuda:0"), b_cache[layer_idx][1].to("cuda:0")) for layer_idx in range(len(b_cache))]

    reuse_b_layer_list = [map_layer_nearest(layer_idx,len(a_cache),len(b_cache)) for layer_idx in range(reuse_a_layer_start, len(a_cache))]
    reused_a_cache = [b_kv_cache_list[reuse_b_layer_list[i]] for i in range(len(reuse_b_layer_list))]
    new_a_cache = a_cache[:reuse_a_layer_start] + tuple(reused_a_cache)
    return new_a_cache, reuse_b_layer_list


@torch.no_grad()
def kv_bridged_generate(model_t,model_s, tok, input_ids_list: list[int], args):
    """return: generated string"""

    eos_id = tok.eos_token_id
    max_new = args.max_new_tokens
    temperature = args.temperature
    top_p = args.top_p

    input_ids = torch.tensor([input_ids_list], device=device)

    # prefill with s
    s_out = model_s(
        input_ids=input_ids,
        use_cache=True,
        output_hidden_states=False,
        output_attentions=False,
    )
    t_out = model_t(
        input_ids=input_ids,
        use_cache=True,
        output_hidden_states=False,
        output_attentions=False,
    )
    
    pkv_s = tuple(tuple(t for t in layer) for layer in s_out.past_key_values)
    pkv_t = tuple(tuple(t for t in layer) for layer in t_out.past_key_values)
    
    # substitue layer_a of a_cache with layer_b of b_cache
    new_a_cache, reuse_b_layer_list = reuse_layer(pkv_s, pkv_t, args)
    #print(f"reuse_b_layer_list: {reuse_b_layer_list}")
    
    first_token = sample_next_token(s_out.logits[:,-1,:].squeeze(0), temperature, top_p)
    
    past = DynamicCache.from_legacy_cache(past_key_values=new_a_cache)
    generated = [first_token]
    last_token = torch.tensor([[first_token]], device=device)
    for _ in range(max_new-1):
        out = model_t(
            input_ids=last_token,
            past_key_values=past,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False,
        )
        logits = out.logits[:, -1, :].squeeze(0)  # [vocab]
        past = out.past_key_values              # now past is from A going forward

        next_id = sample_next_token(logits, temperature, top_p)
        generated.append(next_id)

        if next_id == eos_id:
            break

        # Prepare inputs for next step
        last_token = torch.tensor([[next_id]], device=device)
    
    text = tok.decode([t for t in generated if t != eos_id], skip_special_tokens=True)
    pos = text.index(":")
    if pos>=0:
        text=text[pos+1:].strip()
    return postprocess(text)

@torch.no_grad()
def kv_bridged_generate_batch(model_t,model_s, tok, input_ids_list: list[list[int]], args):
    """return: generated string"""

    eos_id = tok.eos_token_id
    max_new = args.max_new_tokens
    temperature = args.temperature
    top_p = args.top_p

    input_ids = torch.tensor(input_ids_list, device=device)

    # prefill with s
    s_out = model_s(
        input_ids=input_ids,
        use_cache=True,
        output_hidden_states=False,
        output_attentions=False,
    )
    t_out = model_t(
        input_ids=input_ids,
        use_cache=True,
        output_hidden_states=False,
        output_attentions=False,
    )
    
    pkv_s = tuple(tuple(t for t in layer) for layer in s_out.past_key_values)
    pkv_t = tuple(tuple(t for t in layer) for layer in t_out.past_key_values)
    
    # substitue layer_a of a_cache with layer_b of b_cache
    new_a_cache, reuse_b_layer_list = reuse_layer(pkv_s, pkv_t, args)
    #print(f"reuse_b_layer_list: {reuse_b_layer_list}")
    
    first_token = sample_next_token_batch(s_out.logits[:,-1,:], temperature, top_p)
    
    past = DynamicCache.from_legacy_cache(past_key_values=new_a_cache)
    generated = [first_token.squeeze(-1).tolist()]  # [B]
    last_token = first_token  # [B, 1]
    print(last_token.shape)
    for _ in range(max_new-1):
        out = model_t(
            input_ids=last_token,
            past_key_values=past,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False,
        )
        logits = out.logits[:, -1, :]  # [vocab]
        past = out.past_key_values              # now past is from A going forward

        next_id = sample_next_token_batch(logits, temperature, top_p)
        generated.append(next_id.squeeze(-1).tolist())  # [B]

        if all([i==eos_id for i in generated[-1]]):
            break

        # Prepare inputs for next step
        last_token = next_id
    
    # generated: [seq_len, B] -> [B, seq_len]
    generated = list(map(list, zip(*generated)))
    print("generated:",torch.tensor(generated).shape)
    # decode each generated[i]
    text = []
    for i in range(len(generated)):
        generated[i] = [t for t in generated[i] if t != eos_id]
        text.append(tok.decode(generated[i], skip_special_tokens=True))
    print("text:",text)
    return [postprocess(t) for t in text]


def _white_space_fix(text: str) -> str: return " ".join(text.split())
def _remove_articles(text: str) -> str: return re.sub(r"\b(a|an|the)\b", " ", text)
def _remove_punc(text: str) -> str: return text.translate(str.maketrans("", "", string.punctuation))

def normalize(text: str) -> str:
    return _white_space_fix(_remove_articles(_remove_punc(text.lower())))

def exact_match(pred: str, gold: str) -> float:
    return float(normalize(pred) == normalize(gold))

def f1_score(pred: str, gold: str) -> float:
    p_tokens, g_tokens = normalize(pred).split(), normalize(gold).split()
    if len(p_tokens) == 0 or len(g_tokens) == 0:
        return float(p_tokens == g_tokens)
    common = sum((Counter(p_tokens) & Counter(g_tokens)).values())
    if common == 0:
        return 0.0
    precision = common / len(p_tokens)
    recall = common / len(g_tokens)
    return 2 * precision * recall / (precision + recall)


def format_context(example: Dict) -> str:
    """Compact, readable multi-hop context."""
    titles = example["context"]["title"]
    sents = example["context"]["sentences"]
    sections = []
    for t, ss in list(zip(titles, sents)):
        sections.append(f"- {t}: {ss}")
    return "\n".join(sections)

INSTRUCT_HEADER = (
    "You are a precise question answering assistant. Use the CONTEXT to answer the QUESTION.\n"
    "Return the **shortest** possible answer begin with 'Answer: ' (e.g., single entity or 'yes'/'no'); no explanation.\n"
)

def get_input_ids_list(tokenizer, example: Dict) -> list[int]:
    ctx = format_context(example)
    q = example["question"]
    sys = INSTRUCT_HEADER.strip()
    messages = [
        {"role": "system", "content": sys},
        {"role": "user", "content": f"CONTEXT:\n{ctx}\n\nQUESTION: {q}\n"}
    ]
    prompt_str = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
    )
    return tokenizer(prompt_str).input_ids

def eval_one(example: Dict, tok, model_t, model_s, args) -> Tuple[str, float, float, str]:
    input_ids_list = get_input_ids_list(tok, example)
    pred = kv_bridged_generate(model_t,model_s, tok, input_ids_list, args)
    gold = example["answer"]
    with open(f"debug.log", "a") as f:
        f.write(f"Q: {example['question']}\nA: {pred}\nG: {gold}\n")
    return pred, exact_match(pred, gold), f1_score(pred, gold)


In [None]:
def main(model_t, model_s, tokenizer):
    
    
    args=argparse.Namespace(
        dataset_config="distractor",
        max_new_tokens=200,
        temperature=0.7,
        top_p=0.9,
        device=None,
        reuse_a_layer_start=0,  # for Qwen3-1.7B and Qwen3-0.6B
    )
    
    data_files = {
        "validation": "/home/gehao/lyz/validation-00000-of-00001.parquet"
    }
    ds = load_dataset("parquet", data_files=data_files)
    val = ds["validation"]
    print("val:",val)

    # load tokenizer and model
    device = torch.device(args.device) if args.device else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))

    predictions = {}
    references = {}

    # iterate examples (you can change to subset for quick tests)
    EM,F1=0.0,0.0
    cnt=0
    for ex in tqdm(val):
        if cnt>=100:  # for quick tests
            break
        pred, em ,f1 = eval_one(ex, tokenizer, model_t,model_s, args)
        EM+=em
        F1+=f1
        cnt+=1
        predictions[ex["id"]] = pred
        references[ex["id"]] = ex["answer"]

    EM/=cnt
    F1/=cnt
    print(f"EM: {EM*100:.2f}, F1: {F1*100:.2f}")

In [9]:
model_name_ = "Qwen/Qwen3-1.7B"
tokenizer_ = AutoTokenizer.from_pretrained(model_name_)
model_ = AutoModelForCausalLM.from_pretrained(model_name_).to(device)
model_.eval()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (up_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (down_proj): Linear(in_features=6144, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): Qwe

In [10]:
model_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwe

In [11]:
print(model.config)

Qwen3Config {
  "architectures": [
    "Qwen3ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "dtype": "float32",
  "eos_token_id": 151645,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_types": [
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention"
  ],
  "max_position_embeddings": 40960,
  "max_wi

In [12]:
main(model,model_,tokenizer)

  0%|          | 6/7405 [00:03<1:13:29,  1.68it/s]


KeyboardInterrupt: 