<a href="https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/docs/examples/index_structs/struct_indices/SQLIndexDemo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text-to-SQL Guide (Query Engine + Retriever)

This is a basic guide to LlamaIndex's Text-to-SQL capabilities. 
1. We first show how to perform text-to-SQL over a toy dataset: this will do "retrieval" (sql query over db) and "synthesis".
2. We then show how to buid a TableIndex over the schema to dynamically retrieve relevant tables during query-time.
3. We finally show you how to define a text-to-SQL retriever on its own. 

If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.

In [None]:
%pip install llama-index-llms-openai

In [None]:
!pip install llama-index

In [None]:
import os
import openai

In [None]:
os.environ["OPENAI_API_KEY"] = "sk-.."
openai.api_key = os.environ["OPENAI_API_KEY"]

In [2]:
from IPython.display import Markdown, display

import logging
import sys

# logging.basicConfig(level=logging.DEBUG)
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

### Create Database Schema

We use `sqlalchemy`, a popular SQL database toolkit, to create an empty `city_stats` Table

In [3]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

In [4]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

In [5]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

### Define SQL Database

We first define our `SQLDatabase` abstraction (a light wrapper around SQLAlchemy). 

In [6]:
from llama_index.core import SQLDatabase
# from llama_index.llms.openai import OpenAI

In [1]:
import sys,os
os.environ["HF_HOME"] = "/home/jeffye/.cache/huggingface_speed"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# os.environ['HTTP_PROXY'] = "192.168.1.45:10809"
# os.environ['HTTPS_PROXY'] = "192.168.1.45:10809"
# os.environ["HF_HUB_CACHE"] = "/home/jeffye/.cache/huggingface_hub"


from llama_index.core import Settings
from llama_index.llms.huggingface import HuggingFaceLLM

from llama_index.core.prompts.base import PromptTemplate
import torch
import logging
import sys
from functools import partial

# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

# os.environ["MODELSCOPE_CACHE"] = "D:/dataset/cache/modelscope/"
# from llama_index.llms.modelscope import ModelScopeLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from llama_index.core.callbacks import (
    CallbackManager,
    LlamaDebugHandler,
    CBEventType,
)

from llama_index.callbacks.aim import AimCallback

aim_callback = AimCallback(repo="./")
# callback_manager = CallbackManager([aim_callback])


llama_debug = LlamaDebugHandler(print_trace_on_end=True)
callback_manager = CallbackManager([llama_debug, aim_callback])


system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
# system_prompt = "you are a very helpful assistant";

# This will wrap the default prompts that are internal to llama-index
# query_wrapper_prompt = PromptTemplate("<|USER|>{query_str}<|ASSISTANT|>")
# query_wrapper_prompt = PromptTemplate("{query_str}\n<|im_start|>assistant")
# <|im_start|>assistant


llm = HuggingFaceLLM(
    context_window=4096,
    max_new_tokens=256,
    generate_kwargs={"temperature": 0.2, "do_sample": True},
    # system_prompt=system_prompt,
    # query_wrapper_prompt=query_wrapper_prompt,
    tokenizer_name="Qwen/Qwen1.5-7B-Chat",
    model_name="Qwen/Qwen1.5-7B-Chat",
    device_map="auto",
    # stopping_ids=[50278, 50279, 50277, 1, 0],
    # stopping_ids=[151645, 151644][:1],
    tokenizer_kwargs={"max_length": 4096, "trust_remote_code": True},
    # uncomment this if using CUDA to reduce memory usage
    model_kwargs={"torch_dtype": torch.float16, "trust_remote_code": True},
    callback_manager=callback_manager,
    is_chat_model=True,
)

# very ugly workaround for a bug
llm._tokenizer.apply_chat_template = partial(llm._tokenizer.apply_chat_template, add_generation_prompt=True)
Settings.llm = llm


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
# Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-m3")

In [None]:
# Use Message request
from llama_index.core.base.llms.types import MessageRole, ChatMessage
import os

messages = [
    ChatMessage(
        role=MessageRole.SYSTEM, content="You are a helpful assistant."
    ),
    ChatMessage(role=MessageRole.USER, content="who are you?"),
]
resp = llm.chat(messages, max_length=4000)
print(resp)

In [10]:
llama_debug.get_event_pairs()

[]

We add some testing data to our SQL database.

In [11]:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {
        "city_name": "Chicago",
        "population": 2679000,
        "country": "United States",
    },
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [12]:
# view current table
stmt = select(
    city_stats_table.c.city_name,
    city_stats_table.c.population,
    city_stats_table.c.country,
).select_from(city_stats_table)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)

[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]


### Query Index

We first show how we can execute a raw SQL query, which directly executes over the table.

In [13]:
from sqlalchemy import text

with engine.connect() as con:
    rows = con.execute(text("SELECT city_name from city_stats"))
    for row in rows:
        print(row)

('Chicago',)
('Seoul',)
('Tokyo',)
('Toronto',)


## Part 1: Text-to-SQL Query Engine
Once we have constructed our SQL database, we can use the NLSQLTableQueryEngine to
construct natural language queries that are synthesized into SQL queries.

Note that we need to specify the tables we want to use with this query engine.
If we don't the query engine will pull all the schema context, which could
overflow the context window of the LLM.

In [None]:
from llama_index.core.query_engine import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["city_stats"], llm=llm, verbose=True, callback_manager=callback_manager
)
query_str = "Which city has the highest population?"
response = query_engine.query(query_str)

In [None]:
display(Markdown(f"<b>{response}</b>"))
print(response)
# query_engine._sql_retriever

This query engine should be used in any case where you can specify the tables you want
to query over beforehand, or the total size of all the table schema plus the rest of
the prompt fits your context window.

## Part 2: Query-Time Retrieval of Tables for Text-to-SQL
If we don't know ahead of time which table we would like to use, and the total size of
the table schema overflows your context window size, we should store the table schema 
in an index so that during query time we can retrieve the right schema.

The way we can do this is using the SQLTableNodeMapping object, which takes in a 
SQLDatabase and produces a Node object for each SQLTableSchema object passed 
into the ObjectIndex constructor.


In [None]:
from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)


from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)


Now we can take our SQLTableRetrieverQueryEngine and query it for our response.

In [None]:

query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

In [None]:
response = query_engine.query("那个城市人口最多?")
display(Markdown(f"<b>{response}</b>"))

You can also add additional context information for each table schema you define.

In [None]:
# manually set context text
city_stats_text = (
    "This table gives information regarding the population and country of a"
    " given city.\nThe user will query with codewords, where 'foo' corresponds"
    " to population and 'bar'corresponds to city."
)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
]

## Part 3: Text-to-SQL Retriever

So far our text-to-SQL capability is packaged in a query engine and consists of both retrieval and synthesis.

You can use the SQL retriever on its own. We show you some different parameters you can try, and also show how to plug it into our `RetrieverQueryEngine` to get roughly the same results.

In [None]:
from llama_index.core.retrievers import NLSQLRetriever

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=True, verbose=True
)

In [None]:
results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)

In [None]:
from llama_index.core.response.notebook_utils import display_source_node

for n in results:
    display_source_node(n)

In [None]:
# default retrieval (return_raw=False)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=False
)

In [None]:
results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)

In [None]:
# NOTE: all the content is in the metadata
for n in results:
    display_source_node(n, show_source_metadata=True)

### Plug into our `RetrieverQueryEngine`

We compose our SQL Retriever with our standard `RetrieverQueryEngine` to synthesize a response. The result is roughly similar to our packaged `Text-to-SQL` query engines.

In [None]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)

In [None]:
response = query_engine.query(
    "Return the top 5 cities (along with their populations) with the highest population."
)

In [None]:
print(str(response))