In [29]:
import json
from schema_parser import load_schema
from keyword_extraction_with_llm import extract_keywords
from typing import Dict, Any
import os

TRAIN_JSON_PATH = r"/home/yangliu26/data/train/train.json"
SCHEMA_JSON_PATH = r"/home/yangliu26/data/train/train_tables.json"

# 加载schema信息
def get_schema_map(schema_json_path: str) -> Dict[str, Any]:
    schema = load_schema(schema_json_path)
    db_schema_map = {}
    for db in schema:
        db_id = db["db_id"] if isinstance(db, dict) and "db_id" in db else db.get("db_id", "")
        db_schema_map[db_id] = db
    return db_schema_map

def link_keywords_to_schema(keywords, schema_info):
    # 简单schema linking逻辑：关键词与表名、字段名做模糊匹配
    linked = []
    tables = schema_info.get("table_names_original", [])
    columns = [col[1] for col in schema_info.get("column_names_original", [])]
    for kw in keywords:
        for t in tables:
            if kw.lower() in t.lower():
                linked.append((kw, "table", t))
        for c in columns:
            if kw.lower() in c.lower():
                linked.append((kw, "column", c))
    return linked


In [30]:
with open(TRAIN_JSON_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

print(data[:1])

[{'db_id': 'movie_platform', 'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.', 'evidence': 'released in the year 1945 refers to movie_release_year = 1945;', 'SQL': 'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1'}]


In [45]:
# 各数据库地schema
schema_map = get_schema_map(SCHEMA_JSON_PATH)
print(schema_map.__class__)
# 打印第一个键值对
# first_key = next(iter(schema_map))
# print(first_key, schema_map[first_key])
# print(json.dumps(schema_map[first_key], indent=2, ensure_ascii=False))

<class 'dict'>


In [50]:
results = []
data = data[:1]
print(json.dumps(data, indent=2))

[
  {
    "db_id": "movie_platform",
    "question": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
    "evidence": "released in the year 1945 refers to movie_release_year = 1945;",
    "SQL": "SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1"
  }
]


In [None]:
for sample in data:
    db_id = sample["db_id"]
    question = sample["question"]
    schema_info = schema_map.get(db_id, {})
    keywords = extract_keywords(question)
    linking = link_keywords_to_schema(keywords, schema_info)
    results.append({
        "db_id": db_id,
        "question": question,
        "keywords": keywords,
        "schema_linking": linking
    })

In [None]:
# 输出结果
    out_path = os.path.join(os.path.dirname(__file__), "schema_linking_result.json")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"Schema linking结果已保存到: {out_path}")