In [2]:
import os
os.environ["OPENAI_API_KEY"] = ""

<b>Schema create</b>

In [84]:
import json, os, math, time, datetime, re, pathlib
from collections import defaultdict
from typing import Dict, Any, List, Tuple
from openai import OpenAI
from datetime import datetime, timezone
from typing import Dict, List, Tuple
import os, json, re
from typing import Dict, List, Tuple
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from openai import OpenAI

In [86]:
SCHEMA_PATH = "tables.json"        
OUT_PATH = "train_schema_embeddings.jsonl"  
EMBED_MODEL = "text-embedding-3-small"  
CHAT_MODEL   = "gpt-5-mini"           
LC_INDEX_DIR = "faiss_idx_tables"    

In [88]:
client = OpenAI()  

#open schema file
def load_schemas(path: str) -> Dict[str, Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        items = json.load(f)
    return {it["db_id"]: it for it in items}

In [122]:
#converts a database schema into a compact text string by listing each table with its column names
def schema_text(db: Dict[str, Any], max_cols_per_table: int = 24) -> str:
    tnames = db["table_names_original"]
    cols = db["column_names_original"]  
    by_table = defaultdict(list)

    for _, (tidx, cname) in enumerate(cols):
        if tidx >= 0:  
            by_table[tidx].append(str(cname))
    parts = []
    for i, t in enumerate(tnames):
        c = by_table[i][:max_cols_per_table]
        cols_txt = ", ".join(c)
        parts.append(f"{t}({cols_txt})")
    return " | ".join(parts)

In [124]:
# converts each schema into a readable text form mapped by its db_id
def batch(iterable, n=64):
    it = list(iterable)
    for i in range(0, len(it), n):
        yield it[i:i+n]
def build_texts(schemas: Dict[str, Dict[str, Any]]) -> Dict[str, str]:
    return {db_id: schema_text(db) for db_id, db in schemas.items()}

In [126]:
#creates embeddings for a list of texts by sending them to the model collecting the vector outputs, and returning them as a list of embeddings.
def embed_texts(texts: List[str], model: str) -> List[List[float]]:
    all_vecs = []
    for chunk in batch(texts, n=64):
        resp = client.embeddings.create(model=model, input=chunk)
        all_vecs.extend([d.embedding for d in resp.data])
        time.sleep(0.05)
    return all_vecs

In [128]:
#function loads all schemas, converts them to text, generates embeddings with the chosen model, and saves data into a JSON-lines file
def main():
    schemas = load_schemas(SCHEMA_PATH)
    db_ids = list(schemas.keys())
    db_texts = [schema_text(schemas[i]) for i in db_ids]

    vecs = embed_texts(db_texts, EMBED_MODEL)
    dims = len(vecs[0]) if vecs else 0
    now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")

    with open(OUT_PATH, "w", encoding="utf-8") as f:
        for db_id, text, emb in zip(db_ids, db_texts, vecs):
            rec = {
                "db_id": db_id,
                "text": text,
                "embedding": emb,
                "model": EMBED_MODEL,
                "dims": dims,
                "updated_at": now,
            }
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")
    print("finished")

if __name__ == "__main__":
    main()


finished


<b>DB selector</b>

In [98]:
#read embedding
def load_schema_embeddings(path: str) -> List[Dict]:
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                items.append(json.loads(line))
    for rec in items:
        v = rec["embedding"]
        n = math.sqrt(sum(x*x for x in v)) + 1e-12
        rec["_norm"] = [x / n for x in v]
    return items

SCHEMAS = load_schema_embeddings(OUT_PATH)

In [100]:
#generates a single embedding vector for the given text
def embed_query(text: str) -> List[float]:
    resp = client.embeddings.create(model=EMBED_MODEL, input=[text])
    return resp.data[0].embedding

#calculates the cosine similarity between vectors
def cosine(u: List[float], v: List[float]) -> float:
    nu = math.sqrt(sum(x*x for x in u)) + 1e-12
    uhat = [x/nu for x in u]
    return sum(x*y for x, y in zip(uhat, v))

#Return [(db_id, score)] highest to lowest.
def select_db(question: str, top_k: int = 3) -> List[Tuple[str, float]]:
    qv = embed_query(question)
    scores = []
    for rec in SCHEMAS:
        score = cosine(qv, rec["_norm"])
        scores.append((rec["db_id"], score))
    scores.sort(key=lambda t: t[1], reverse=True)
    return scores[:top_k]

In [102]:
if __name__ == "__main__":
    question = "What are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?"
    top = select_db(question, top_k=3)
    print("Top DB candidates:")
    for db_id, score in top:
        print(f"{db_id:20s}  score={score:.4f}")
    topDb = top[0][0]
    print("Selected:", topDb)

Top DB candidates:
department_management  score=0.4146
election              score=0.3079
scientist_1           score=0.2533
Selected: department_management


<b>Select table</b>

In [104]:
# schema parsing 
_TABLE_RE = re.compile(r"\s*([A-Za-z0-9_]+)\s*\(([^)]*)\)\s*")

def lc_load_db_records(path: str) -> List[Dict]:
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if s:
                out.append(json.loads(s))
    return out
 

In [106]:
#Convert a db 'text' into per-table strings
def lc_explode_tables(rec: Dict) -> Dict[str, str]:
    db = rec["db_id"]
    out = {}
    for piece in rec["text"].split("|"):
        piece = piece.strip()
        m = _TABLE_RE.match(piece)
        if not m:
            continue
        tname = m.group(1)
        cols  = ", ".join(c.strip() for c in m.group(2).split(",") if c.strip())
        key = f"{db}.{tname}"
        out[key] = f"{db}.{tname}({cols})"
    return out

def lc_build_table_texts(schema_path: str) -> Dict[str, str]:
    tbl_texts: Dict[str, str] = {}
    for rec in lc_load_db_records(schema_path):
        tbl_texts.update(lc_explode_tables(rec))
    return tbl_texts
    

In [108]:
# FAISS index
def lc_build_or_load_index(table_texts: Dict[str, str], embeddings: OpenAIEmbeddings) -> FAISS:
    if os.path.isdir(LC_INDEX_DIR):
        return FAISS.load_local(LC_INDEX_DIR, embeddings, allow_dangerous_deserialization=True)
    texts = list(table_texts.values())
    metas = [{"table_key": k} for k in table_texts.keys()]
    vs = FAISS.from_texts(texts=texts, embedding=embeddings, metadatas=metas)
    vs.save_local(LC_INDEX_DIR)
    return vs


In [110]:
# retrieval (with version-safe scores) 
def lc_similarity_search_tables(
    vs: FAISS,
    question: str,
    table_texts: Dict[str, str],
    k: int = 8,
    restrict_db: str | None = None,
) -> List[Tuple[str, float]]:
    """Return [(table_key, score)] where score is distance/relevance (lower usually closer)."""
    pool_k = min(max(k * 3, 16), len(table_texts))
    if hasattr(vs, "similarity_search_with_score"):
        docs_scores = vs.similarity_search_with_score(question, k=pool_k)
    else:
        docs_scores = vs.similarity_search_with_relevance_scores(question, k=pool_k)
        docs_scores = [(doc, float(score)) for doc, score in docs_scores]

    if restrict_db:
        docs_scores = [
            (d, s) for (d, s) in docs_scores
            if d.metadata.get("table_key", "").split(".", 1)[0] == restrict_db
        ]
    docs_scores = docs_scores[:k]
    return [(d.metadata["table_key"], float(s)) for (d, s) in docs_scores]


In [112]:
# This function asks an LLM to pick the minimal set of tables needed to answer a natural-language SQL
def lc_minimal_tables(question: str, candidates: List[str], llm: ChatOpenAI) -> List[str]:
    sys_msg = (
        "You are a precise SQL planner. Given a natural-language question and a list of "
        "candidate tables, return ONLY the minimal set of tables needed to answer it. "
        "Prefer a single table if it contains the requested columns. Respond as strict JSON."
    )
    user_msg = (
        f"Question:\n{question}\n\n"
        f"Candidate tables:\n" + "\n".join(f"- {c}" for c in candidates) + "\n\n"
        'Return JSON exactly like: {"tables": ["db.table", "..."]}'
    )
    resp = llm.invoke([SystemMessage(content=sys_msg), HumanMessage(content=user_msg)])
    content = resp.content or ""
    try:
        data = json.loads(content)
    except Exception:
        start, end = content.find("{"), content.rfind("}")
        if start != -1 and end != -1 and end > start:
            try:
                data = json.loads(content[start:end+1])
            except Exception:
                return candidates[:1]
        else:
            return candidates[:1]

    chosen = [t for t in data.get("tables", []) if t in candidates]
    return chosen or candidates[:1]


In [114]:
# finding the most relevant database tables for answering a userâ€™s question
class LCTableSelector:
    def __init__(self, schema_jsonl: str = OUT_PATH):
        self.table_texts = lc_build_table_texts(schema_jsonl)
        if not self.table_texts:
            raise RuntimeError("No tables parsed from schema JSONL. Check file format.")
        self.embeddings = OpenAIEmbeddings(model=EMBED_MODEL)
        self.vs = lc_build_or_load_index(self.table_texts, self.embeddings)
        self.llm = ChatOpenAI(model=CHAT_MODEL)

    def select(self, question: str, db: str | None = None, k: int = 8, minimal: bool = True) -> Dict:
        cands = lc_similarity_search_tables(self.vs, question, self.table_texts, k=k, restrict_db=db)
        cand_keys = [t for t, _ in cands]
        result = {"question": question, "candidates": cands}
        result["selected_tables"] = (
            lc_minimal_tables(question, cand_keys, self.llm) if (minimal and cand_keys)
            else (cand_keys[:1] if cand_keys else [])
        )
        return result

In [116]:
# example 
if __name__ == "__main__":
    # run this file alone (do NOT import any LlamaIndex code in the same session)
    selector = LCTableSelector(OUT_PATH)

    q  = "How many rooms whose capacity is less than 50 does the Lamberton building have?"
    out = selector.select(q, db="college_2", k=8, minimal=True)

    print("\nQuestion:", out["question"])
    print("Top candidates (lower score = closer):")
    for t, s in out["candidates"]:
        print(f"  {t:35s}  score={s:.3f}")
    print("Selected minimal set:", out["selected_tables"])


Question: How many rooms whose capacity is less than 50 does the Lamberton building have?
Top candidates (lower score = closer):
  college_2.classroom                  score=1.077
  college_2.section                    score=1.412
  college_2.department                 score=1.446
Selected minimal set: ['college_2.classroom']


In [118]:
import time

q  = "What is the name of the school with smallest enrollment size per state"
db = "soccer_2"

def run_total_time(name, selector_cls, init_args, q, db, k=8, minimal=True):
    t_start = time.perf_counter()
    sel = selector_cls(*init_args)                 # build
    out = sel.select(q, db, k=k, minimal=minimal)  
    for t, s in out["candidates"]:
        print(f"  {t:35s}  score={s:.3f}")
    print("Selected minimal set:", out["selected_tables"])
    
    # query
    total_s = time.perf_counter() - t_start
    print(f"{name} total time: {total_s:.2f} s")
    return total_s

t2 = run_total_time("LangChain/FAISS", LCTableSelector, (OUT_PATH,),q, db)

print("\n=== Summary ===")
print(f"LangChain/FAISS total:{t2:.2f} s")

  soccer_2.College                     score=1.340
Selected minimal set: ['soccer_2.College']
LangChain/FAISS total time: 4.78 s

=== Summary ===
LangChain/FAISS total:4.78 s
