# Dynamic Few-Shot Prompting for SQL

Few-Shot Prompting is a technique to add relevant examples to the prompt.
Dynamic Few-Shot Prompting uses a vector store such as Astra DB to get the relevant examples from a similarity search between the prompt and known prompts associated to validated examples.
This notebook shows how to use RAGStack and LangChain to do Dynamic Few-Shot Prompting to improve the accuracy of Text-to-SQL prompts.

Let's start by installing RAGStack and required libraries :

In [None]:
!pip install --quiet ragstack-ai datasets

Configure your OpenAI API key ([link](https://platform.openai.com/docs/quickstart/account-setup)):

In [None]:
import os
from getpass import getpass

from langchain_openai import OpenAIEmbeddings

if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")

emb = OpenAIEmbeddings()

Configure your Astra DB token and API endpoint ([link](https://docs.datastax.com/en/astra/astra-db-vector/get-started/quickstart.html#create-a-serverless-vector-database)):





In [None]:
# Grab the Astra token and api endpoint from the environment or user input
ASTRA_DB_TOKEN = (
    os.environ["ASTRA_DB_APPLICATION_TOKEN"]
    if "ASTRA_DB_APPLICATION_TOKEN" in os.environ
    else getpass("Astra DB Token: ")
)
ASTRA_DB_ENDPOINT = (
    os.environ["ASTRA_DB_API_ENDPOINT"]
    if "ASTRA_DB_API_ENDPOINT" in os.environ
    else input("Astra DB Endpoint: ")
)

For this demo, we will use the [Spider](https://yale-lily.github.io/spider) dataset, which has been a standard to evaluate generated SQL performance for a few years now. This dataset consists of `question`, `query` pairs to indicate the ideal query to be generated from a given natural language question.
We'll use the LangChain SQL agent to run those queries against a SQL database that has the Spider `Pets` schema.
For this we create a local sqlite database for convenience but the principle would be the same for any SQL database.


In [4]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///spider.db")

PETS_SQL = """
PRAGMA foreign_keys=OFF;
BEGIN TRANSACTION;
CREATE TABLE Student (
       StuID    	INTEGER PRIMARY KEY,
       LName		VARCHAR(12),
       Fname		VARCHAR(12),
       Age		INTEGER,
       Sex		VARCHAR(1),
       Major		INTEGER,
       Advisor		INTEGER,
       city_code	VARCHAR(3)
);
INSERT INTO Student VALUES(1001,'Smith','Linda',18,'F',600,1121,'BAL');
INSERT INTO Student VALUES(1002,'Kim','Tracy',19,'F',600,7712,'HKG');
INSERT INTO Student VALUES(1003,'Jones','Shiela',21,'F',600,7792,'WAS');
INSERT INTO Student VALUES(1004,'Kumar','Dinesh',20,'M',600,8423,'CHI');
INSERT INTO Student VALUES(1005,'Gompers','Paul',26,'M',600,1121,'YYZ');
INSERT INTO Student VALUES(1006,'Schultz','Andy',18,'M',600,1148,'BAL');
INSERT INTO Student VALUES(1007,'Apap','Lisa',18,'F',600,8918,'PIT');
INSERT INTO Student VALUES(1008,'Nelson','Jandy',20,'F',600,9172,'BAL');
INSERT INTO Student VALUES(1009,'Tai','Eric',19,'M',600,2192,'YYZ');
INSERT INTO Student VALUES(1010,'Lee','Derek',17,'M',600,2192,'HOU');
INSERT INTO Student VALUES(1011,'Adams','David',22,'M',600,1148,'PHL');
INSERT INTO Student VALUES(1012,'Davis','Steven',20,'M',600,7723,'PIT');
INSERT INTO Student VALUES(1014,'Norris','Charles',18,'M',600,8741,'DAL');
INSERT INTO Student VALUES(1015,'Lee','Susan',16,'F',600,8721,'HKG');
INSERT INTO Student VALUES(1016,'Schwartz','Mark',17,'M',600,2192,'DET');
INSERT INTO Student VALUES(1017,'Wilson','Bruce',27,'M',600,1148,'LON');
INSERT INTO Student VALUES(1018,'Leighton','Michael',20,'M',600,1121,'PIT');
INSERT INTO Student VALUES(1019,'Pang','Arthur',18,'M',600,2192,'WAS');
INSERT INTO Student VALUES(1020,'Thornton','Ian',22,'M',520,7271,'NYC');
INSERT INTO Student VALUES(1021,'Andreou','George',19,'M',520,8722,'NYC');
INSERT INTO Student VALUES(1022,'Woods','Michael',17,'M',540,8722,'PHL');
INSERT INTO Student VALUES(1023,'Shieber','David',20,'M',520,8722,'NYC');
INSERT INTO Student VALUES(1024,'Prater','Stacy',18,'F',540,7271,'BAL');
INSERT INTO Student VALUES(1025,'Goldman','Mark',18,'M',520,7134,'PIT');
INSERT INTO Student VALUES(1026,'Pang','Eric',19,'M',520,7134,'HKG');
INSERT INTO Student VALUES(1027,'Brody','Paul',18,'M',520,8723,'LOS');
INSERT INTO Student VALUES(1028,'Rugh','Eric',20,'M',550,2311,'ROC');
INSERT INTO Student VALUES(1029,'Han','Jun',17,'M',100,2311,'PEK');
INSERT INTO Student VALUES(1030,'Cheng','Lisa',21,'F',550,2311,'SFO');
INSERT INTO Student VALUES(1031,'Smith','Sarah',20,'F',550,8772,'PHL');
INSERT INTO Student VALUES(1032,'Brown','Eric',20,'M',550,8772,'ATL');
INSERT INTO Student VALUES(1033,'Simms','William',18,'M',550,8772,'NAR');
INSERT INTO Student VALUES(1034,'Epp','Eric',18,'M',50,5718,'BOS');
INSERT INTO Student VALUES(1035,'Schmidt','Sarah',26,'F',50,5718,'WAS');
CREATE TABLE Has_Pet (
       StuID		INTEGER,
       PetID		INTEGER,
       FOREIGN KEY(PetID) REFERENCES Pets(PetID),
       FOREIGN KEY(StuID) REFERENCES Student(StuID)
);
INSERT INTO Has_Pet VALUES(1001,2001);
INSERT INTO Has_Pet VALUES(1002,2002);
INSERT INTO Has_Pet VALUES(1002,2003);
INSERT INTO Has_Pet VALUES(1003,2001);
INSERT INTO Has_Pet VALUES(1003,2002);
INSERT INTO Has_Pet VALUES(1003,2003);
INSERT INTO Has_Pet VALUES(1034,2001);
INSERT INTO Has_Pet VALUES(1034,2002);
INSERT INTO Has_Pet VALUES(1034,2003);
CREATE TABLE Pets (
       PetID		INTEGER PRIMARY KEY,
       PetType		VARCHAR(20),
       pet_age INTEGER,
       weight REAL
);
INSERT INTO Pets VALUES(2001,'cat',3,12.0);
INSERT INTO Pets VALUES(2002,'dog',2,13.400000000000000799);
INSERT INTO Pets VALUES(2003,'dog',1,9.3000000000000007105);
COMMIT;
"""

with db._engine.begin() as conn:  # noqa: SLF001
    conn.connection.executescript(PETS_SQL)

We download the Spider dataset:

In [None]:
from datasets import load_dataset

spider = load_dataset("spider", split="validation")
spider_df = spider.to_pandas()
spider_schema = load_dataset("richardr1126/spider-schema", split="train")
spider_schema_df = spider_schema.to_pandas()

First, let's create a LangChain `SemanticSimilarityExampleSelector`. This component wraps a VectorStore to get examples similar to a given input. Our examples are the question/SQL query pairs of the Spider dataset.


In [6]:
from langchain_astradb import AstraDBVectorStore
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

examples = [
    {"input": row["question"], "query": row["query"]}
    for _, row in spider_df[["question", "query"]].iterrows()
]
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    emb,
    AstraDBVectorStore,
    k=5,
    input_keys=["input"],
    collection_name="lc_few_shots",
    token=ASTRA_DB_TOKEN,
    api_endpoint=ASTRA_DB_ENDPOINT,
)

We can get a few SQL query examples for a given prompt:

In [7]:
some_examples = example_selector.select_examples(
    {"input": "Who are the students who have both cat and dog pets."}
)
for example in some_examples:
    print(example)

{'input': "What are the students' first names who have both cats and dogs as pets?", 'query': "SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'dog'"}
{'input': 'Find the first name of students who have both cat and dog pets .', 'query': "select t1.fname from student as t1 join has_pet as t2 on t1.stuid  =  t2.stuid join pets as t3 on t3.petid  =  t2.petid where t3.pettype  =  'cat' intersect select t1.fname from student as t1 join has_pet as t2 on t1.stuid  =  t2.stuid join pets as t3 on t3.petid  =  t2.petid where t3.pettype  =  'dog'"}
{'input': 'What are the first names of every student who has a cat or dog as a pet?', 'query': "SELECT DISTINCT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3

We can now instanciate a `FewShotPromptTemplate` passing it our example selector. We reuse prompt template LangChain uses for SQL and add it a line to indicate that what follows will be a list of examples.

In [8]:
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    prefix=SQL_PREFIX
    + "\n\nHere are some examples of user inputs and their corresponding SQL queries:",
    suffix="",
    input_variables=["input", "dialect", "top_k"],
)

We can check what the final prompt looks like for a given input.

In [9]:
print(
    few_shot_prompt.format(
        input="Who are the students who have both cat and dog.",
        top_k=3,
        dialect="SQLite",
    )
)

You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 3 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't kn

Eventually, we use our Dynamic Few-Shot prompt as system message for a SQL agent (we override the default SQL agent prompt) and we can use the agent to autonomously answer our questions on the SQL database.

The agent will:
* Connect to the Sqlite database
* List the available tables
* Get the table schemas
* Generate a prompt with examples similar to the input question
* Send the prompt to OpenAI API and get the corresponding SQL query
* Execute the SQL query an the Sqlite database and get the results
* Generate a prompt with the results to get a text answer for the input question
* Send the prompt to OpenAI API and get the final answer to the question

In [11]:
from langchain.agents.agent_toolkits.sql.prompt import SQL_FUNCTIONS_SUFFIX, SQL_PREFIX
from langchain_community.agent_toolkits import create_sql_agent
from langchain_core.messages import AIMessage
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)
from langchain_openai import ChatOpenAI

# from langchain.globals import set_debug
# set_debug(True)

db = SQLDatabase.from_uri("sqlite:///spider.db")

messages = [
    SystemMessagePromptTemplate(prompt=few_shot_prompt),
    HumanMessagePromptTemplate.from_template("{input}"),
    AIMessage(content=SQL_FUNCTIONS_SUFFIX),
    MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent = create_sql_agent(llm, db=db, prompt=prompt, agent_type="openai-tools")

result = agent.run({"input": "Who are the students who have a cat or a dog?"})

print(result)

The students who have a cat or dog as pets are Linda, Tracy, Shiela, and Eric.


Cleanup Astra DB

In [12]:
example_selector.vectorstore.delete_collection()