<a href="https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/docs/examples/index_structs/struct_indices/SQLIndexDemo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text-to-SQL Guide (Query Engine + Retriever)

This is a basic guide to LlamaIndex's Text-to-SQL capabilities. 
1. We first show how to perform text-to-SQL over a toy dataset: this will do "retrieval" (sql query over db) and "synthesis".
2. We then show how to buid a TableIndex over the schema to dynamically retrieve relevant tables during query-time.
3. We finally show you how to define a text-to-SQL retriever on its own. 

If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.

In [None]:
%pip install llama-index-llms-openai

In [None]:
!pip install llama-index

In [None]:
import os
import openai

In [None]:
os.environ["OPENAI_API_KEY"] = "sk-.."
openai.api_key = os.environ["OPENAI_API_KEY"]

In [1]:
from IPython.display import Markdown, display

import logging
import sys

logging.basicConfig(level=logging.DEBUG)
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

### Create Database Schema

We use `sqlalchemy`, a popular SQL database toolkit, to create an empty `city_stats` Table

In [2]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

In [3]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

In [4]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

### Define SQL Database

We first define our `SQLDatabase` abstraction (a light wrapper around SQLAlchemy). 

In [1]:
from llama_index.core import SQLDatabase
# from llama_index.llms.openai import OpenAI

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
import sys,os
os.environ["HF_HOME"] = "/home/jeffye/.cache/huggingface_speed/"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# os.environ['HTTP_PROXY'] = "192.168.1.45:10809"
# os.environ['HTTPS_PROXY'] = "192.168.1.45:10809"
# os.environ["HF_HUB_CACHE"] = "/home/jeffye/.cache/huggingface_hub"


from llama_index.core import Settings
from llama_index.llms.huggingface import HuggingFaceLLM

from llama_index.core.prompts.base import PromptTemplate
import torch
import logging
import sys
from functools import partial

# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

# os.environ["MODELSCOPE_CACHE"] = "D:/dataset/cache/modelscope/"
# from llama_index.llms.modelscope import ModelScopeLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from llama_index.core.callbacks import (
    CallbackManager,
    LlamaDebugHandler,
    CBEventType,
)

# from llama_index.callbacks.aim import AimCallback

# aim_callback = AimCallback(repo="./")
# callback_manager = CallbackManager([aim_callback])


llama_debug = LlamaDebugHandler(print_trace_on_end=True)
# callback_manager = CallbackManager([llama_debug, aim_callback])
callback_manager = CallbackManager([llama_debug])

system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
# system_prompt = "you are a very helpful assistant";

# This will wrap the default prompts that are internal to llama-index
# query_wrapper_prompt = PromptTemplate("<|USER|>{query_str}<|ASSISTANT|>")
# query_wrapper_prompt = PromptTemplate("{query_str}\n<|im_start|>assistant")
# <|im_start|>assistant

use_qianfan = True

if not use_qianfan:
    llm = HuggingFaceLLM(
        context_window=4096,
        max_new_tokens=256,
        generate_kwargs={"temperature": 0.2, "do_sample": True},
        # system_prompt=system_prompt,
        # query_wrapper_prompt=query_wrapper_prompt,
        tokenizer_name="Qwen/Qwen1.5-7B-Chat",
        model_name="Qwen/Qwen1.5-7B-Chat",
        device_map="auto",
        # stopping_ids=[50278, 50279, 50277, 1, 0],
        # stopping_ids=[151645, 151644][:1],
        tokenizer_kwargs={"max_length": 4096, "trust_remote_code": True},
        # uncomment this if using CUDA to reduce memory usage
        model_kwargs={"torch_dtype": torch.float16, "trust_remote_code": True},
        callback_manager=callback_manager,
        is_chat_model=True,
    )

    # very ugly workaround for a bug
    llm._tokenizer.apply_chat_template = partial(llm._tokenizer.apply_chat_template, add_generation_prompt=True)
    
else:
    import os

    from langchain_community.llms import QianfanLLMEndpoint
    from llama_index.llms.langchain import LangChainLLM

    os.environ["QIANFAN_AK"] = "BbTpj7vuUUk7GOt6t6X8XjtO"
    os.environ["QIANFAN_SK"] = "LzTSlYubQ3bapZkfLmqzYvWmw644WpUR"

    llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K", streaming=False)
    # llm = QianfanLLMEndpoint(model="ERNIE-3.5-8K", streaming=False)
    llm = LangChainLLM(llm=llm)

Settings.llm = llm


In [3]:
# Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-m3")

  return self.fget.__get__(instance, owner)()


In [3]:
# Use Message request
from llama_index.core.base.llms.types import MessageRole, ChatMessage
import os

messages = [
    ChatMessage(
        role=MessageRole.SYSTEM, content="You are a helpful assistant."
    ),
    ChatMessage(role=MessageRole.USER, content="who are you?"),
]
resp = llm.chat(messages, max_length=4000)
print(resp)



assistant: system: You are a helpful assistant.

user: 谁是你？

assistant: 您好，我是百度研发的知识增强大语言模型，中文名是文心一言，英文名是ERNIE Bot。我能够与人对话互动，回答问题，协助创作，高效便捷地帮助人们获取信息、知识和灵感。


In [10]:
query = "Given an input question, first create a syntactically correct sqlite query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use tables listed below.\nTable 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: .\n\nQuestion: Which city has the highest population?\nSQLQuery: "

messages = [
    ChatMessage(
        role=MessageRole.SYSTEM, content="You are a helpful assistant."
    ),
    ChatMessage(role=MessageRole.USER, content=query),
]
resp = llm.chat(messages, max_length=4000)
print(resp)




assistant: SQLQuery: SELECT city_name, population FROM city_stats ORDER BY population DESC LIMIT 1;
SQLResult: It's not possible to provide the exact result as we don't have the data from the table 'city_stats', but the result will be a table with one row and two columns. The first column 'city_name' will have the name of the city with the highest population, and the second column 'population' will have the population of that city.
Answer: Since we don't have the data, we can't give a specific city name. But the city name will be the first row in the 'city_name' column of the result set after executing the above query on the actual 'city_stats' table.


In [9]:
llama_debug.get_event_pairs()

[]

We add some testing data to our SQL database.

In [11]:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {
        "city_name": "Chicago",
        "population": 2679000,
        "country": "United States",
    },
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [12]:
# view current table
stmt = select(
    city_stats_table.c.city_name,
    city_stats_table.c.population,
    city_stats_table.c.country,
).select_from(city_stats_table)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)

[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]


### Query Index

We first show how we can execute a raw SQL query, which directly executes over the table.

In [13]:
from sqlalchemy import text

with engine.connect() as con:
    rows = con.execute(text("SELECT city_name from city_stats"))
    for row in rows:
        print(row)

('Chicago',)
('Seoul',)
('Tokyo',)
('Toronto',)


In [4]:
from sqlalchemy import create_engine
import urllib
db_password = "4402115bac2c1c68"
db_password_quoted = urllib.parse.quote(db_password)
  
# 定义数据库连接字符串  
# 格式通常为: mysql+驱动名称://用户名:密码@主机地址:端口/数据库名  
DATABASE_URL = f"mysql+mysqldb://root:{db_password_quoted}@59.68.29.90:3306/liukun"  
  
# 使用 create_engine 创建引擎  
engine = create_engine(DATABASE_URL, echo=False)  

In [6]:
# from sqlalchemy.ext.declarative import declarative_base
# from sqlalchemy.orm import sessionmaker
from sqlalchemy import text

# sql_database = SQLDatabase(engine, include_tables=["city_stats"])
# sql_database = SQLDatabase(engine)
sql_database = SQLDatabase(engine,
                           include_tables=["administrative_division_city", "administrative_division_community_village",
                                           "administrative_division_district",
                                           "administrative_division_town", "city_roads", "data_governance_achievements",
                                           "data_source_table", "real_estate", "real_buildings", "resident_info",
                                           "standard_address", "units", "water_bodies_rivers", "data_collection_table"])
# Session = sessionmaker(bind=engine)
# session = Session()

# session.query()

with engine.connect() as con:
    rows = con.execute(text("SELECT * FROM real_buildings"))
    # result_proxy = con.execute("select * from actual_buildings") # 返回值为ResultProxy类型
    # result_proxy = con.execute("SELECT * from m_song LIMIT 1")
    # result = result_proxy.fetchall()
    for row in rows:
        print(row)

('8001', '001', '湖北省鄂州市鄂城区A道A社区1号', '湖北省', '鄂州市', '鄂城区', 'A道', 'A社区', '小区A', 2000, 0)
('8002', '002', '湖北省鄂州市华容区B镇B村2号', '湖北省', '鄂州市', '华容区', 'B镇', 'B村', '小区B', 2005, 0)
('8003', '003', '湖北省鄂州市梁子湖区C乡C村3号', '湖北省', '鄂州市', '梁子湖区', 'C乡', 'C村', '小区C', 2010, 0)
('8004', '004', '湖北省鄂州市鄂城区D街道D社区4号', '湖北省', '鄂州市', '鄂城区', 'D街道', 'D社区', '小区D', 2008, 0)
('8005', '005', '湖北省鄂州市华容区E镇E村5号', '湖北省', '鄂州市', '华容区', 'E镇', 'E村', '小区E', 2015, 0)
('8006', '006', '湖北省鄂州市梁子湖区F乡F某村6号', '湖北省', '鄂州市', '梁子湖区', 'F乡', 'F某村', '小区F', 2003, 0)
('8007', '007', '湖北省鄂州市鄂城区G街道G社区7号', '湖北省', '鄂州市', '鄂城区', 'G街道', 'G社区', '小区G', 2012, 0)
('8008', '008', '湖北省鄂州市华容区H镇H村8号', '湖北省', '鄂州市', '华容区', 'H镇', 'H村', '小区H', 2009, 0)
('8009', '009', '湖北省鄂州市梁子湖区I乡I村9号', '湖北省', '鄂州市', '梁子湖区', 'I乡', 'I村', '小区I', 2011, 0)
('8010', '010', '湖北省鄂州市鄂城区J街道J社区10号', '湖北省', '鄂州市', '鄂城区', 'J街道', 'J社区', '小区J', 2018, 0)
('8011', '011', '湖北省鄂州市葛店经济技术开发区J街道J社区10号', '湖北省', '鄂州市', '葛店经济技术开发区', 'J街道', 'J社区', '小区J', 2019, 0)
('8012', '012', '湖北省鄂州市临空经济开发区G街道G社区

## Part 1: Text-to-SQL Query Engine
Once we have constructed our SQL database, we can use the NLSQLTableQueryEngine to
construct natural language queries that are synthesized into SQL queries.

Note that we need to specify the tables we want to use with this query engine.
If we don't the query engine will pull all the schema context, which could
overflow the context window of the LLM.

In [7]:
from llama_index.core.query_engine import NLSQLTableQueryEngine

# query_engine = NLSQLTableQueryEngine(
#     sql_database=sql_database, tables=["city_stats"], llm=llm, verbose=True, callback_manager=callback_manager
# )

context_query_kwargs = {
    "units": "units指实际存在的各种单位组织，通常指的是工商管理部门登记注册的各类企业、政府机关以及事业单位,其中政府机关和事业单位合起来称为党政机关，该表的数据全是湖北省鄂州关于党政机关和企业数量分布的数据",
    "data_source_table": "里面包含数公基数据各种数据表，如果要查询数所有数公基的数据量，则要查询所有Source_Table_Name字段值对应的表的数据的条数"
}
context_str_prefix = '''
数据库名称：liukun
语及指标定义:
数据总量：也称之为数据量，指的是结构化数据表的数据条数。
数据容量：指的是结构化数据表或非结构化数据的存储容量。
上图：指的是一标三实数据和城市数字模型在CIM平台上进行展示。
上图率：可上图数据数量占治理后数据数量的比例。举个例子“鄂州市数公基标准地址的上图率是多少”，首先这个问题涉及到的表名是“标准地址表”和数据治理成果表，标准地址上图率=数据治理成果表（Data_Governance_Achievements）中Collected_Data_Resource_Name字段值为“标准地址”对应的治理后数据量（Data_Volume_After_Governance字段对应的值）/标准地址表的中数据的总条数，其他的一些数据的上图率也以此类推
数据治理：指归集后的原始数据资源经过数据清洗转换、加工处理之后形成的标准化、高质量的数据。
治理率：治理后的数据量占治理前数据量的比例。数源单位：提供原始数据资源的政府单位。
鄂州市的行政区划：鄂城区（行政区）、华容区（行政区）、梁子湖区（行政区）、葛店经济技术开发区（功能区）、临空经济开发区（功能区）。鄂州市：湖北省内的一个地级市。
数源单位：提供原始数据资源的政府单位。
'''
query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["administrative_division_city", "administrative_division_community_village",
                                       "administrative_division_district",
                                       "administrative_division_town", "city_roads", "data_governance_achievements",
                                       "data_source_table", "real_estate", "real_buildings", "resident_info",
                                       "standard_address", "units", "water_bodies_rivers", "data_collection_table"],
    llm=llm, context_query_kwargs=context_query_kwargs, context_str_prefix=context_str_prefix
)
# query_str = "Which city has the highest population?"
# response = query_engine.query(query_str)

In [8]:
from llama_index.core.indices.struct_store.sql_retriever import BaseSQLParser
from llama_index.core.schema import QueryBundle

print(query_engine._sql_retriever._get_tables(""))

class QianfanSQLParser(BaseSQLParser):
    """Default SQL Parser."""

    def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str:
        """Parse response to SQL."""
        sql_query_start = response.find("SQLQuery:")
        if sql_query_start != -1:
            response = response[sql_query_start:]
            # TODO: move to removeprefix after Python 3.9+
            if response.startswith("SQLQuery:"):
                response = response[len("SQLQuery:") :]
        sql_result_start = response.find("SQLResult:")
        if sql_result_start != -1:
            response = response[:sql_result_start]
        return response.strip().strip("```").strip()

[SQLTableSchema(table_name='administrative_division_city', context_str=None), SQLTableSchema(table_name='administrative_division_community_village', context_str=None), SQLTableSchema(table_name='administrative_division_district', context_str=None), SQLTableSchema(table_name='administrative_division_town', context_str=None), SQLTableSchema(table_name='city_roads', context_str=None), SQLTableSchema(table_name='data_governance_achievements', context_str=None), SQLTableSchema(table_name='data_source_table', context_str='里面包含数公基数据各种数据表，如果要查询数所有数公基的数据量，则要查询所有Source_Table_Name字段值对应的表的数据的条数'), SQLTableSchema(table_name='real_estate', context_str=None), SQLTableSchema(table_name='real_buildings', context_str=None), SQLTableSchema(table_name='resident_info', context_str=None), SQLTableSchema(table_name='standard_address', context_str=None), SQLTableSchema(table_name='units', context_str='units指实际存在的各种单位组织，通常指的是工商管理部门登记注册的各类企业、政府机关以及事业单位,其中政府机关和事业单位合起来称为党政机关，该表的数据全是湖北省鄂州关于党政机关和企业数量分布的数据'), SQLTab

NameError: name 'QueryBundle' is not defined

In [10]:
query_str = "鄂州市有多少60岁以上的老人？"
response = query_engine.query(query_str)

print(response)

INFO:llama_index.core.indices.struct_store.sql_retriever:> Table desc str: Table 'actual_buildings' has columns: unified_code (CHAR(32) COLLATE "utf8mb4_0900_ai_ci"): '统一编码', building_id (CHAR(32) COLLATE "utf8mb4_0900_ai_ci"): '建筑物ID', standard_address (VARCHAR(255) COLLATE "utf8mb4_0900_ai_ci"): '坐落标准地址', province_administrative_division_name (VARCHAR(255) COLLATE "utf8mb4_0900_ai_ci"): '省行政区划名称', city_administrative_division_name (VARCHAR(255) COLLATE "utf8mb4_0900_ai_ci"): '市行政区划名称', district_county_administrative_division_name (VARCHAR(255) COLLATE "utf8mb4_0900_ai_ci"): '区县行政区划名称', town_street_administrative_division_name (VARCHAR(255) COLLATE "utf8mb4_0900_ai_ci"): '乡镇街道行政区划名称', community_village_administrative_division_name (VARCHAR(255) COLLATE "utf8mb4_0900_ai_ci"): '社区村行政区划名称', community_name (VARCHAR(255) COLLATE "utf8mb4_0900_ai_ci"): '小区名称', construction_year (YEAR): '建筑年份', demolished (TINYINT): '是否拆除', with comment: (实有建筑表) and foreign keys: . The table description is: 

**********
Trace: query
    |_templating -> 4.6e-05 seconds
    |_llm -> 7.354545 seconds
    |_synthesize -> 13.640642 seconds
      |_templating -> 3.1e-05 seconds
      |_llm -> 13.638391 seconds
**********
很抱歉，由于我之前的SQL查询语句存在错误，我无法为您提供鄂州市60岁以上老人的准确数量。请允许我重新编写一个正确的SQL查询语句，以获取您所需的信息。正确的SQL查询语句可能类似于：

```sql
SELECT COUNT(*) 
FROM population_info 
WHERE age >= 60 AND city = '鄂州市';
```

请注意，上述SQL查询语句假设存在一个名为`population_info`的表，其中包含有关人口的信息，包括年龄和城市等字段。此外，城市字段应该与查询中的城市名称完全匹配。如果实际表结构或字段名称与此不同，请相应地调整查询语句。

然而，我无法直接执行SQL查询来获取实际结果，因为我是一个AI文本生成模型，而不是一个数据库查询工具。您需要在具有相应数据库访问权限的环境中执行此查询，以获取鄂州市60岁以上老人的实际数量。


In [19]:
# print(llama_debug.get_event_time_info(CBEventType.LLM))
# llama_debug.get_events(CBEventType.LLM)[1].payload
len(llama_debug.get_events(CBEventType.LLM))

8

In [24]:
llama_debug.get_events(CBEventType.LLM)[4].payload
# display(Markdown(f"<b>{response}</b>"))
# print(response)
# query_engine._sql_retriever

{<EventPayload.MESSAGES: 'messages'>: [ChatMessage(role=<MessageRole.USER: 'user'>, content='Given an input question, synthesize a response from the query results.\nQuery: 统计武昌区人口的不同年龄段的数量分布? (21-30,31-40,41-50二个年龄段)\nSQL: SELECT COUNT(*) AS \'人口数量\', age区间, COUNT(*) as \'人口数量\'\nFROM population_info\nWHERE address->\'区/县\' = \'武昌区\' AND age BETWEEN 21 AND 30 OR age BETWEEN 31 AND 40 OR age BETWEEN 41 AND 50\nGROUP BY age区间\nSQL Response: Error: Statement "SELECT COUNT(*) AS \'人口数量\', age区间, COUNT(*) as \'人口数量\'\\nFROM population_info\\nWHERE address->\'区/县\' = \'武昌区\' AND age BETWEEN 21 AND 30 OR age BETWEEN 31 AND 40 OR age BETWEEN 41 AND 50\\nGROUP BY age区间" is invalid SQL.\nResponse: ', additional_kwargs={})],
 <EventPayload.ADDITIONAL_KWARGS: 'additional_kwargs'>: {},
 <EventPayload.SERIALIZED: 'serialized'>: {'system_prompt': '',
  'pydantic_program_mode': <PydanticProgramMode.DEFAULT: 'default'>,
  'query_wrapper_prompt': {'metadata': {'prompt_type': <PromptType.CUSTOM: 'custom'

This query engine should be used in any case where you can specify the tables you want
to query over beforehand, or the total size of all the table schema plus the rest of
the prompt fits your context window.

## Part 2: Query-Time Retrieval of Tables for Text-to-SQL
If we don't know ahead of time which table we would like to use, and the total size of
the table schema overflows your context window size, we should store the table schema 
in an index so that during query time we can retrieve the right schema.

The way we can do this is using the SQLTableNodeMapping object, which takes in a 
SQLDatabase and produces a Node object for each SQLTableSchema object passed 
into the ObjectIndex constructor.


In [18]:
from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)


from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)


Now we can take our SQLTableRetrieverQueryEngine and query it for our response.

In [19]:

query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

In [20]:
response = query_engine.query("那个城市人口最多?")
display(Markdown(f"<b>{response}</b>"))



<b>很抱歉，我在输出中错误地包含了"SQL Response"这一部分，这实际上并不是SQL语句，而是对查询结果和回答的描述。真正的SQL语句是我在前面给出的部分：


```sql
SELECT city_name 
FROM city_stats 
ORDER BY population DESC 
LIMIT 1;
```
此SQL查询将返回人口最多的城市的名字。在给出的示例数据中，这个名字是"Shanghai"。所以，回答这个问题，"人口最多的城市是Shanghai"。感谢您指出这个问题。</b>

You can also add additional context information for each table schema you define.

In [21]:
# manually set context text
city_stats_text = (
    "This table gives information regarding the population and country of a"
    " given city.\nThe user will query with codewords, where 'foo' corresponds"
    " to population and 'bar'corresponds to city."
)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
]

## Part 3: Text-to-SQL Retriever

So far our text-to-SQL capability is packaged in a query engine and consists of both retrieval and synthesis.

You can use the SQL retriever on its own. We show you some different parameters you can try, and also show how to plug it into our `RetrieverQueryEngine` to get roughly the same results.

In [22]:
from llama_index.core.retrievers import NLSQLRetriever

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=True, verbose=True
)

In [23]:
results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)



In [24]:
from llama_index.core.response.notebook_utils import display_source_node

for n in results:
    display_source_node(n)

**Node ID:** ae761db0-2d62-4c17-989c-8e1ef1166572<br>**Similarity:** None<br>**Text:** Error: Statement 'sql\nSELECT city_name, population \nFROM city_stats \nORDER BY population DESC ...<br>

In [25]:
# default retrieval (return_raw=False)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=False
)

In [26]:
results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)



In [27]:
# NOTE: all the content is in the metadata
for n in results:
    display_source_node(n, show_source_metadata=True)

**Node ID:** e0bb4e4a-5872-4dc1-ab0c-f0944ae90896<br>**Similarity:** None<br>**Text:** Error: Statement "sql\nSELECT city_name, population\nFROM city_stats\nORDER BY population DESC\nL...<br>**Metadata:** {}<br>

### Plug into our `RetrieverQueryEngine`

We compose our SQL Retriever with our standard `RetrieverQueryEngine` to synthesize a response. The result is roughly similar to our packaged `Text-to-SQL` query engines.

In [28]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)

In [29]:
response = query_engine.query(
    "Return the top 5 cities (along with their populations) with the highest population."
)



In [30]:
print(str(response))

To return the top 5 cities with the highest population, you can use the following SQL query:


```sql
SELECT city_name, population
FROM city_stats
ORDER BY population DESC
LIMIT 5;
```
This query selects the `city_name` and `population` columns from the `city_stats` table. It then orders the result set in descending order based on the population column and limits the output to the top 5 rows.

Note: The original error message indicates that the provided SQL statement is invalid, but the query itself seems to be correct. It's possible that the error message is referring to something else in the context or there might be an issue with the formatting or execution of the query. However, the query itself is a valid SQL statement to retrieve the top 5 cities with the highest population from the `city_stats` table.
