In [1]:
from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, re, json

class KeywordExtractor:
    def __init__(self, model_path: str):
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, trust_remote_code=True, local_files_only=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map="auto",
            trust_remote_code=True, local_files_only=True
        )
        self.model.eval()
        self.model.config.pad_token_id = self.tokenizer.pad_token_id

    # ---------- 构造 Prompt ----------
    def build_prompt(self, question: str) -> str:
        return f"""
You are a highly precise extraction assistant.

Task: Extract only the key entities, attributes, numbers, or filtering conditions.

Return **only** this JSON object (no extra text):
{{
  "keywords": ["<keyword1>", "<keyword2>"]
}}

Input Question:
{question}
""".strip()

    # ---------- 通用工具 ----------
    @staticmethod
    def extract_json_block(text: str):
        """提取首个 {...} 并反序列化"""
        m = re.search(r"\{[\s\S]*?\}", text)
        return json.loads(m.group()) if m else {}

    # ---------- 单问 ----------
    def extract_keywords_one(self, question: str) -> List[str]:
        prompt = self.build_prompt(question)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        out = self.model.generate(
            **inputs,
            max_new_tokens=128,
            do_sample=False           # 确定性
        )
        # 解码新增部分
        generated = out[:, inputs.input_ids.shape[1]:]
        text = self.tokenizer.decode(generated[0], skip_special_tokens=True)

        # 去掉 ```json ... ``` 包裹
        text = re.sub(r"^```(?:json)?\s*([\s\S]*?)\s*```$", r"\1", text.strip(), flags=re.DOTALL)

        parsed = self.extract_json_block(text)
        return parsed.get("keywords", [])

    # ---------- 批量 ----------
    def batch_extract(self, questions: List[str]) -> List[List[str]]:
        # 一次喂给模型，避免 asyncio 阻塞和显存重复
        prompts = [self.build_prompt(q) for q in questions]
        tokens  = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.model.device)

        outs = self.model.generate(**tokens, max_new_tokens=128, do_sample=False)
        gens = [o[tokens.input_ids.shape[1]:] for o in outs]

        texts = self.tokenizer.batch_decode(gens, skip_special_tokens=True)
        clean  = [re.sub(r"^```(?:json)?\s*([\s\S]*?)\s*```$", r"\1", t.strip(), flags=re.DOTALL)
                  for t in texts]
        return [self.extract_json_block(t).get("keywords", []) for t in clean]

In [3]:

extractor = KeywordExtractor("/home/yangliu26/qwen3-8b")


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [5]:

questions = [
    "Which stores have sales exceeding 50,000?",
    "What are the popular restaurants in Beijing in 2023?",
    "List the training institutions that offer courses in Python and deep learning."
]
print(questions)


['Which stores have sales exceeding 50,000?', 'What are the popular restaurants in Beijing in 2023?', 'List the training institutions that offer courses in Python and deep learning.']


In [6]:
prompts = [build_prompt(q) for q in questions]
tokens  = tokenizer(prompts, padding=True, return_tensors="pt").to(self.model.device)

NameError: name 'self' is not defined

In [None]:
outs = self.model.generate(**tokens, max_new_tokens=128, do_sample=False)
gens = [o[tokens.input_ids.shape[1]:] for o in outs]

texts = self.tokenizer.batch_decode(gens, skip_special_tokens=True)
clean  = [re.sub(r"^```(?:json)?\s*([\s\S]*?)\s*```$", r"\1", t.strip(), flags=re.DOTALL)
          for t in texts]

In [None]:
keywords_batch = extractor.batch_extract(questions)
for q, kw in zip(questions, keywords_batch):
    print(q, "->", kw)