<a href="https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/index_structs/struct_indices/SQLIndexDemo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NL2SQL to Bigquery Using llamaindex & Gemini
The main purpose of this tutorial is to give a guidance on how to develop text-to-SQL capabilities using LLM and BigQuery as a backend. Simple text-to-SQL is straightforward, but there are several challenges in real application:
- **Busines context**: how to add business context into existing database schema
- **Multiple tables**: how to handle join and schema retrieval when model needs to get the data from multiple tables
- **Dynamic prompt**: instead of loading the whole schema into the prompt, retrieve only related table required to solve the questions. We use Vertex AI embedding to search semantic similarity between question and schema.

This notebook addresses above challenges using llamaindex & inmemory vectorestore. This guide explains LlamaIndex's Text-to-SQL capabilities.
1. First show how to perform text-to-SQL using simple "retrieval" (sql query over db) and "synthesis".
2. Second, showing how to buid a TableIndex over the schema to dynamically retrieve relevant tables during query-time. This method addresses multi tables scenarion in text-to-SQL use case.
3. Third, Customizing default prompt template according to application requirements.

**Credits**:
- https://docs.llamaindex.ai/en/stable/examples/index_structs/struct_indices/SQLIndexDemo.html
- https://github.com/hamnarif/Text-to-SQL/blob/main/update_prompt_template.ipynb


This notebook is adaptation of [llamaindex's documentation](https://docs.llamaindex.ai/en/stable/examples/index_structs/struct_indices/SQLIndexDemo/). I made the following updates and modification to the original notebook:
- Connect to Google BigQuery instead of local sqlite database
- Using Google Gemini as foundation model
- Add additional section to customise default prompt template.


##Initial Setup

###Install required libraries
The following libraries are required to run the tutorial:
- llamaindex
- Google Vertex AI client library
- Langchain for embedding
- Bigquery's SQL Alchemy libray

In [56]:
!pip install protobuf==5.26.1

Collecting protobuf==5.26.1
  Downloading protobuf-5.26.1-cp37-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
Downloading protobuf-5.26.1-cp37-abi3-manylinux2014_x86_64.whl (302 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.8/302.8 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 4.25.6
    Uninstalling protobuf-4.25.6:
      Successfully uninstalled protobuf-4.25.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
arize-phoenix 2.2.1 requires protobuf<5.0,>=3.20, but you have protobuf 5.26.1 which is incompatible.
opentelemetry-proto 1.27.0 requires protobuf<5.0,>=3.19, but you have protobuf 5.26.1 which is incompatible.[0m[31m
[0mSuccessfully installed protobuf-5.26.1


In [1]:
!pip install llama-index llama-index-llms-vertex sqlalchemy-bigquery langchain llama-index-embeddings-langchain
!pip install arize-phoenix==3.25.0 pyvis
!pip install llama-index-callbacks-arize-phoenix
!pip install --upgrade google-cloud-aiplatform
!pip install -U langchain-google-vertexai
!pip install numpy==1.26.4
!pip install pandas==2.2.2
!pip install scikit-learn==1.3.2




In [2]:
#call required library
from IPython.display import Markdown, display
import pandas as pd
import vertexai

###Authenticate to Google Cloud Credential

In [3]:
# Authenticate with Google account
# run the following lines ONLY if you use Google Colab

from google.colab import auth as google_auth
google_auth.authenticate_user()

In [None]:
#connect to gcp credential
#setup this if You use VM or local machine

# from google.oauth2 import service_account

#assign credential. replace project_id and credential
#mycredential = service_account.Credentials.from_service_account_file(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
#myproject=os.environ["PROJECT_ID"]

Setup phoenix for LLM tracing & evaluation

In [4]:
!pip install --upgrade phoenix



In [5]:
# setup Arize Phoenix for logging/observability
import phoenix as px

px.launch_app()
import llama_index.core

llama_index.core.set_global_handler("arize_phoenix")

🌍 To view the Phoenix app in your browser, visit https://advhog55xi61-496ff2e9c6d22116-6006-colab.googleusercontent.com/
📺 To view the Phoenix app in a notebook, run `px.active_session().view()`
📖 For more information on how to use Phoenix, check out https://docs.arize.com/phoenix


Setup GCP project and location

In [10]:
# Change to your project ID!

PROJECT_ID = "aipoc-454808" # @param {type:"string"}

LOCATION = "us-central1"  # @param {type:"string"}
DATASET_ID = 'aipoc_454808_data_bucket' # @param {type:"string"}

vertexai.init(project=PROJECT_ID, location=LOCATION)

## BigQuery Setup

### BigQuery: Create dataset
Create a BigQuery dataset to upload the sample data.


In [68]:
# Create BigQuery Dataset on your project
from google.cloud import bigquery
import pandas as pd

bq_client = bigquery.Client(project=PROJECT_ID)

dataset_id = "{}.{}".format(bq_client.project, DATASET_ID)
dataset = bigquery.Dataset(dataset_id)
dataset.location = LOCATION

# Create the dataset
try:
    dataset = bq_client.create_dataset(dataset, timeout=30)
    print(f'Dataset {dataset_id} create successfully.')
except Exception as e:
    print(e)

409 POST https://bigquery.googleapis.com/bigquery/v2/projects/aipoc-454808/datasets?prettyPrint=false: Already Exists: Dataset aipoc-454808:demo


###Create tables and ingesting CSV to BigQuery
There are 3 tables will be created in BigQuery dataset. We will load sample data from csv into BigQuery. If this process fails, try to recreate the dataset and load the data.

In [15]:
#get csv file from repository

customerfile= 'https://raw.githubusercontent.com/mchoirul/genai-code/main/sampledata/rabbit_customer_dummy.csv'
ordertransactionfile = 'https://raw.githubusercontent.com/mchoirul/genai-code/main/sampledata/rabbit-customer-transaction.csv'
surveyfile = 'https://raw.githubusercontent.com/mchoirul/genai-code/main/sampledata/rabbit-satisfactionsurvey.csv'

Create import_csv_to_bq function to load csv to BigQuery table

In [16]:
#create function to import csv to BigQuery
def import_csv_to_bq (filepath, table_id):

  job_config = bigquery.LoadJobConfig(
      source_format=bigquery.SourceFormat.CSV, skip_leading_rows=1, autodetect=True,
          write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE )

  df = pd.read_csv(filepath, delimiter=',', )
  load_job = bq_client.load_table_from_dataframe(dataframe=df,
                                          destination=table_id,
                                            job_config=job_config)  # Make an API request.


  load_job.result()  # Waits for the job to complete.

  table = bq_client.get_table(table_id)  # Make an API request.
  print(
      "Loaded {} rows and {} columns to {}".format(
          table.num_rows, len(table.schema), table_id
      )
)

Execute import_csv_to_bq to load customer sample data

In [17]:
#import customer csv
tablename = 'customerdata'
table_id = "{}.{}".format(dataset_id, tablename) #fully qualified table name
print (table_id)

import_csv_to_bq(customerfile, table_id)

#preview imported data
bq_client.query("SELECT * FROM "+ "`" +table_id+ "`").to_dataframe().head()

#wait until dataframe preview completed!

aipoc-454808.demo.customerdata
Loaded 58 rows and 7 columns to aipoc-454808.demo.customerdata


Unnamed: 0,custid,custname,dateofbirth,city_address,nationality,memberstatus,education
0,31,Mickey Mouse,1950-01-01,Jakarta,Indonesia,YES,Bachelor
1,8,Iron Man,1960-03-03,Pontianak,Indonesia,NO,Doctorate
2,20,Beast,1960-03-03,Jambi,Indonesia,NO,Doctorate
3,38,Dale,1960-03-03,Pontianak,Indonesia,NO,Doctorate
4,50,Max Goof,1960-03-03,Jambi,Indonesia,NO,Doctorate


Execute import_csv_to_bq to load ordertransaction sample data

In [None]:
#import order transaction to BigQuery
tablename = 'ordertransaction'
table_id = "{}.{}".format(dataset_id, tablename) #fully qualified table name
print (table_id)

import_csv_to_bq(ordertransactionfile, table_id)

#preview imported data
bq_client.query("SELECT * FROM "+ "`" +table_id+ "`").to_dataframe().head()

#wait until dataframe preview completed!

your-projectid.rabbitconsulting.ordertransaction
Loaded 279 rows and 7 columns to your-projectid.rabbitconsulting.ordertransaction


Unnamed: 0,transactionid,custid,transactsitelocation,transactiondate,servicecategory,transactionamount,servingconsultant
0,TX3008,27,Surabaya,2022-01-08,Eye Care,800000,Ema
1,TX3028,33,Bandung,2022-01-28,Eye Care,800000,Ema
2,TX3040,21,Bandung,2022-02-10,Covid Test,2900000,Ema
3,TX1004,12,Bandung,2023-03-11,Consultation,500000,Ema
4,TX1024,12,Bandung,2023-03-31,Consultation,500000,Ema


Execute import_csv_to_bq to load satisfactionsurvey sample data

In [None]:
#import survey data to BigQuery
tablename = 'satisfactionsurvey'
table_id = "{}.{}".format(dataset_id, tablename) #fully qualified table name
print (table_id)

import_csv_to_bq(surveyfile, table_id)

#preview imported data
bq_client.query("SELECT * FROM "+ "`" +table_id+ "`").to_dataframe().head()

#wait until dataframe preview completed!

your-projectid.rabbitconsulting.satisfactionsurvey
Loaded 279 rows and 2 columns to your-projectid.rabbitconsulting.satisfactionsurvey


Unnamed: 0,transactionid,surveyscore
0,TX3070,61
1,TX3091,61
2,TX3093,61
3,TX5147,61
4,TX3027,62


### Setting up SQL Alchemy for BigQuery
llamaindex uses SQL Alchemy as default database connection. We need to setup SQL Alchemy connection and tested it out against BigQuery dataset created in erlier step.

We first define our `SQLDatabase` abstraction (a light wrapper around SQLAlchemy).

In [7]:
from sqlalchemy import (
     create_engine,
     MetaData,
     Table,
     Column,
     String,
     Integer,
     select,
 )

In [37]:
!pip install numpy==1.26.4

Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m76.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
arize-phoenix 2.2.1 requires protobuf<5.0,>=3.20

In [71]:
!pip uninstall nltk
!pip install -U nltk

Found existing installation: nltk 3.9.1
Uninstalling nltk-3.9.1:
  Would remove:
    /usr/local/bin/nltk
    /usr/local/lib/python3.11/dist-packages/nltk-3.9.1.dist-info/*
    /usr/local/lib/python3.11/dist-packages/nltk/*
Proceed (Y/n)? Y
  Successfully uninstalled nltk-3.9.1
Collecting nltk
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Downloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nltk
Successfully installed nltk-3.9.1


In [11]:
#initiate bigquery connection
from llama_index.core import SQLDatabase

table_uri = f"bigquery://{PROJECT_ID}/{DATASET_ID}"
engine = create_engine(f"bigquery://{PROJECT_ID}/{DATASET_ID}")

#initiate database connection
sql_database = SQLDatabase(engine)

In [13]:
#Testing table retrieval using SQL alchemy
from sqlalchemy import text

with engine.connect() as con:
    rows = con.execute(text("SELECT * from t_rpp_click LIMIT 20"))
    for row in rows:
        print(row)

('143.217.1.150', 'Mozilla', 'NULL', 'bb46e433b2e371a651703b5fa8301bf1', 'NULL', 'ef6m5dei6i3rohhrmzqt', 10002, 30000, 232271, 10000021, 2, 'iphone', 'NULL', 'NULL', 'stg-wrasta501z.stg.jp.local', 1, 'rpp', 'ichiba', 180002, 1523980800000, datetime.date(2018, 4, 18), 'request_id_600', 'dk-k7pueid1-2q-017bb061-2552-4a35-8a86-cabdlj17', 'R', 301, 301, datetime.datetime(2018, 4, 17, 16, 0, tzinfo=datetime.timezone.utc), 10, 301, 0, 'NULL', 'NULL', '018771572d7220c9', 0, 0, 'NULL', 0, 'NULL', '9D44A3C8-330B-4479-851C-199AD5017DC020230126065148175', '00000000-0000-0000-0000-000000000000', 600001, 'keyword601', 5, 'NULL', 'NULL', 'NULL', datetime.date(2000, 1, 1))
('143.217.1.150', 'Mozilla', 'NULL', 'bb46e433b2e371a651703b5fa8301bf1', 'NULL', 'ef6m5dei6i3rohhrmzqt', 10002, 30000, 232271, 10000021, 2, 'iphone', 'NULL', 'NULL', 'stg-wrasta501z.stg.jp.local', 1, 'rpp', 'ichiba', 180002, 1523980800000, datetime.date(2018, 4, 18), 'request_id_600', 'dk-k7pueid1-2q-017bb061-2552-4a35-8a86-cabdlj1

##Setting up Vertex AI LLM and Embedding
We use Gemini Pro 1.5 as main LLM model for SQL generation and response synthesis. The other challenge is to find correct table that contextually related with user question. We solve this using vector database and embedding.


The main purpose of embedding is to generate vector representation of database schema and user questions. With this way, We can perform semantic similarity comparison between user questions and database schema.

Configure embedding uses Vertex AI embedding model: textembedding-gecko-multilingual@001. We utilize Langchain's embedding wrapper for better compatibility.

In [36]:
#initiate vertex AI llm using llamaindex
from llama_index.llms.vertex import Vertex

#use gemini 1.5 pro model for LLM
llm = Vertex(model="gemini-1.5-pro-002",
             max_tokens=32768,
             temperature=0.1,  additional_kwargs={"max_output_tokens": 2048})

#initiate vertex ai embedding
#borrow langchain's wrapper
from langchain_google_vertexai import VertexAIEmbeddings
vertexembedding= VertexAIEmbeddings(model_name='textembedding-gecko-multilingual@001')

#apply settings to llamaindex config
from llama_index.core import Settings
from llama_index.embeddings.langchain import LangchainEmbedding

# Wrap the Langchain embedding model with LangchainEmbedding
embed_model = LangchainEmbedding(vertexembedding)

#apply settings to default llamaindex config
Settings.embed_model = embed_model # Use the wrapped embedding model
Settings.llm = llm

  llm = Vertex(model="gemini-1.5-pro-002",


In [37]:
#test llm connection
llm.complete("Hello first president of Singapore is...").text

'Yusof bin Ishak\n'

## Part 1: Text-to-SQL Query Engine
Once we have constructed our BigQuery dataset, we can use the NLSQLTableQueryEngine to
construct natural language queries that are synthesized into SQL queries.
1. user question input as natural language
2. LLM translate the question into SQL language based on database schema provided
3. llamaindex wrap the SQL statement and execute it against BigQuery using SQLAlchemy library
4. The output consist of of raw results, SQL Query, and synthesis interpretation of the data


Note that we need to specify the tables we want to use with this query engine.
If we don't the query engine will pull all the schema context, which could
overflow the context window of the LLM.

In [40]:
from llama_index.core.query_engine import NLSQLTableQueryEngine

#specify 3 tables in the dataset
query_engine = NLSQLTableQueryEngine(
  sql_database=sql_database, tables=["t_rpp_click"],
  llm=llm, embed_model=embed_model,
  synthesize_response=False,
  verbose=True
 )


In [48]:
#query data from multiple tables
#query_str2 = "show amount of order for each customer, by mentioning customer name row by row?"
query_str = "how many click done on datetime 2000-04-17 ?"
response = query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

> Table Info: Table 't_rpp_click' has columns: ip (VARCHAR), user_agent (VARCHAR), referer (VARCHAR), dispcid (VARCHAR), uid (VARCHAR), rp (VARCHAR), location_id (INTEGER), reporting_id (INTEGER), shop_id (INTEGER), item_id (INTEGER), event_type (INTEGER), search_words (VARCHAR), search_genres (VARCHAR), search_tags (VARCHAR), server_id (VARCHAR), conversion_id (INTEGER), conversion_service_id (VARCHAR), media_id (VARCHAR), rpp_campaign_id (INTEGER), click_timestamp (INTEGER), click_url_issued_date (DATE), request_id (VARCHAR), click_id (VARCHAR), selector (VARCHAR), click_price_feed (INTEGER), click_price (INTEGER), click_datetime (TIMESTAMP), import_sts (INTEGER), charge_amount (INTEGER), deduct_amount (INTEGER), reason_id (VARCHAR), note (VARCHAR), logic (VARCHAR), redirect_unavailable_flag (INTEGER), timesales_flag (INTEGER), sales_end_time (VARCHAR), cancel_amount (INTEGER), custom_id (VARCHAR), ias_ai (VARCHAR), a_uid (VARCHAR), rpp_keyword_id (INTEGER), keyword_id (VARCHAR), use

<b>[(34,)]</b>

In [49]:
#get sql query only
response.metadata['sql_query']

"SELECT count(*) FROM t_rpp_click WHERE DATE(click_datetime) = '2000-04-17'"

In [50]:
#get sql query. pretty impressive
response.metadata['result']

[(34,)]

Lets try to ask a more complex question:
show average sales, order frequency, and average score for each city when order period is between Jan - march 2023, then order by highest sales

In [51]:
query_str2 = "show average sales, order frequency, and average score for each city when order period is between Jan - march 2023, then order by highest sales?"
response2 = query_engine.query(query_str2)
display(Markdown(f"<b>{response2}</b>"))

> Table Info: Table 't_rpp_click' has columns: ip (VARCHAR), user_agent (VARCHAR), referer (VARCHAR), dispcid (VARCHAR), uid (VARCHAR), rp (VARCHAR), location_id (INTEGER), reporting_id (INTEGER), shop_id (INTEGER), item_id (INTEGER), event_type (INTEGER), search_words (VARCHAR), search_genres (VARCHAR), search_tags (VARCHAR), server_id (VARCHAR), conversion_id (INTEGER), conversion_service_id (VARCHAR), media_id (VARCHAR), rpp_campaign_id (INTEGER), click_timestamp (INTEGER), click_url_issued_date (DATE), request_id (VARCHAR), click_id (VARCHAR), selector (VARCHAR), click_price_feed (INTEGER), click_price (INTEGER), click_datetime (TIMESTAMP), import_sts (INTEGER), charge_amount (INTEGER), deduct_amount (INTEGER), reason_id (VARCHAR), note (VARCHAR), logic (VARCHAR), redirect_unavailable_flag (INTEGER), timesales_flag (INTEGER), sales_end_time (VARCHAR), cancel_amount (INTEGER), custom_id (VARCHAR), ias_ai (VARCHAR), a_uid (VARCHAR), rpp_keyword_id (INTEGER), keyword_id (VARCHAR), use

<b>[]</b>

In [52]:
#show result only
response2.metadata['result']

[]

Ask a more contextual question, such as Which city has the highest average survey score?

In [53]:
#ask query that require to join multiple tables

query_str3 = "Which city has the highest average survey score, and mention the score?"
response3 = query_engine.query(query_str3)
display(Markdown(f"<b>{response3}</b>"))

> Table Info: Table 't_rpp_click' has columns: ip (VARCHAR), user_agent (VARCHAR), referer (VARCHAR), dispcid (VARCHAR), uid (VARCHAR), rp (VARCHAR), location_id (INTEGER), reporting_id (INTEGER), shop_id (INTEGER), item_id (INTEGER), event_type (INTEGER), search_words (VARCHAR), search_genres (VARCHAR), search_tags (VARCHAR), server_id (VARCHAR), conversion_id (INTEGER), conversion_service_id (VARCHAR), media_id (VARCHAR), rpp_campaign_id (INTEGER), click_timestamp (INTEGER), click_url_issued_date (DATE), request_id (VARCHAR), click_id (VARCHAR), selector (VARCHAR), click_price_feed (INTEGER), click_price (INTEGER), click_datetime (TIMESTAMP), import_sts (INTEGER), charge_amount (INTEGER), deduct_amount (INTEGER), reason_id (VARCHAR), note (VARCHAR), logic (VARCHAR), redirect_unavailable_flag (INTEGER), timesales_flag (INTEGER), sales_end_time (VARCHAR), cancel_amount (INTEGER), custom_id (VARCHAR), ias_ai (VARCHAR), a_uid (VARCHAR), rpp_keyword_id (INTEGER), keyword_id (VARCHAR), use

<b>Error: (google.cloud.bigquery.dbapi.exceptions.DatabaseError) 404 POST https://bigquery.googleapis.com/bigquery/v2/projects/aipoc-454808/queries?prettyPrint=false: Not found: Table aipoc-454808:aipoc_454808_data_bucket.t_survey was not found in location us-central1
[SQL: SELECT city, AVG(survey_score) AS average_score FROM t_survey GROUP BY city ORDER BY average_score DESC LIMIT 1]
(Background on this error at: https://sqlalche.me/e/20/4xp6)</b>

In [None]:
#more complex question
query_str5 = "show me top 5 average order amount per customer? mention customer name and their average order line by line"
response5 = query_engine.query(query_str5)
display(Markdown(f"<b>{response5}</b>"))

> Table desc str: Table 'customerdata' has columns: custid (INTEGER), custname (VARCHAR): 'customer name', dateofbirth (VARCHAR), city_address (VARCHAR): 'city address of customer', nationality (VARCHAR), memberstatus (VARCHAR), education (VARCHAR), .

Table 'ordertransaction' has columns: transactionid (VARCHAR), custid (INTEGER): 'customer id', transactsitelocation (VARCHAR): 'city location of transaction', transactiondate (VARCHAR), servicecategory (VARCHAR), transactionamount (INTEGER): 'order amount each transaction', servingconsultant (VARCHAR), .

Table 'satisfactionsurvey' has columns: transactionid (VARCHAR), surveyscore (INTEGER), .
> Predicted SQL query: SELECT customerdata.custname, AVG(ordertransaction.transactionamount) AS average_transaction_amount FROM customerdata INNER JOIN ordertransaction ON customerdata.custid = ordertransaction.custid GROUP BY customerdata.custname ORDER BY average_transaction_amount DESC LIMIT 5


<b>Robin's average order amount: $3,000,000
Cyborg's average order amount: $3,000,000
Batman's average order amount: $3,000,000
Starfire's average order amount: $3,000,000
Wonder Woman's average order amount: $2,500,000
</b>

If We inspect the log using Phoenix, We realize that llamaindex always retrieve schema from 3 tables regardless the questions. This might not be optimal approach when there are many tables with bigger schema.
We can use similarity search to get more contextual table required, as described in the next section.

## Part 2: Query-Time Retrieval of Tables for Text-to-SQL
This method uses SQLTableNodeMapping to add mapping and busines context explanation for each tables. This should improve query accuracy and better LLM's synthesize result.

In many cases, adding description directly inside BigQuery schema is not the option because of permission issues. In this case, You can add contextual description inside SQLTableSchema object within the application itself.

This method also stores table schema in to vector database with their embedding. This will improve efficiency of schema retrieval, which only pull required table for SQL generation instead of the whole database schema.


In [None]:
#use vector search to store table schema embedding
# need to put sufficient context explanation.
# We can also put in BigQuery description
from re import VERBOSE

#create vector embedding for table schema
from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

# add contextual description in the schema object
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="customerdata",
                    context_str="this is customer table. column custid is PK. one to many join: customerdata.custid = ordertransaction.custid") ),
    (SQLTableSchema(table_name="ordertransaction",
                    context_str="contains transaction data for every customer. column transactionid is PK, custid is FK. one to one join: ordertransaction.transactionid = satisfactionsurvey.transactionid")),
    (SQLTableSchema(table_name="satisfactionsurvey",
                    context_str="the table contains survey score for every transaction, need to join with ordertransaction to get customer & city related data.  column transactionid is PK and FK. one to one join: ordertransaction.transactionid = satisfactionsurvey.transactionid"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex
)

#tips: synthesize_response=False to get faster LLM result
query_engineX = SQLTableRetrieverQueryEngine(
     sql_database,
     obj_index.as_retriever(similarity_top_k=2),
     embed_model=embed_model,
     llm=llm,
     synthesize_response=True,
 )

In [None]:
#try to view the default prompt in llamaindex
# define prompt viewing function
def display_prompt_dict(prompts_dict):
    for k, p in prompts_dict.items():
        text_md = f"**Prompt Key**: {k}" f"**Text:** "
        display(Markdown(text_md))
        print(p.get_template())
        display(Markdown(""))

In [None]:
#view the default prompt to see how it works
myprompt=(query_engineX.get_prompts())
display_prompt_dict(myprompt)

**Prompt Key**: response_synthesis_prompt**Text:** 

Given an input question, synthesize a response from the query results.
Query: {query_str}
SQL: {sql_query}
SQL Response: {context_str}
Response: 




**Prompt Key**: sql_retriever:text_to_sql_prompt**Text:** 

Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}

Question: {query_str}
SQLQuery: 




In [None]:
#test NL to SQL
query_strX1 = "show average sales, order frequency, and average score for each city when order period is between Jan - march 2023, then order by highest sales?  "
responseX1 = query_engineX.query(query_strX1)
display(Markdown(f"<b>{responseX1}</b>"))

<b>Between January and March 2023, Semarang had the highest average sales at 2,214,286 Rupiah with an average satisfaction score of 71.3 and 7 orders.  Bandung followed with average sales of 1,937,500 Rupiah, an average score of 82, and 8 orders.  Surabaya had the third highest average sales at 1,777,778 Rupiah, an average score of 79.3, and the highest order frequency with 9 orders.  Makassar's average sales were 1,714,286 Rupiah with an average score of 81.6 and 7 orders. Medan had average sales of 1,625,000 Rupiah, an average score of 79.4, and 8 orders.  Finally, Jakarta had the lowest average sales at 1,462,500 Rupiah with an average score of 71.5 and 8 orders.
</b>

In [None]:
#get sql query.
responseX1.metadata['sql_query']

"SELECT\n    ot.transactsitelocation,\n    AVG(ot.transactionamount) AS avg_sales,\n    COUNT(ot.transactionid) AS order_frequency,\n    AVG(ss.surveyscore) AS avg_score\n  FROM\n    ordertransaction AS ot\n    INNER JOIN satisfactionsurvey AS ss ON ot.transactionid = ss.transactionid\n  WHERE substr(ot.transactiondate, 1, 7) BETWEEN '2023-01' AND '2023-03'\n  GROUP BY ot.transactsitelocation\nORDER BY\n  avg_sales DESC"

In [None]:

responseX1.metadata['col_keys']

['transactsitelocation', 'avg_sales', 'order_frequency', 'avg_score']

In [None]:
#get sql query.
responseX1.metadata['result']

[('Semarang', 2214285.714285714, 7, 71.28571428571429),
 ('Bandung', 1937500.0, 8, 82.0),
 ('Surabaya', 1777777.7777777778, 9, 79.33333333333334),
 ('Makassar', 1714285.7142857143, 7, 81.57142857142857),
 ('Medan', 1625000.0, 8, 79.37500000000001),
 ('Jakarta', 1462500.0, 8, 71.5)]

Lets create a function to wrap the result into panda dataframe. This will make easier data consumption to various app.

In [None]:
#method to convert colkey & result to dataframe

def convert_to_dataframe(col_keys, data):
  """Converts data to a pandas DataFrame.

  Args:
    col_keys: A list of column names.
    data: A list of tuples, where each tuple represents a row of data.

  Returns:
    A pandas DataFrame.
  """
  df = pd.DataFrame(data, columns=col_keys)
  return df

Lets return the resultset to panda dataframe

In [None]:
col_keys=responseX1.metadata['col_keys']
data = responseX1.metadata['result']

df = convert_to_dataframe(col_keys, data)
df

Unnamed: 0,transactsitelocation,avg_sales,order_frequency,avg_score
0,Semarang,2214286.0,7,71.285714
1,Bandung,1937500.0,8,82.0
2,Surabaya,1777778.0,9,79.333333
3,Makassar,1714286.0,7,81.571429
4,Medan,1625000.0,8,79.375
5,Jakarta,1462500.0,8,71.5


In [None]:
#query data from multiple tables
query_str21 = "show total amount of order for each customer, by mentioning customer name row by row?"
response21 = query_engineX.query(query_str21)
display(Markdown(f"<b>{response21}</b>"))

<b>Mickey Mouse has a total order amount of $10,000,000.
Iron Man has a total order amount of $1,500,000.
Beast has a total order amount of $500,000.
Dale has a total order amount of $12,500,000.
Max Goof has a total order amount of $17,000,000.
Batman has a total order amount of $6,000,000.
Beasty Boy has a total order amount of $2,500,000.
Jubilee has a total order amount of $7,500,000.
Donald Duck has a total order amount of $7,600,000.
Gyro Gearloose has a total order amount of $1,200,000.
Chip and Dale's Rescue Rangers has a total order amount of $16,000,000.
Wonder Woman has a total order amount of $5,000,000.
Raven has a total order amount of $3,500,000.
Rogue has a total order amount of $6,500,000.
Goofy has a total order amount of $5,300,000.
Ludwig Von Drake has a total order amount of $14,500,000.
Gadget Hackwrench has a total order amount of $17,000,000.
Captain America has a total order amount of $1,500,000.
Iceman has a total order amount of $7,800,000.
Clarabelle Cow has a total order amount of $8,700,000.
Petey has a total order amount of $17,000,000.
Spider-Man has a total order amount of $1,000,000.
Robin has a total order amount of $3,000,000.
Chip has a total order amount of $10,200,000.
Goofy's son has a total order amount of $19,000,000.
Cyclops has a total order amount of $1,000,000.
Gambit has a total order amount of $4,300,000.
Launchpad McQuack has a total order amount of $16,000,000.
Rover Scout has a total order amount of $17,500,000.
The Flash has a total order amount of $2,000,000.
Nightwing has a total order amount of $1,000,000.
Pyro has a total order amount of $8,400,000.
Minnie Mouse has a total order amount of $7,100,000.
José Carioca has a total order amount of $14,500,000.
Thor has a total order amount of $2,000,000.
Angel has a total order amount of $5,600,000.
Horace Horsecollar has a total order amount of $9,700,000.
Peg Pete has a total order amount of $16,500,000.
Hulk has a total order amount of $3,000,000.
Colossus has a total order amount of $9,500,000.
P.J. has a total order amount of $16,500,000.
Green Lantern has a total order amount of $500,000.
Starfire has a total order amount of $3,000,000.
Quicksilver has a total order amount of $7,900,000.
Pluto has a total order amount of $5,900,000.
Panchito Pistoles has a total order amount of $16,000,000.
Monterey Jack has a total order amount of $16,000,000.
Wolverine has a total order amount of $1,000,000.
Storm has a total order amount of $8,900,000.
Scrooge McDuck has a total order amount of $18,500,000.
Goofy Junior has a total order amount of $14,500,000.
Aquaman has a total order amount of $2,500,000.
Cyborg has a total order amount of $6,000,000.
Scarlet Witch has a total order amount of $7,700,000.
Daisy Duck has a total order amount of $8,500,000.
Donald Duck's nephews has a total order amount of $18,000,000.
Brer Rabbit has a total order amount of $16,500,000.
</b>

In [None]:
col_keys=response21.metadata['col_keys']
data = response21.metadata['result']

df = convert_to_dataframe(col_keys, data)
df

Unnamed: 0,custname,total_amount
0,Mickey Mouse,10000000
1,Iron Man,1500000
2,Beast,500000
3,Dale,12500000
4,Max Goof,17000000
5,Batman,6000000
6,Beasty Boy,2500000
7,Jubilee,7500000
8,Donald Duck,7600000
9,Gyro Gearloose,1200000


In [None]:
#query data from multiple tables
query_str2 = "which city has the highest order amount?"
response2 = query_engineX.query(query_str2)
print(response2)

Surabaya has the highest order amount.



## Part 3: Customizing Prompt Template
We may need to modify the default prompt according to specific requirement of the application. Below are some conditions when customizing prompt is needed:
- Adding specific convention to identify fully qualified name of object in the database schema
- Adding query example. It's more efficient to give instruction to LLM using example rather than long description
- Create specific desired output

In [None]:
#customizing llamaindex prompt template
#add query example to improve accuracy

from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.prompts.prompt_type import PromptType

# Modified Prompt to better understand and respond to user queries
MODIFIED_TEXT_TO_SQL_TMPL = (
    "You are a data analyst expert on SQL query and interpreting insights from query result. "
    "Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. "
    "You can order the results by a relevant column to return the most interesting examples in the database. "
    "Never query for all the columns from a specific table, only ask for a few relevant columns given the question. "
    "Use only the column names that you can see in the schema description. "
    "Be careful to not query for columns that do not exist. Pay attention to which column is in which table. "
    "Also, ensure proper qualification of column names with the table name when needed. "
    "NEVER execute any DML statement. Return error when user is trying any DML statement. "
    "You are required to use the following format, each taking one line: "

    "Examples: "
    "question: show average sales for each city when transaction happened in Jan-march 2023 "
    "answer: SELECT `city_address`, AVG(`transactionamount`) AS average_sales FROM rabbitconsulting.`ordertransaction` AS o JOIN rabbitconsulting.`customerdata` AS c ON c.`custid` = o.`custid` WHERE `transactiondate` BETWEEN '2023-01-01' AND '2023-03-31' GROUP BY `city_address` "

    "question: show average satisfaction score for each consultant "
    "answer: SELECT `servingconsultant`, AVG(`surveyscore`) AS average_satisfaction_score FROM rabbitconsulting.`ordertransaction` AS o JOIN rabbitconsulting.`satisfactionsurvey` AS s ON o.`transactionid` = s.`transactionid` GROUP BY `servingconsultant` "

    "question: show total sales, transaction frequency, and average score for each city "
    "answer: SELECT `city_address`, SUM(`transactionamount`) AS total_sales, COUNT(o.`transactionid`) AS transaction_frequency, AVG(`surveyscore`) AS average_score FROM rabbitconsulting.`customerdata` AS c JOIN rabbitconsulting.`ordertransaction` AS o ON c.`custid` = o.`custid` JOIN rabbitconsulting.`satisfactionsurvey` AS s ON o.`transactionid` = s.`transactionid` GROUP BY `city_address` "


    "Question: Question here\n"
    "SQLQuery: SQL Query to run\n"
    "SQLResult: Result of the SQLQuery\n"
    "Answer: Final answer here\n\n"

    "Only use tables listed below.\n"
    "{schema}\n\n"

    "Question: {query_str}\n"
    "SQLQuery: "
)


MODIFIED_TEXT_TO_SQL_PROMPT = PromptTemplate(
    MODIFIED_TEXT_TO_SQL_TMPL,
    prompt_type=PromptType.TEXT_TO_SQL,
)

In [None]:
prompts_dict = query_engineX.get_prompts()
query_engineX.update_prompts(
    {"sql_retriever:text_to_sql_prompt": MODIFIED_TEXT_TO_SQL_PROMPT}
)

In [None]:
#check updated prompt
mypromptx=(query_engineX.get_prompts())
display_prompt_dict(mypromptx)

**Prompt Key**: response_synthesis_prompt**Text:** 

Given an input question, synthesize a response from the query results.
Query: {query_str}
SQL: {sql_query}
SQL Response: {context_str}
Response: 




**Prompt Key**: sql_retriever:text_to_sql_prompt**Text:** 

You are a data analyst expert on SQL query and interpreting insights from query result. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table, only ask for a few relevant columns given the question. Use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, ensure proper qualification of column names with the table name when needed. NEVER execute any DML statement. Return error when user is trying any DML statement. You are required to use the following format, each taking one line: Examples: question: show average sales for each city when transaction happened in Jan-march 2023 answer: SELECT `city_address`, AVG(`transac



In [None]:
#test new prompt
#query data from multiple tables
#try to ask question in Bahasa Indonesia

query_strX2 = "tampilkan total penjualan, frekuensi order, dan rata rata skor survey for untuk tiap konsultan. urutkan dari penjualan tertinggi?"
responseX2 = query_engineX.query(query_strX2)
display(Markdown(f"<b>{responseX2}</b>"))

<b>Berikut adalah total penjualan, frekuensi order, dan rata-rata skor survey untuk tiap konsultan, diurutkan dari penjualan tertinggi:

* **Ahmad:** Total penjualan Rp 134.900.000, dengan 71 order dan rata-rata skor survey 80.55.
* **Boby:** Total penjualan Rp 128.700.000, dengan 70 order dan rata-rata skor survey 72.67.
* **Rany:** Total penjualan Rp 123.500.000, dengan 69 order dan rata-rata skor survey 76.70.
* **Ema:** Total penjualan Rp 120.700.000, dengan 69 order dan rata-rata skor survey 82.16.
</b>

In [None]:
responseX2.metadata['sql_query']

'SELECT o.servingconsultant, SUM(o.transactionamount) AS total_sales, COUNT(o.transactionid) AS transaction_frequency, AVG(s.surveyscore) AS average_survey_score FROM rabbitconsulting.ordertransaction AS o JOIN rabbitconsulting.satisfactionsurvey AS s ON o.transactionid = s.transactionid GROUP BY o.servingconsultant ORDER BY total_sales DESC'

In [None]:
col_keys=responseX2.metadata['col_keys']
data = responseX2.metadata['result']

df = convert_to_dataframe(col_keys, data)
df

Unnamed: 0,servingconsultant,total_sales,transaction_frequency,average_survey_score
0,Ahmad,134900000,71,80.549296
1,Boby,128700000,70,72.671429
2,Rany,123500000,69,76.695652
3,Ema,120700000,69,82.15942


In [None]:
#try malicious instruction

query_testdml = "drop table customer"
responsedml = query_engineX.query(query_testdml)
display(Markdown(f"<b>{responsedml}</b>"))

<b>Error: DML operations are not permitted.
</b>