In [2]:
import json
from schema_parser import load_schema
from schema_linker import SchemaLinker
from async_keyword_extractor import KeywordExtractor
from typing import Dict, Any
import os
import asyncio
from tqdm import tqdm

TRAIN_JSON_PATH = r"/home/yangliu26/data/train/train.json"
SCHEMA_JSON_PATH = r"/home/yangliu26/data/train/train_tables.json"
MODEL_PATH = r"/home/yangliu26/qwen3-8b"

# 加载schema信息
def get_schema_map(schema_json_path: str) -> Dict[str, Any]:
    schema = load_schema(schema_json_path)
    db_schema_map = {}
    for db in schema:
        db_id = db["db_id"] if isinstance(db, dict) and "db_id" in db else db.get("db_id", "")
        db_schema_map[db_id] = db
    return db_schema_map

def link_keywords_to_schema(keywords, schema_info):
    # 简单schema linking逻辑：关键词与表名、字段名做模糊匹配
    linked = []
    tables = schema_info.get("table_names_original", [])
    columns = [col[1] for col in schema_info.get("column_names_original", [])]
    for kw in keywords:
        for t in tables:
            if kw.lower() in t.lower():
                linked.append((kw, "table", t))
        for c in columns:
            if kw.lower() in c.lower():
                linked.append((kw, "column", c))
    return linked

async def async_main():
    with open(TRAIN_JSON_PATH, "r", encoding="utf-8") as f:
        data = json.load(f)
    # 取前10个数据作为测试
    data = data[:10]
    schema_map = get_schema_map(SCHEMA_JSON_PATH)
    extractor = KeywordExtractor(MODEL_PATH)
    linker = SchemaLinker()
    # 提取出所有问题
    questions = [sample["question"] for sample in data]
    # 提取出每个问题的关键词
    # all_keywords = await extractor.batch_extract(questions)
    print("🔍 Extracting keywords …")
    all_keywords = await extractor.batch_extract(
        tqdm(questions, desc="Keyword-extract", unit="q")
    )
    
    results = []
    for sample, keywords in zip(data, all_keywords):
        db_id = sample["db_id"]
        question = sample["question"]
        evidence = sample["evidence"]
        schema_info = schema_map.get(db_id, {})
        # schema-linking
        linker.build_index(schema_info)
        linking_results = linker.search(keywords)
        # 格式化
        formatted_linking = {}
        for matches in linking_results:
            for kw, schema_item, table_name, score in matches:
                if table_name not in formatted_linking:
                    formatted_linking[table_name] = []
                formatted_linking[table_name].append(schema_item)

        results.append({
            "db_id": db_id,
            "question": question,
            "evidence": evidence,
            "keywords": keywords,
            "schema_linking": formatted_linking
        })
        
    # 输出结果到当前文件目录下的schema_linking_result.json
    out_path = os.path.join(os.path.dirname(__file__), "schema_linking_result.json")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"Schema linking结果已保存到: {out_path}")

In [3]:
with open(TRAIN_JSON_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

print(data[:1])

[{'db_id': 'movie_platform', 'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.', 'evidence': 'released in the year 1945 refers to movie_release_year = 1945;', 'SQL': 'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1'}]


In [4]:
# 各数据库地schema
schema_map = get_schema_map(SCHEMA_JSON_PATH)
print(schema_map.__class__)
# 打印第一个键值对
# first_key = next(iter(schema_map))
# print(first_key, schema_map[first_key])
# print(json.dumps(schema_map[first_key], indent=2, ensure_ascii=False))

<class 'dict'>


In [5]:
results = []
data = data[:1]
print(json.dumps(data, indent=2))

[
  {
    "db_id": "movie_platform",
    "question": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
    "evidence": "released in the year 1945 refers to movie_release_year = 1945;",
    "SQL": "SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1"
  }
]


In [10]:
schema_map = get_schema_map(SCHEMA_JSON_PATH)
extractor = KeywordExtractor(MODEL_PATH)
linker = SchemaLinker()
# 提取出所有问题
questions = [sample["question"] for sample in data]
# 提取出每个问题的关键词
all_keywords = await extractor.batch_extract(questions)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


[Parsing Failed] Content: ```json
{
  "keywords": ["movie titles", "released", "1945", "sort", "descending", "popularity"]
}
``` 

Okay, let's tackle this query. The user is asking for movie titles released in 1945 and wants them sorted by popularity in descending order. First, I need to extract the key entities and attributes.

The main entities here are "movie titles" since that's what they're asking for. The year 1945 is a crucial filter, so that's definitely a keyword. The action of releasing is important too, so "released" should be included. 

Then there's the sorting part. The user mentioned sorting by popularity, so "popularity" is a key attribute. The order is descending, so "descending" needs to be in the list. 

Wait, should "sort" be included? The instruction says to extract essential elements. Since the user is asking to sort the listing, "sort" is part of the action. But maybe it's redundant because the sorting is implied by the mention of popularity and descending. Hmm, 

In [6]:
for sample in data:
    db_id = sample["db_id"]
    question = sample["question"]
    schema_info = schema_map.get(db_id, {})
    keywords = keyword(question)
    linking = link_keywords_to_schema(keywords, schema_info)
    results.append({
        "db_id": db_id,
        "question": question,
        "keywords": keywords,
        "schema_linking": linking
    })

NameError: name 'extract_keywords' is not defined

In [None]:
# 输出结果
    out_path = os.path.join(os.path.dirname(__file__), "schema_linking_result.json")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"Schema linking结果已保存到: {out_path}")