# Dynamic Few-Shot Prompting for SQL

Generating SQL to answer natural language questions is _hard_. There has been a large amount of research in this area, even before LLMs were made popular by ChatGPT. Although every individual problem may require specific techniques to achieve the best performance, there are still some general principles that can apply across many domains.

This notebook will demonstrate one of those generalizable techniques - Dynamic Few-Shot Prompting using Astra DB - which can be used to improve the accuracy of generating queries to access your structured data.

## TODOs

- remove the local db from reqs
- retest with redshift (eeh)
- move all imports to a top section
- rename the "execute_rs" for universality
- rephrase to clarify RS optional
- typing/style to all code
- prettify the text, the output, the narrative
- handle the astrapy prereleaseness
- a messy mess with the 5 test queries magically aligned with the 5 golden ones (bleah)

## Setup

#### Requirements

In [1]:
!pip install -q \
  "openai>=1.0,<2.0" \
  "astrapy==2.0.0-rc1" \
  "datasets==3.*" \
  "redshift_connector==2.*" \
  "tenacity==9.*"

#### Connect to Services

What you'll need:
- OpenAI API key ([link](https://platform.openai.com/docs/quickstart/account-setup))
- Astra DB Token & URL ([link](https://docs.datastax.com/en/astra/astra-db-vector/get-started/quickstart.html#create-a-serverless-vector-database))
- Amazon RedShift credentials ([link](https://docs.aws.amazon.com/redshift/latest/gsg/new-user-serverless.html#serverless-console-resource-creation))

In [77]:
# Initialize the OpenAI Client
import os
from typing import List

from getpass import getpass
import openai
from tenacity import retry, wait_exponential

if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")

client = openai.OpenAI()


# This introduces expoential backoff in case ChatGPT is being rate limited/errors
@retry(wait=wait_exponential(multiplier=1.2, min=4, max=30))
def chat_gpt_completion(prompt: str) -> str:
    return client.chat.completions.create(
        messages=[{
            "role": "user",
            "content": prompt,
        }],
        model="gpt-4o-2024-08-06",
    ).choices[0].message.content


@retry(wait=wait_exponential(multiplier=1.2, min=4, max=30))
def compute_embedding(text: str) -> List[float]:
    return client.embeddings.create(
        input=text,
        model="text-embedding-3-small",
        timeout=10,
    ).data[0].embedding


In [6]:
# Initialize the Astra DB vector client
import os

from astrapy import DataAPIClient

# Grab the Astra token and API Endpoint from the environment or user input
ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"]
if not ASTRA_DB_API_ENDPOINT:
    ASTRA_DB_API_ENDPOINT = input("Astra DB API Endpoint: ")
ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
if not ASTRA_DB_APPLICATION_TOKEN:
    ASTRA_DB_APPLICATION_TOKEN = getpass("Astra DB Token: ")

data_api_client = DataAPIClient()
astra_db = data_api_client.get_database(
    ASTRA_DB_API_ENDPOINT,
    token=ASTRA_DB_APPLICATION_TOKEN,
)

In [None]:
# Initialize your RedShift connection
import redshift_connector

# NOTE: To connect to RedShift using other forms of auth, follow the steps here:
# https://github.com/aws/amazon-redshift-python-driver/blob/master/tutorials/001%20-%20Connecting%20to%20Amazon%20Redshift.ipynb
RS_HOST_URL = os.getenv(
    "REDSHIFT_HOST_URL",
    getpass("RedShift Host URL: "),
)
RS_USER = os.getenv(
    "REDSHIFT_USER",
    input("RedShift User: "),
)
RS_PASSWORD = os.getenv(
    "REDSHIFT_PASSWORD",
    getpass("RedShift Password: "),
)
RS_DB_NAME = os.getenv(
    "REDSHIFT_DB_NAME",
    input("RedShift DB Name: "),
)


def rs_execute(*statements, return_as_df: bool = True, commit: bool = True) -> pd.DataFrame:
    """Util to execute RedShift statements and return the result of the final query"""
    with redshift_connector.connect(
        host=RS_HOST_URL,
        database=RS_DB_NAME,
        user=RS_USER,
        password=RS_PASSWORD,
    ) as conn:
        with conn.cursor() as cursor:
            for statement in statements:
                cursor.execute(statement)

            try:
                if return_as_df:
                    res = cursor.fetch_dataframe()
                else:
                    res = cursor.fetchall()
            except redshift_connector.ProgrammingError:
                # No result set (final statement was not a SELECT)
                res = None

            if commit:
                conn.commit()

            return res


In [60]:
### MOCK REDSHIFT VIA SQLITE3
import sqlite3
import pandas as pd

DB_FILENAME = "steo1.db"

def rs_execute(*statements) -> pd.DataFrame:
    """Util to execute DB SQL statements and return the result of the final query"""
    with sqlite3.connect(DB_FILENAME) as conn:
        cursor = conn.cursor()
        for statement in statements:
            cursor.execute(statement)
        try:
            res = cursor.fetchall()
        except Exception as e:#redshift_connector.ProgrammingError:
            # No result set (final statement was not a SELECT)
            print("EXCEPTION TO FIX", str(e), e)
            res = None

        return res


#### Load Data

For this demo, we will use the [Spider](https://yale-lily.github.io/spider) dataset, which has been a standard to evaluate generated SQL performance for a few years now. This dataset consists of `question`, `query` pairs to indicate the ideal query to be generated from a given natural language question.

NOTE: To apply this to a specific use case, the best way to collect data is to store generated SQL from an application in some live environment (dev, staging, prod all work as long as the queries are realistic). You can then either have a quick feedback interface for users to thumbs up/thumbs down the generated queries, or can have internal human evaluators grade the queries offline. As more poisitively graded generations are available, you will have more examples to be able to use to improve your live SQL-generating application.

In [8]:
from datasets import load_dataset

spider = load_dataset("spider", split="validation")
spider_df = spider.to_pandas()
spider_schema = load_dataset("richardr1126/spider-schema", split="train")
spider_schema_df = spider_schema.to_pandas()

In [9]:
spider_df.head(5)

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]"
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]"
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin..."
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ..."
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ...","[SELECT, avg, (, age, ), ,, min, (, age, ), ,,...","[select, avg, (, age, ), ,, min, (, age, ), ,,...","[What, is, the, average, ,, minimum, ,, and, m..."


In [10]:
# Remove some questions for us to test our pipeline with
test_ndxs = [21, 28, 60, 63, 75]  # Randomly picked some queries the model succeeds and fails on
ndxs_to_drop = test_ndxs.copy()
test_df = spider_df.loc[test_ndxs]

for ndx in test_ndxs:
    row = spider_df.loc[ndx]
    print(f"Question ({row['db_id']}):\n    {row['question']}")
    print(f"Gold Query SQL:\n    {row['query']}")
    print("=" * 80)

    # Remove any rows that have the exact same query as the answer
    # (this prevents some data leakage so our dynamic few-shot approach can't already see the correct query)
    same_query_rows = spider_df[spider_df['query'] == row['query']]
    for same_query_ndx in same_query_rows.index:
        ndxs_to_drop.append(same_query_ndx)

spider_df = spider_df.drop(ndxs_to_drop)

Question (concert_singer):
    How many concerts occurred in 2014 or 2015?
Gold Query SQL:
    SELECT count(*) FROM concert WHERE YEAR  =  2014 OR YEAR  =  2015
Question (concert_singer):
    Show the stadium names without any concert.
Gold Query SQL:
    SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
Question (pets_1):
    What are the students' first names who have both cats and dogs as pets?
Gold Query SQL:
    SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'dog'
Question (pets_1):
    Find the id of students who do not have a cat pet.
Gold Query SQL:
    SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.pe

#### SQL DB Setup

Set up a demo DB using data from the Spider dataset. You can optionally connect this to your preexisting DB, provided you have some example (question, query) pairs to use as few-shot exemplars.

We will be using Amazon RedShift to store the structured data, but this same approach should work with any SQL-compatible DB setup. We will create tables matching the Spider schema and insert some fake data for each of the test questions we chose above.

NOTE: To generate the below SQL statements, I simply used ChatGPT (4) with the prompt template:

> Given the below schema, primary keys, and foreign keys, please give me valid SQL statements for creating tables that match that schema:
>
> Schema: `{spider_schema_df.loc[db_id]['Schema (values (type))']}`
>
> Primary Keys: `{spider_schema_df.loc[db_id]['Primary Keys']}`
>
> Foreign Keys: `{spider_schema_df.loc[db_id]['Foreign Keys']}`

> Thanks! Now please give me SQL statements that populate the tables with valid data, at a minimum of 5 rows per table

In [11]:
# Set up concert_singer DB
CREATE_TABLES_SQL = """
-- Creating the stadium table
CREATE TABLE stadium (
    Stadium_ID INT PRIMARY KEY,
    Location TEXT,
    Name TEXT,
    Capacity INT,
    Highest INT,
    Lowest INT,
    Average INT
);

-- Creating the singer table
CREATE TABLE singer (
    Singer_ID INT PRIMARY KEY,
    Name TEXT,
    Country TEXT,
    Song_Name TEXT,
    Song_release_year TEXT,
    Age INT,
    Is_male BOOLEAN
);

-- Creating the concert table
CREATE TABLE concert (
    concert_ID INT PRIMARY KEY,
    concert_Name TEXT,
    Theme TEXT,
    Stadium_ID INT,
    Year TEXT,
    FOREIGN KEY (Stadium_ID) REFERENCES stadium(Stadium_ID)
);

-- Creating the singer_in_concert table
CREATE TABLE singer_in_concert (
    concert_ID INT,
    Singer_ID INT,
    PRIMARY KEY (concert_ID, Singer_ID),
    FOREIGN KEY (concert_ID) REFERENCES concert(concert_ID),
    FOREIGN KEY (Singer_ID) REFERENCES singer(Singer_ID)
);
"""

POPULATE_DATA_SQL = """
-- Populating the stadium table
INSERT INTO stadium (Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average) VALUES
(1, 'New York, USA', 'Liberty Stadium', 50000, 1000, 500, 750),
(2, 'London, UK', 'Royal Arena', 60000, 1500, 600, 900),
(3, 'Tokyo, Japan', 'Sunshine Dome', 55000, 1200, 550, 800),
(4, 'Sydney, Australia', 'Ocean Field', 40000, 900, 400, 650),
(5, 'Berlin, Germany', 'Eagle Grounds', 45000, 1100, 450, 700);

-- Populating the singer table
INSERT INTO singer (Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male) VALUES
(1, 'John Doe', 'USA', 'Freedom Song', '2018', 28, TRUE),
(2, 'Emma Stone', 'UK', 'Rolling Hills', '2019', 25, FALSE),
(3, 'Haruki Tanaka', 'Japan', 'Tokyo Lights', '2020', 30, TRUE),
(4, 'Alice Johnson', 'Australia', 'Ocean Waves', '2021', 27, FALSE),
(5, 'Max Müller', 'Germany', 'Berlin Nights', '2017', 32, TRUE);

-- Populating the concert table
INSERT INTO concert (concert_ID, concert_Name, Theme, Stadium_ID, Year) VALUES
(1, 'Freedom Fest', 'Pop', 1, '2021'),
(2, 'Rock Mania', 'Rock', 2, '2022'),
(3, 'Electronic Waves', 'Electronic', 3, '2020'),
(4, 'Jazz Evenings', 'Jazz', 3, '2019'),
(5, 'Classical Mornings', 'Classical', 5, '2023');

-- Populating the singer_in_concert table
INSERT INTO singer_in_concert (concert_ID, Singer_ID) VALUES
(1, 1),
(1, 2),
(2, 3),
(3, 4),
(4, 5),
(5, 1),
(2, 2),
(3, 3),
(4, 4),
(5, 5);
"""

statements = [
    statement.strip() for statement in (CREATE_TABLES_SQL + POPULATE_DATA_SQL).split(";")
    if len(statement.strip()) > 0
]
rs_execute(*statements)

[]

In [12]:
# Set up pets_1 DB
CREATE_TABLES_SQL = """
-- Creating the Student table
CREATE TABLE Student (
    StuID INT PRIMARY KEY,
    LName VARCHAR(255),
    Fname VARCHAR(255),
    Age INT,
    Sex VARCHAR(10),
    Major INT,
    Advisor INT,
    city_code VARCHAR(50)
);

-- Creating the Pets table
CREATE TABLE Pets (
    PetID INT PRIMARY KEY,
    PetType VARCHAR(255),
    pet_age INT,
    weight INT
);

-- Creating the Has_Pet table
CREATE TABLE Has_Pet (
    StuID INT,
    PetID INT,
    FOREIGN KEY (StuID) REFERENCES Student(StuID),
    FOREIGN KEY (PetID) REFERENCES Pets(PetID)
);
"""

POPULATE_DATA_SQL = """
-- Populating the Student table
INSERT INTO Student (StuID, LName, Fname, Age, Sex, Major, Advisor, city_code) VALUES
(101, 'Smith', 'John', 20, 'M', 501, 301, 'NYC'),
(102, 'Johnson', 'Emma', 22, 'F', 502, 302, 'LAX'),
(103, 'Williams', 'Michael', 21, 'M', 503, 303, 'CHI'),
(104, 'Brown', 'Sarah', 23, 'F', 504, 304, 'HOU'),
(105, 'Jones', 'David', 19, 'M', 505, 305, 'PHI');

-- Populating the Pets table
INSERT INTO Pets (PetID, PetType, pet_age, weight) VALUES
(201, 'dog', 3, 20.5),
(202, 'cat', 5, 10.2),
(203, 'dog', 2, 8.1),
(204, 'parrot', 4, 0.5),
(205, 'hamster', 1, 0.7);

-- Populating the Has_Pet table
INSERT INTO Has_Pet (StuID, PetID) VALUES
(101, 201),
(101, 202),
(105, 203),
(103, 204),
(104, 205),
(105, 201);
"""

statements = [
    statement.strip() for statement in (CREATE_TABLES_SQL + POPULATE_DATA_SQL).split(";")
    if len(statement.strip()) > 0
]
rs_execute(*statements)

[]

Finally, we just need an evaluation function given a list of queries. This evaluation function should return if the queries run successfully against our fake data, and if they return the same results that the gold SQL queries return.

In [110]:
def eval_generated_queries(queries: List[str]) -> pd.DataFrame:
    """Evaluate the given queries against the test set, and return a report of the performance on each row"""
    report = []

    for ndx, pred_sql in enumerate(queries):
        row = test_df.iloc[ndx]
        gold_sql = row["query"]
        gold_results = rs_execute(gold_sql)

        pred_sql = queries[ndx]
        try:
            pred_results = rs_execute(pred_sql)
            err = None

            # Figure out correctness, not super straightforward
            if len(gold_results) == len(pred_results):
                # (note: here and in the next sorting, we ignore 'sorting collisions' for simplicity...)
                in_sorted_pred = [sorted(tpl,key=str) for tpl in pred_results]
                in_sorted_gold = [sorted(tpl,key=str) for tpl in gold_results]
                if "ORDER BY" in gold_sql.upper():
                    # must leave order among tuples untouched and expect it to be correct
                    correct = in_sorted_pred == in_sorted_gold
                else:
                    # sort sequence of tuples before comparison\
                    correct = sorted(in_sorted_pred) == sorted(in_sorted_gold)
            else:
                correct = False

            # debug
            print(f"\n\nGOLD_SQL: {gold_sql}")
            print(f"GOLD_RES: {str(gold_results)}")
            print(f"PRED_RES: {str(pred_results)}")
            print(f"==> CORRECT: {correct}")
            
        except Exception as e:
            print("EXCEPTION TO FIX", str(e), e)
            pred_results = None
            err = e
            correct = False

        report.append({
            "DB ID": row["db_id"],
            "Question": row["question"],
            "Correct": correct,
            "Error": err,
        })

    return pd.DataFrame(report)


## Generating SQL

Now, let's go ahead and generate some SQL queries in a few different ways. First, we'll use ChatGPT with a prompt template from the [SQL-PaLM](https://arxiv.org/abs/2306.00739) paper as-is to see how it does.

#### Zero-Shot Text to SQL

First, we have some utility functions to define. The logic below is mostly for translating the Spider schema format to match the prompt template of the SQL-PaLM paper, which has proven to be effective.

In [61]:
import re
from typing import List


def _format_schema(db_id: str, old_schema: str) -> (str, str):
    """Converts the existing schema format of spider_schema_df to schema, columns used by SQL-PaLM"""
    schema_str = f"| {db_id} "
    cols_str = ""

    for table_name in re.findall(r"([^,: ]*) :", old_schema):
        schema_str += f"| {table_name} : "

        start_ndx = old_schema.find(table_name + " :")
        end_ndx = old_schema.find(":", start_ndx + len(table_name) + 4)
        if end_ndx == -1:
            end_ndx = len(old_schema)
        curr_substr = old_schema[start_ndx:end_ndx]
        for col_name, col_type in re.findall(r" ([^:,]*) \(([^,]*)\)", curr_substr):
            schema_str += f"{col_name} , "
            cols_str += f"{table_name} : {col_name} ({col_type}) | "

        schema_str = schema_str[:-2]
    cols_str = cols_str[:-2]

    return schema_str + ";", cols_str + ";"


def _format_prompt(db_id: str, question: str, answer: str | None = None) -> str:
    """Returns a formatted section of the prompt describing the DB Schema"""
    schema_row = spider_schema_df[spider_schema_df['db_id'] == db_id].iloc[0]
    schema_str, cols_str = _format_schema(db_id, schema_row['Schema (values (type))'])

    if answer is not None:
        # Is example
        prefix_str = "Here is an example: "
    else:
        prefix_str = "Given the following schema information, generate valid SQL to answer the provided query. Enclose the query in markdown code-block syntax."

    prompt_str = f"""{prefix_str}Convert text to SQL:

[Schema : (values)]: {schema_str}

[Column names (type)]: {cols_str}

[Primary Keys]: {schema_row['Primary Keys']}

[Foreign Keys]: {schema_row['Foreign Keys']}

[Q]: {question}

[SQL]: """
    if answer is not None:
        prompt_str += answer + "\n\n"

    return prompt_str


def zero_shot_prompt(db_id: str, question: str) -> str:
    """Creates a Zero-Shot prompt for the given question & db_id"""
    return _format_prompt(db_id, question)


def few_shot_prompt(db_id: str, question: str, indices: List[int]) -> str:
    """Creates a few shot prompt using the given indices to construct the example"""
    prompt = ""
    for ndx in indices:
        row = spider_df.loc[ndx]
        prompt += _format_prompt(row['db_id'], row['question'], row['query'])

    return prompt + _format_prompt(db_id, question)


def clean_response(resp):
    _start = resp.find("```sql")
    if _start < 0:
        raise ValueError("Invalid answer from LLM.")
    rest = resp[_start+6:]
    _stop = rest.find("```")
    if _stop < 0:
        raise ValueError("Invalid answer from LLM.")
    _sql = rest[:_stop].strip()
    return _sql

def generate_sql(question: str, db_id: str, prompt_fn, debug_prompt=False) -> str:
    """Generates a Chat response for a given Spider question & DB ID"""
    prompt = prompt_fn(db_id=db_id, question=question)

    if debug_prompt:
        print(f"Prompting ChatGPT with:\n{prompt}\n\n")

    response = chat_gpt_completion(prompt)
    return clean_response(response)


Next, we can try some of the questions we pulled out from the examples above

In [62]:
# Showing the prompt here as well to get a sense of what the model is seeing for context
sql1 = generate_sql(
    question="How many concerts occurred in 2014 or 2015?",
    db_id="concert_singer",
    prompt_fn=zero_shot_prompt,
    debug_prompt=True,
)
print(sql1)

Prompting ChatGPT with:
Given the following schema information, generate valid SQL to answer the provided query. Enclose the query in markdown code-block syntax.Convert text to SQL:

[Schema : (values)]: | concert_singer | stadium : Stadium_ID , Location , Name , Capacity , Highest , Lowest , Average | singer : Singer_ID , Name , Country , Song_Name , Song_release_year , Age , Is_male | concert : concert_ID , concert_Name , Theme , Stadium_ID , Year | singer_in_concert : concert_ID , Singer_ID ;

[Column names (type)]: stadium : Stadium_ID (number) | stadium : Location (text) | stadium : Name (text) | stadium : Capacity (number) | stadium : Highest (number) | stadium : Lowest (number) | stadium : Average (number) | singer : Singer_ID (number) | singer : Name (text) | singer : Country (text) | singer : Song_Name (text) | singer : Song_release_year (text) | singer : Age (number) | singer : Is_male (others) | concert : concert_ID (number) | concert : concert_Name (text) | concert : Theme 

In [64]:
sql2 = generate_sql(
    question="Show the stadium names without any concert.",
    db_id="concert_singer",
    prompt_fn=zero_shot_prompt,
)
print(sql2)

SELECT s.Name
FROM stadium s
LEFT JOIN concert c ON s.Stadium_ID = c.Stadium_ID
WHERE c.concert_ID IS NULL;


In [65]:
sql3 = generate_sql(
    question="What are the students' first names who have both cats and dogs as pets?",
    db_id="pets_1",
    prompt_fn=zero_shot_prompt,
)
print(sql3)

SELECT DISTINCT s.Fname
FROM Student s
JOIN Has_Pet hp1 ON s.StuID = hp1.StuID
JOIN Pets p1 ON hp1.PetID = p1.PetID
JOIN Has_Pet hp2 ON s.StuID = hp2.StuID
JOIN Pets p2 ON hp2.PetID = p2.PetID
WHERE p1.PetType = 'Cat' AND p2.PetType = 'Dog';


In [66]:
sql4 = generate_sql(
    question="Find the id of students who do not have a cat pet.",
    db_id="pets_1",
    prompt_fn=zero_shot_prompt,
)
print(sql4)

SELECT s.StuID
FROM Student s
WHERE s.StuID NOT IN (
    SELECT hp.StuID
    FROM Has_Pet hp
    JOIN Pets p ON hp.PetID = p.PetID
    WHERE p.PetType = 'cat'
);


In [34]:
sql5 = generate_sql(
    question="Find the first name and age of students who have a pet.",
    db_id="pets_1",
    prompt_fn=zero_shot_prompt,
)
print(sql5)

SELECT Student.Fname, Student.Age
FROM Student
JOIN Has_Pet ON Student.StuID = Has_Pet.StuID;


In [67]:
# Now let's evaluate all 5 queries and see how the LLM did
sqls = [sql1, sql2, sql3, sql4, sql5]
report = eval_generated_queries(sqls)
display(report)



GOLD_SQL: SELECT count(*) FROM concert WHERE YEAR  =  2014 OR YEAR  =  2015
GOLD_RES: [(0,)]
PRED_RES: [(0,)]
==> CORRECT: True


GOLD_SQL: SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
GOLD_RES: [('Ocean Field',)]
PRED_RES: [('Ocean Field',)]
==> CORRECT: True


GOLD_SQL: SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'dog'
GOLD_RES: [('John',)]
PRED_RES: []
==> CORRECT: False


GOLD_SQL: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat'
GOLD_RES: [(102,), (103,), (104,), (105,)]
PRED_RES: [(102,), (103,), (104,), (105,)]
==> CORRECT: True


GOLD_SQL: SELECT DISTINCT T1.fna

Unnamed: 0,DB ID,Question,Correct,Error
0,concert_singer,How many concerts occurred in 2014 or 2015?,True,
1,concert_singer,Show the stadium names without any concert.,True,
2,pets_1,What are the students' first names who have bo...,False,
3,pets_1,Find the id of students who do not have a cat ...,True,
4,pets_1,Find the first name and age of students who ha...,False,


Not bad! 3 / 5 questions answered correctly - LLMs are truly a game changer. However, if we were deploying this in production, we would want the maximum possible accuracy, so let's see how we can make the model a little bit better.

#### Few-Shot Text to SQL

Few-Shot prompting has been a technique for a few years, since LLMs were first conceived. The theory behind few-shot prompting (also known as "In-Context Learning") is that we are showing the model examples of our desired input/output pairs, and guiding the model to make decisions more in line with what we expect for the task at hand. Since the original GPT-3 paper, [Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165), few-shot methods have been shown to greatly improve performance over zero-shot methods (aka no examples in the prompt). Because of this, let's see how our model does if we give it a few examples of correct output.

In [68]:
from functools import partial

fixed_few_shot_prompt_fn = partial(few_shot_prompt, indices=[0, 2])

NOTE: One of the biggest tradeoffs with few-shot learning is the context length. We end up sacrificing context-length (and thus increasing costs, latency, and decreasing the space left for the generation) for each example shown to the model. For examples where the input context is very long (like showing schemas for SQL queries), this can be a difficult tradeoff to make. In our scenario, we will settle for just showing 2 examples to the model.

In [69]:
# Showing the prompt here as well to get a sense of what the model is seeing for context
sql1 = generate_sql(
    question="How many concerts occurred in 2014 or 2015?",
    db_id="concert_singer",
    prompt_fn=fixed_few_shot_prompt_fn,
    debug_prompt=True,
)
print(sql1)

Prompting ChatGPT with:
Here is an example: Convert text to SQL:

[Schema : (values)]: | concert_singer | stadium : Stadium_ID , Location , Name , Capacity , Highest , Lowest , Average | singer : Singer_ID , Name , Country , Song_Name , Song_release_year , Age , Is_male | concert : concert_ID , concert_Name , Theme , Stadium_ID , Year | singer_in_concert : concert_ID , Singer_ID ;

[Column names (type)]: stadium : Stadium_ID (number) | stadium : Location (text) | stadium : Name (text) | stadium : Capacity (number) | stadium : Highest (number) | stadium : Lowest (number) | stadium : Average (number) | singer : Singer_ID (number) | singer : Name (text) | singer : Country (text) | singer : Song_Name (text) | singer : Song_release_year (text) | singer : Age (number) | singer : Is_male (others) | concert : concert_ID (number) | concert : concert_Name (text) | concert : Theme (text) | concert : Stadium_ID (text) | concert : Year (text) | singer_in_concert : concert_ID (number) | singer_in_co

In [70]:
sql2 = generate_sql(
    question="Show the stadium names without any concert.",
    db_id="concert_singer",
    prompt_fn=fixed_few_shot_prompt_fn,
)
print(sql2)

SELECT s.Name
FROM stadium s
LEFT JOIN concert c ON s.Stadium_ID = c.Stadium_ID
WHERE c.concert_ID IS NULL


In [71]:
sql3 = generate_sql(
    question="What are the students' first names who have both cats and dogs as pets?",
    db_id="pets_1",
    prompt_fn=fixed_few_shot_prompt_fn,
)
print(sql3)

SELECT DISTINCT s.Fname 
FROM Student s
JOIN Has_Pet hp1 ON s.StuID = hp1.StuID
JOIN Pets p1 ON hp1.PetID = p1.PetID
JOIN Has_Pet hp2 ON s.StuID = hp2.StuID
JOIN Pets p2 ON hp2.PetID = p2.PetID
WHERE p1.PetType = 'cat' AND p2.PetType = 'dog'


In [72]:
sql4 = generate_sql(
    question="Find the id of students who do not have a cat pet.",
    db_id="pets_1",
    prompt_fn=fixed_few_shot_prompt_fn,
)
print(sql4)

SELECT StuID 
FROM Student 
WHERE StuID NOT IN (
    SELECT distinct Has_Pet.StuID 
    FROM Has_Pet 
    JOIN Pets ON Has_Pet.PetID = Pets.PetID 
    WHERE Pets.PetType = 'cat'
)


In [73]:
sql5 = generate_sql(
    question="Find the first name and age of students who have a pet.",
    db_id="pets_1",
    prompt_fn=fixed_few_shot_prompt_fn,
)
print(sql5)

SELECT Student.Fname, Student.Age
FROM Student
JOIN Has_Pet ON Student.StuID = Has_Pet.StuID


In [74]:
# Now lets evaluate all 5 queries and see how the LLM did
sqls = [sql1, sql2, sql3, sql4, sql5]
report = eval_generated_queries(sqls)
display(report)



GOLD_SQL: SELECT count(*) FROM concert WHERE YEAR  =  2014 OR YEAR  =  2015
GOLD_RES: [(0,)]
PRED_RES: [(0,)]
==> CORRECT: True


GOLD_SQL: SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
GOLD_RES: [('Ocean Field',)]
PRED_RES: [('Ocean Field',)]
==> CORRECT: True


GOLD_SQL: SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'dog'
GOLD_RES: [('John',)]
PRED_RES: [('John',)]
==> CORRECT: True


GOLD_SQL: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat'
GOLD_RES: [(102,), (103,), (104,), (105,)]
PRED_RES: [(102,), (103,), (104,), (105,)]
==> CORRECT: True


GOLD_SQL: SELECT DISTINC

Unnamed: 0,DB ID,Question,Correct,Error
0,concert_singer,How many concerts occurred in 2014 or 2015?,True,
1,concert_singer,Show the stadium names without any concert.,True,
2,pets_1,What are the students' first names who have bo...,True,
3,pets_1,Find the id of students who do not have a cat ...,True,
4,pets_1,Find the first name and age of students who ha...,False,


No numeric improvement, but we were able to get the 4th question right (at the cost of the third question now generating invalid SQL). Few-Shot prompting is helpful in some cases, but when the LLM already understands the problem structure, we don't see much improvement from showing it other generic examples.

#### Dynamic Few-Shot Text to SQL

The few-shot prompting style used above can be helpful in some cases, but sometimes we need more tailored information for the model to learn from. To better inform the model, we can find similar questions to the one being asked, and use those as In-Context Learning examples (aka few-shot examples) instead!

We'll start by loading the rest of the Spider dataset into Astra DB. Once it's there, then we can create a new `prompt_fn` that generates a few-shot prompt based on the most similar questions to the one being asked.

In [92]:
# First create a collection for vector search
from astrapy.info import CollectionDefinition
from astrapy.api_options import APIOptions, SerdesOptions  # this enables passing directly 'ndarray' and the like to Astra DB

astra_db_collection = astra_db.create_collection(
    "text2sql_examples_test1",
    definition=(
        CollectionDefinition.builder()
        .set_vector_dimension(1536)
        .build()
    ),
    spawn_api_options=APIOptions(
        serdes_options=SerdesOptions(
            unroll_iterables_to_lists=True,
        ),
    ),
)

In [78]:
# Next embed all questions in our "train" set
from tqdm.auto import tqdm

tqdm.pandas()

spider_embeddings = spider_df["question"].progress_apply(lambda question: compute_embedding(question))
spider_df["question_embedding"] = spider_embeddings

  0%|          | 0/1025 [00:00<?, ?it/s]

In [102]:
# Now upload those documents to Astra DB
spider_df['_id'] = spider_df.index
# (must use the special name '$vector' for the embedding)
spider_documents = spider_df.rename(columns={'question_embedding': '$vector'}).to_dict(orient="records")

result = astra_db_collection.insert_many(
    spider_documents,
    # this is a lot of big documents, so let's be cautious
    timeout_ms=45000,
    request_timeout_ms=20000,
)
print(result)

CollectionInsertManyResult(inserted_ids=[0, 1, 2, 3, 4 ... (1025 total)], raw_results=...)


In [103]:
def dynamic_few_shot_prompt_fn(db_id: str, question: str) -> str:
    emb = compute_embedding(question)
    docs = astra_db_collection.find(sort={"$vector": emb}, limit=2)
    closest_q_ids = [
        doc["_id"] for doc in docs
    ]

    return few_shot_prompt(db_id, question, indices=closest_q_ids)


In [104]:
# Showing the prompt here as well to get a sense of what the model is seeing for context
sql1 = generate_sql(
    question="How many concerts occurred in 2014 or 2015?",
    db_id="concert_singer",
    prompt_fn=dynamic_few_shot_prompt_fn,
    debug_prompt=True,
)
print(sql1)

Prompting ChatGPT with:
Here is an example: Convert text to SQL:

[Schema : (values)]: | concert_singer | stadium : Stadium_ID , Location , Name , Capacity , Highest , Lowest , Average | singer : Singer_ID , Name , Country , Song_Name , Song_release_year , Age , Is_male | concert : concert_ID , concert_Name , Theme , Stadium_ID , Year | singer_in_concert : concert_ID , Singer_ID ;

[Column names (type)]: stadium : Stadium_ID (number) | stadium : Location (text) | stadium : Name (text) | stadium : Capacity (number) | stadium : Highest (number) | stadium : Lowest (number) | stadium : Average (number) | singer : Singer_ID (number) | singer : Name (text) | singer : Country (text) | singer : Song_Name (text) | singer : Song_release_year (text) | singer : Age (number) | singer : Is_male (others) | concert : concert_ID (number) | concert : concert_Name (text) | concert : Theme (text) | concert : Stadium_ID (text) | concert : Year (text) | singer_in_concert : concert_ID (number) | singer_in_co

In [105]:
sql2 = generate_sql(
    question="Show the stadium names without any concert.",
    db_id="concert_singer",
    prompt_fn=dynamic_few_shot_prompt_fn,
)
print(sql2)

SELECT name 
FROM stadium 
WHERE stadium_id NOT IN (
    SELECT stadium_id 
    FROM concert
);


In [106]:
sql3 = generate_sql(
    question="What are the students' first names who have both cats and dogs as pets?",
    db_id="pets_1",
    prompt_fn=dynamic_few_shot_prompt_fn,
)
print(sql3)

SELECT DISTINCT T1.Fname
FROM student AS T1
JOIN has_pet AS T2 ON T1.stuid = T2.stuid
JOIN pets AS T3 ON T2.petid = T3.petid
WHERE T3.pettype = 'cat'
AND T1.stuid IN (
  SELECT T2.stuid
  FROM has_pet AS T2
  JOIN pets AS T3 ON T2.petid = T3.petid
  WHERE T3.pettype = 'dog'
);


In [107]:
sql4 = generate_sql(
    question="Find the id of students who do not have a cat pet.",
    db_id="pets_1",
    prompt_fn=dynamic_few_shot_prompt_fn,
)
print(sql4)

SELECT T1.stuid 
FROM student AS T1 
WHERE T1.stuid NOT IN (
    SELECT T1.stuid 
    FROM student AS T1 
    JOIN has_pet AS T2 ON T1.stuid = T2.stuid 
    JOIN pets AS T3 ON T3.petid = T2.petid 
    WHERE T3.pettype = 'cat'
)


In [108]:
sql5 = generate_sql(
    question="Find the first name and age of students who have a pet.",
    db_id="pets_1",
    prompt_fn=dynamic_few_shot_prompt_fn,
)
print(sql5)

SELECT DISTINCT T1.Fname, T1.Age 
FROM student AS T1 
JOIN has_pet AS T2 ON T1.stuid = T2.stuid


In [111]:
# Now lets evaluate all 5 queries and see how the LLM did
sqls = [sql1, sql2, sql3, sql4, sql5]
report = eval_generated_queries(sqls)
display(report)



GOLD_SQL: SELECT count(*) FROM concert WHERE YEAR  =  2014 OR YEAR  =  2015
GOLD_RES: [(0,)]
PRED_RES: [(0,)]
==> CORRECT: True


GOLD_SQL: SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
GOLD_RES: [('Ocean Field',)]
PRED_RES: [('Ocean Field',)]
==> CORRECT: True


GOLD_SQL: SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'dog'
GOLD_RES: [('John',)]
PRED_RES: [('John',)]
==> CORRECT: True


GOLD_SQL: SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat'
GOLD_RES: [(102,), (103,), (104,), (105,)]
PRED_RES: [(102,), (103,), (104,), (105,)]
==> CORRECT: True


GOLD_SQL: SELECT DISTINC

Unnamed: 0,DB ID,Question,Correct,Error
0,concert_singer,How many concerts occurred in 2014 or 2015?,True,
1,concert_singer,Show the stadium names without any concert.,True,
2,pets_1,What are the students' first names who have bo...,True,
3,pets_1,Find the id of students who do not have a cat ...,True,
4,pets_1,Find the first name and age of students who ha...,True,


Nice! The model was now able to answer 5 / 5 questions! Being able to modify the data you show the LLM at runtime is certainly important, and can have an impact even just changing the examples that you show it to get more familiar with your problem. Utilizing Astra DB as a dynamic few-shot prompt store is a robust and low-effort way to continuously improve your LLM Text to SQL applications in production.