In [1]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    insert,
    column
)

In [2]:
engine = create_engine("sqlite+pysqlite:///:memory:", echo=True)
metadata_obj = MetaData()

In [7]:
city_states_table = Table(
    "city_states",
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False)
)
metadata_obj.create_all(engine)

2023-11-26 14:03:23,529 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-26 14:03:23,530 INFO sqlalchemy.engine.Engine PRAGMA main.table_info("city_states")
2023-11-26 14:03:23,531 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-11-26 14:03:23,532 INFO sqlalchemy.engine.Engine PRAGMA temp.table_info("city_states")
2023-11-26 14:03:23,532 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-11-26 14:03:23,533 INFO sqlalchemy.engine.Engine 
CREATE TABLE city_states (
	city_name VARCHAR(16) NOT NULL, 
	population INTEGER, 
	country VARCHAR(16) NOT NULL, 
	PRIMARY KEY (city_name)
)


2023-11-26 14:03:23,534 INFO sqlalchemy.engine.Engine [no key 0.00041s] ()
2023-11-26 14:03:23,535 INFO sqlalchemy.engine.Engine COMMIT


In [8]:
rows = [
    {"city_name": "Toronto", "population": 2731571, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13929286, "country": "Japan"},
    {"city_name": "Berlin", "population": 600000, "country": "Germany"},
]

for row in rows:
    stmt = insert(city_states_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

2023-11-26 14:03:26,959 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-26 14:03:26,961 INFO sqlalchemy.engine.Engine INSERT INTO city_states (city_name, population, country) VALUES (?, ?, ?)
2023-11-26 14:03:26,961 INFO sqlalchemy.engine.Engine [generated in 0.00059s] ('Toronto', 2731571, 'Canada')
2023-11-26 14:03:26,962 INFO sqlalchemy.engine.Engine COMMIT
2023-11-26 14:03:26,963 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-26 14:03:26,963 INFO sqlalchemy.engine.Engine INSERT INTO city_states (city_name, population, country) VALUES (?, ?, ?)
2023-11-26 14:03:26,964 INFO sqlalchemy.engine.Engine [cached since 0.002948s ago] ('Tokyo', 13929286, 'Japan')
2023-11-26 14:03:26,964 INFO sqlalchemy.engine.Engine COMMIT
2023-11-26 14:03:26,965 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-26 14:03:26,965 INFO sqlalchemy.engine.Engine INSERT INTO city_states (city_name, population, country) VALUES (?, ?, ?)
2023-11-26 14:03:26,966 INFO sqlalchemy.engine.Engine [cache

In [5]:
import os
from dotenv import load_dotenv
import openai

load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

In [9]:
from llama_index import SQLDatabase, VectorStoreIndex

sql_database = SQLDatabase(engine, include_tables=["city_states"])

2023-11-26 14:03:32,753 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-26 14:03:32,754 INFO sqlalchemy.engine.Engine SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite~_%' ESCAPE '~' ORDER BY name
2023-11-26 14:03:32,754 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-11-26 14:03:32,756 INFO sqlalchemy.engine.Engine ROLLBACK
2023-11-26 14:03:32,756 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-26 14:03:32,756 INFO sqlalchemy.engine.Engine SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite~_%' ESCAPE '~' ORDER BY name
2023-11-26 14:03:32,757 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-11-26 14:03:32,757 INFO sqlalchemy.engine.Engine SELECT name FROM sqlite_temp_master WHERE type='table' AND name NOT LIKE 'sqlite~_%' ESCAPE '~' ORDER BY name
2023-11-26 14:03:32,758 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-11-26 14:03:32,759 INFO sqlalchemy.engine.Engine PRAGMA main.table_xinfo("city_states")
2023-11-26 14:03:32

#### Straight Forward way to Query a Table

In [12]:
from llama_index.indices.struct_store import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["city_states"]
)

In [13]:
query_str = "Which city has the highest population?"
response = query_engine.query(query_str)

2023-11-26 00:56:05,996 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-26 00:56:05,996 INFO sqlalchemy.engine.Engine PRAGMA main.table_xinfo("city_states")
2023-11-26 00:56:05,997 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-11-26 00:56:05,997 INFO sqlalchemy.engine.Engine ROLLBACK
2023-11-26 00:56:05,998 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-11-26 00:56:05,998 INFO sqlalchemy.engine.Engine PRAGMA main.foreign_key_list("city_states")
2023-11-26 00:56:05,998 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-11-26 00:56:05,998 INFO sqlalchemy.engine.Engine PRAGMA temp.foreign_key_list("city_states")
2023-11-26 00:56:05,999 INFO sqlalchemy.engine.Engine [raw sql] ()
2023-11-26 00:56:05,999 INFO sqlalchemy.engine.Engine SELECT sql FROM  (SELECT * FROM sqlite_master UNION ALL   SELECT * FROM sqlite_temp_master) WHERE name = ? AND type in ('table', 'view')
2023-11-26 00:56:05,999 INFO sqlalchemy.engine.Engine [raw sql] ('city_states',)
2023-11-26 00:56:06,000 INFO sqla

In [14]:
print(response.response)

The city with the highest population is Tokyo.


#### Building the table index

In [16]:
from llama_index.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema
)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_states"))
]

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex
)

In [17]:
# manually set extra context text
city_stats_text = (
    "This table gives information regarding the population and country of a given city.\n"
    "The user will query with codewords, where 'foo' corresponds to population and 'bar'"
    "corresponds to city."
)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
]


In [18]:
from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine

query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

In [19]:
response = query_engine.query("Which city has the highest population?")

KeyError: 'context'