In [None]:
#https://docs.llamaindex.ai/en/stable/examples/index_structs/struct_indices/SQLIndexDemo/

## Setup

In [None]:
# %pip install llama-index-llms-openai
# !pip install llama-index
!pip install matplotlib

In [3]:
import os
import openai

In [4]:
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
openai.api_key = os.environ["OPENAI_API_KEY"]

In [6]:
from IPython.display import Markdown, display

In [7]:
## Create Databse Schema

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

In [8]:
# Define SQL Database 

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from sqlalchemy import insert

llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")

sql_database = SQLDatabase(engine, include_tables=["city_stats"])

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {
        "city_name": "Chicago",
        "population": 2679000,
        "country": "United States",
    },
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

## Part 1 - simple text to sql

Asking questions directly of a table 

In [18]:
# Part 1 - Text To SQL Query Engine

from llama_index.core.query_engine import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["city_stats"], llm=llm
)
query_str = "Which city has the highest population?"
response = query_engine.query(query_str)

display(Markdown(f"{response}"))

Tokyo has the highest population among all cities, with a population of 13,960,000.

## Part 2 - Query time table retrieval 

Asking questions of several indexed tables

In [10]:
# Part 2 - Query Time Retrieval of Tabes for Text To SQL
# Here we store table schema in an index so at query time we can retrieve the right table schema

from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex


'''
# OPTIONAL - manually set context text
city_stats_text = (
    "This table gives information regarding the population and country of a"
    " given city.\nThe user will query with codewords, where 'foo' corresponds"
    " to population and 'bar'corresponds to city."
)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
]
'''

# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

In [11]:
response = query_engine.query("Which city has the highest population?")
display(Markdown(f"{response}"))

Tokyo has the highest population among all cities, with a population of 13,960,000.

## Part 3 - SQL Retriever + Query Engine 

Breaking down the SQL Retriever + Query Engine into its components. 

SQL Retriever gets the raw data that results from the user question. This plugs directly into the SQL Query Engine that takes these table results and returns a response to the user.

In [14]:
# View raw results of sql query

from llama_index.core.response.notebook_utils import display_source_node
from llama_index.core.retrievers import NLSQLRetriever


nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=True
)

results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)

for n in results:
    display_source_node(n)

**Node ID:** 62287818-ce3f-4537-af40-c529ed180067<br>**Similarity:** None<br>**Text:** [('Tokyo', 13960000), ('Seoul', 9776000), ('Toronto', 2930000), ('Chicago', 2679000)]<br>

In [15]:
# View individual results of the SQL query

nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=False
)

results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)

# NOTE: all the content is in the metadata
for n in results:
    display_source_node(n, show_source_metadata=True)

**Node ID:** 3a364b05-1b18-4ed7-b07b-bcd54e115365<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Tokyo', 'population': 13960000}<br>

**Node ID:** a029de07-99f8-4649-96e8-5be777ed2b66<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Seoul', 'population': 9776000}<br>

**Node ID:** dff385b4-24e4-428b-9246-f171d0bbd532<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Toronto', 'population': 2930000}<br>

**Node ID:** b09ea9e9-5e5c-407c-88e0-140c6c7ada17<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Chicago', 'population': 2679000}<br>

In [16]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)

response = query_engine.query(
    "Return the top 5 cities (along with their populations) with the highest population."
)

print(str(response))

The top 5 cities with the highest populations are:

1. Tokyo - 13,960,000
2. Seoul - 9,776,000
3. Toronto - 2,930,000
4. Chicago - 2,679,000
