<a href="https://colab.research.google.com/github/dspraneeth07/AXIOM--SQL-REFLEX-AGENT-v4/blob/main/notebooks/01_schema_graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q networkx faiss-cpu sentence-transformers tqdm


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import networkx as nx
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer

print("Libraries ready ✅")




Libraries ready ✅


In [11]:
!mkdir -p data/spider
!wget -q https://raw.githubusercontent.com/taoyds/spider/b7b5b8c890cd30e35427348bb9eb8c6d1350ca7c/evaluation_examples/examples/tables.json -O data/spider/tables.json
import json

with open("data/spider/tables.json", "r") as f:
    spider_tables = json.load(f)

print("Loaded tables:", len(spider_tables))




Loaded tables: 166


In [12]:
example = spider_tables[0]
example.keys()


dict_keys(['column_names', 'column_names_original', 'column_types', 'db_id', 'foreign_keys', 'primary_keys', 'table_names', 'table_names_original'])

In [13]:
def build_schema_graph(spider_tables):
    G = nx.Graph()

    for db in spider_tables:
        db_id = db["db_id"]

        tables = db["table_names_original"]
        columns = db["column_names_original"]
        pk = db["primary_keys"]
        fk = db["foreign_keys"]

        # Add tables as nodes
        for t in tables:
            G.add_node(
                f"{db_id}.{t}",
                db=db_id,
                table=t,
                type="table"
            )

        # Add FK → PK edges
        for fk_col, pk_col in fk:
            fk_table_idx = columns[fk_col][0]
            pk_table_idx = columns[pk_col][0]

            fk_table = tables[fk_table_idx]
            pk_table = tables[pk_table_idx]

            G.add_edge(
                f"{db_id}.{fk_table}",
                f"{db_id}.{pk_table}",
                type="foreign_key"
            )

    return G


In [14]:
schema_graph = build_schema_graph(spider_tables)

print("Total nodes (tables):", schema_graph.number_of_nodes())
print("Total edges (relationships):", schema_graph.number_of_edges())


Total nodes (tables): 876
Total edges (relationships): 745


In [15]:
from collections import defaultdict

def add_column_cousage_edges(G, spider_tables):
    usage = defaultdict(set)

    for db in spider_tables:
        db_id = db["db_id"]
        columns = db["column_names_original"]
        tables = db["table_names_original"]

        for _, col in columns:
            if col != "*":
                usage[db_id].add(col)

        table_nodes = [f"{db_id}.{t}" for t in tables]
        for i in range(len(table_nodes)):
            for j in range(i+1, len(table_nodes)):
                G.add_edge(
                    table_nodes[i],
                    table_nodes[j],
                    type="co_usage"
                )

add_column_cousage_edges(schema_graph, spider_tables)

print("Edges after co-usage:", schema_graph.number_of_edges())


Edges after co-usage: 3113


In [16]:
def build_table_texts(spider_tables):
    texts = []
    table_ids = []

    for db in spider_tables:
        db_id = db["db_id"]
        tables = db["table_names_original"]
        columns = db["column_names_original"]

        table_columns = {i: [] for i in range(len(tables))}

        for t_idx, col in columns:
            if t_idx != -1:
                table_columns[t_idx].append(col)

        for i, table in enumerate(tables):
            text = f"Database {db_id}, Table {table}, Columns: {', '.join(table_columns[i])}"
            texts.append(text)
            table_ids.append(f"{db_id}.{table}")

    return texts, table_ids


In [17]:
texts, table_ids = build_table_texts(spider_tables)

embedder = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedder.encode(texts, normalize_embeddings=True)

dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(np.array(embeddings))

print("FAISS index built ✅")
print("Indexed tables:", len(table_ids))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

FAISS index built ✅
Indexed tables: 876


In [19]:
def retrieve_tables(question, k=5):
    q_emb = embedder.encode([question], normalize_embeddings=True)
    scores, idxs = index.search(np.array(q_emb), k)

    return [table_ids[i] for i in idxs[0]]
retrieve_tables("List students and their courses")


['student_transcripts_tracking.Student_Enrolment_Courses',
 'e_learning.Courses',
 'student_transcripts_tracking.Courses',
 'student_assessment.Courses',
 'e_learning.Student_Course_Enrolment']

In [21]:
import os
os.makedirs("outputs", exist_ok=True)


In [22]:
import os
import pickle

os.makedirs("outputs", exist_ok=True)

with open("outputs/schema_graph.pkl", "wb") as f:
    pickle.dump(schema_graph, f)

print("Saved schema_graph.pkl successfully")


Saved schema_graph.pkl successfully


In [23]:
import pickle

with open("outputs/schema_graph.pkl", "wb") as f:
    pickle.dump(schema_graph, f)

faiss.write_index(index, "outputs/schema_faiss.index")

with open("outputs/table_ids.pkl", "wb") as f:
    pickle.dump(table_ids, f)

print("Schema graph + FAISS saved ✅")


Schema graph + FAISS saved ✅
