In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import hf_hub_download
import os
import json
from inspect import signature

HF_TOKEN = 
model_name = "JesseLiu/qwen25-7b-pagerank"
SPECIAL_TOKENS = ["<degd>", "<ddd>", "<decgd>", "<demgd>", "<debgd>", "<dppd>", "<dpd>"]

class TokenDecoderWrapper:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

        print(signature(tokenizer.batch_decode).parameters)

        print("skip_special_tokens" in signature(tokenizer.batch_decode).parameters)
        self._accepts_skip = "skip_special_tokens" in signature(tokenizer.batch_decode).parameters

    def batch_decode(self, seqs, **kw):
        if self._accepts_skip:

            
            kw.setdefault("skip_special_tokens", False)
            return self.tokenizer.batch_decode(seqs, **kw)
        else:
            
            return [self.tokenizer.decode(s, clean_up_tokenization_spaces=False, **kw)
                    for s in seqs]

    def decode(self, seq, **kw):
        kw.setdefault("skip_special_tokens", False)
        return self.tokenizer.decode(seq, **kw)
    def __len__(self):
        return len(self.tokenizer)

    def __call__(self, *args, **kwargs):
        return self.tokenizer(*args, **kwargs)

    def __getattr__(self, name):
        return getattr(self.tokenizer, name)

raw_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
raw_tokenizer.pad_token = raw_tokenizer.eos_token
raw_tokenizer.padding_side = "right"
raw_tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
tokenizer = TokenDecoderWrapper(raw_tokenizer)  # ✅ 替代

print("Special tokens in tokenizer:", tokenizer.special_tokens_map_extended)

# ── model ─────────────────────────────────────────────────────────────────────
def load_base_and_merge(adapter_repo: str, tokenizer):
    cfg = json.load(open(hf_hub_download(adapter_repo, "adapter_config.json", token=HF_TOKEN)))
    base = AutoModelForCausalLM.from_pretrained(cfg["base_model_name_or_path"],
                                                device_map="auto", token=HF_TOKEN)
    base.resize_token_embeddings(len(tokenizer))
    merged = PeftModel.from_pretrained(base, adapter_repo, token=HF_TOKEN,
                                       is_trainable=True).merge_and_unload()
    return merged

def is_lora_repo(repo=model_name) -> bool:
    try:
        hf_hub_download(repo, "adapter_config.json", token=HF_TOKEN)
        return True
    except Exception:
        return False

if is_lora_repo():
    model = load_base_and_merge(model_name, tokenizer)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name,
                                                 device_map="auto", token=HF_TOKEN)
    model.resize_token_embeddings(len(tokenizer))

with torch.no_grad():
    input_emb = model.get_input_embeddings()
    for t in SPECIAL_TOKENS:
        tid = tokenizer.convert_tokens_to_ids(t)
        input_emb.weight[tid] = torch.randn_like(input_emb.weight[0])


prompt = "Question: Is dermatitis an indication for Eflornithine?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# generate
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=50)

# decode（保留特殊标记）
output_text = tokenizer.batch_decode(output)[0]
print("=== Generation Result ===")
print(output_text)
