# SQL Index Guide (Many Tables)

This guide is an extension to the [core SQL guide](SQLIndexDemo.ipynb). In this guide, we tackle the setting where you have a large number of tables in your database, and putting all the table schemas into the prompt may overflow the text-to-SQL prompt. We make use of our `SQLContextContainerBuilder` to store your table schemas within a LlamaIndex data structure.

We can then query this data structure for the appropriate context.

In [1]:
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [2]:
from llama_index import SQLDatabase, SimpleDirectoryReader, WikipediaReader, Document
from llama_index.indices.struct_store import SQLContextContainerBuilder
from IPython.display import Markdown, display

INFO:numexpr.utils:Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
NumExpr defaulting to 8 threads.


  from .autonotebook import tqdm as notebook_tqdm


### Create Database Schema + Test Data

Here we introduce a toy scenario where there are 100 tables (too big to fit into the prompt)

In [3]:
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column

In [4]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

In [5]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
all_table_names = ["city_stats"]
# create a ton of dummy tables
n = 100
for i in range(n):
    tmp_table_name = f"tmp_table_{i}"
    tmp_table = Table(
        tmp_table_name,
        metadata_obj,
        Column(f"tmp_field_{i}_1", String(16), primary_key=True),
        Column(f"tmp_field_{i}_2", Integer),
        Column(f"tmp_field_{i}_3", String(16), nullable=False),
    )
    all_table_names.append(f"tmp_table_{i}")

metadata_obj.create_all(engine)

In [6]:
# print tables
metadata_obj.tables.keys()

dict_keys(['city_stats', 'tmp_table_0', 'tmp_table_1', 'tmp_table_2', 'tmp_table_3', 'tmp_table_4', 'tmp_table_5', 'tmp_table_6', 'tmp_table_7', 'tmp_table_8', 'tmp_table_9', 'tmp_table_10', 'tmp_table_11', 'tmp_table_12', 'tmp_table_13', 'tmp_table_14', 'tmp_table_15', 'tmp_table_16', 'tmp_table_17', 'tmp_table_18', 'tmp_table_19', 'tmp_table_20', 'tmp_table_21', 'tmp_table_22', 'tmp_table_23', 'tmp_table_24', 'tmp_table_25', 'tmp_table_26', 'tmp_table_27', 'tmp_table_28', 'tmp_table_29', 'tmp_table_30', 'tmp_table_31', 'tmp_table_32', 'tmp_table_33', 'tmp_table_34', 'tmp_table_35', 'tmp_table_36', 'tmp_table_37', 'tmp_table_38', 'tmp_table_39', 'tmp_table_40', 'tmp_table_41', 'tmp_table_42', 'tmp_table_43', 'tmp_table_44', 'tmp_table_45', 'tmp_table_46', 'tmp_table_47', 'tmp_table_48', 'tmp_table_49', 'tmp_table_50', 'tmp_table_51', 'tmp_table_52', 'tmp_table_53', 'tmp_table_54', 'tmp_table_55', 'tmp_table_56', 'tmp_table_57', 'tmp_table_58', 'tmp_table_59', 'tmp_table_60', 'tmp_tabl

We introduce some test data into the `city_stats` table

In [7]:
from sqlalchemy import insert
rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.connect() as connection:
        cursor = connection.execute(stmt)
        connection.commit()

In [8]:
with engine.connect() as connection:
    cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
    print(cursor.fetchall())

[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]


### Using LlamaIndex to Store Table Schema Context

In [9]:
from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index import VectorStoreIndex

In [10]:
sql_database = SQLDatabase(engine)

In [None]:
sql_database.table_info

We dump the table schema information into a vector index. The vector index is stored within the context builder for future use.

In [12]:
table_node_mapping = SQLTableNodeMapping(sql_database)

table_schema_objs = []
for table_name in all_table_names:
    table_schema_objs.append(SQLTableSchema(table_name=table_name))

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex, 
)

INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens
> [build_index_from_nodes] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 6848 tokens
> [build_index_from_nodes] Total embedding token usage: 6848 tokens


### Query Index

Here we show a natural language query. 
1. We construct a `SQLTableRetrieverQueryEngine`. Note that we pass in the `ObjectRetriever` so that we can dynamically retrieve the table during query-time.
2. We run queries against the query engine.

In [13]:
query_engine = SQLTableRetrieverQueryEngine(
    sql_database,
    obj_index.as_retriever(similarity_top_k=1),
)

In [14]:
response = query_engine.query("Which city has the highest population?")

INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens
> [retrieve] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 7 tokens
> [retrieve] Total embedding token usage: 7 tokens
INFO:llama_index.indices.struct_store.sql_query:> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)) and foreign keys: .
> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)) and foreign keys: .
INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 335 tokens
> [query] Total LLM token usage: 335 tokens
INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens
> [query] Total embedding token usage: 0 tokens


In [15]:
str(response)

' Tokyo has the highest population with 13,960,000 people.'

In [16]:
response.metadata

{'result': [('Tokyo', 13960000)],
 'sql_query': 'SELECT city_name, population \nFROM city_stats \nORDER BY population DESC \nLIMIT 1;'}