# Dynamic Few-Shot Prompting for SQL generation with Astra DB

Generating SQL to answer natural language questions is _hard_. There has been a large amount of research in this area, even before the rise of LLMs. Although every individual problem may require specific techniques to achieve the best performance, there are still some general principles that can apply across many domains.

This notebook will demonstrate one of those generalizable techniques, "Dynamic Few-Shot Prompting", using Astra DB. This technique greatly improves the accuracy of the generated SQL queries.

In this notebook, you will:
- download a dataset of questions and corresponding answering SQL statements
- populate a relational (SQL) database with sample data to run the above queries
- learn how to query an LLM to have it generate SQL given a natural-language question
- ... then refine the technique with a few-shot improvement
- ... then make the few-shot selection dynamic with the help of a vector database
- quantify the effectiveness of the approaches

At the end, you will have a clear picture of the main challenges and approach to generating SQL with Large Language Models.

To run the notebook, you will need:
- OpenAI API key ([details](https://platform.openai.com/docs/quickstart/account-setup))
- an Astra DB API Endpoint and associated Database Token ([details](https://docs.datastax.com/en/astra/astra-db-vector/get-started/quickstart.html#create-a-serverless-vector-database))

## Setup

### Requirements

_Note: this demo uses the latest Data API python client, `astrapy 2.*`, which is in [pre-release](https://docs.datastax.com/en/astra-db-serverless/api-reference/client-versions.html#version-2-0-preview) at the time of writing._

In [1]:
!pip install -q \
  "openai>=1.0,<2.0" \
  "astrapy==2.0.0-rc1" \
  "datasets==3.*" \
  "tenacity==9.*"

### Settings

This demo uses OpenAI for the LLM and the embeddings. The demo is set to use the GPT3.5 model, which is cheaper and - being less powerful than the most recent ones - better highlights the increase in accuracy of the improvements exemplified in this notebook. Feel free to edit these settings to your liking (including, with the necessary changes later on, switching to a different model vendor altogether).

An embedding model is needed for later calculation of embedding vectors for the question examples.

The SQL, relational database for this demo will be a local SQLite file for simplicity. You can, instead, choose to target Postgres, Amazon Aurora, Amazon Redshift or any other SQL-compatible database by rewriting a single function given in a later cell.

In [2]:
# LLM and embedding model settings
LLM_MODEL_NAME = "gpt-3.5-turbo-0125"  # "gpt-4o-mini-2024-07-18"  # "gpt-4o-2024-08-06"

EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIMENSION = 1536

# This is needed for the SQLite database
SQLITE_FILE_NAME = "sample_database.db"

### Import dependencies

In [3]:
import json
import os
import re
import sqlite3
from functools import lru_cache, partial
from getpass import getpass
from typing import Any, Callable, Dict, List, Tuple

import openai
import pandas as pd
from astrapy import DataAPIClient
from astrapy.api_options import APIOptions, SerdesOptions
from astrapy.info import CollectionDefinition
from datasets import load_dataset
from tenacity import retry, wait_exponential
from tqdm.auto import tqdm

### initialize OpenAI

In [4]:
if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")

client = openai.OpenAI()

# We equip the LLM and embedding calls with exponential backoff and retry in case the service is throttled

# @retry(wait=wait_exponential(multiplier=1.2, min=4, max=30))
def llm_completion(prompt: str) -> str:
    response = client.chat.completions.create(
        messages=[{
            "role": "user",
            "content": prompt,
        }],
        model=LLM_MODEL_NAME,
    ).choices[0].message.content
    return response


@retry(wait=wait_exponential(multiplier=1.2, min=4, max=30))
def compute_embedding(text: str) -> List[float]:
    return client.embeddings.create(
        input=text,
        model=EMBEDDING_MODEL,
        timeout=10,
    ).data[0].embedding

### Initialize Astra DB client

In [5]:
ASTRA_DB_API_ENDPOINT = os.environ.get("ASTRA_DB_API_ENDPOINT"])
if not ASTRA_DB_API_ENDPOINT:
    ASTRA_DB_API_ENDPOINT = input("Astra DB API Endpoint: ")
ASTRA_DB_APPLICATION_TOKEN = os.environ.get("ASTRA_DB_APPLICATION_TOKEN")
if not ASTRA_DB_APPLICATION_TOKEN:
    ASTRA_DB_APPLICATION_TOKEN = getpass("Astra DB Token: ")

data_api_client = DataAPIClient()
astra_db = data_api_client.get_database(
    ASTRA_DB_API_ENDPOINT,
    token=ASTRA_DB_APPLICATION_TOKEN,
)

### Initialize SQL execution

**Note:** in order to target a different database supporting SQL, all you have to do is provide a different implementation of the following function.

In [6]:
def execute_sql(*statements: List[str], raise_on_error: bool = True) -> List[Tuple[Any, ...]]:
    """Util to execute DB SQL statements and return the result of the final query"""
    with sqlite3.connect(SQLITE_FILE_NAME) as conn:
        cursor = conn.cursor()
        try:
            for statement in statements:
                cursor.execute(statement)
        except sqlite3.OperationalError as e:
            # syntax errors or similar in running the query
            if raise_on_error:
                raise
            return []
        try:
            res = cursor.fetchall()
        except sqlite3.OperationalError as e:
            # No result set (final statement was not a SELECT)
            res = []

        return res

## Load and prepare query data

For this demo, we will use the [Spider](https://yale-lily.github.io/spider) dataset, which has been a standard to evaluate generated SQL performance for a few years now. This dataset consists of `question`, `query` pairs to indicate the ideal query to be generated from a given natural language question.

NOTE: To apply this to a specific use case, the best way to collect data is to store generated SQL from an application in some live environment (dev, staging, prod all work as long as the queries are realistic). You can then either have a quick feedback interface for users to thumbs up/thumbs down the generated queries, or can have internal human evaluators grade the queries offline. As more positively graded generations are available, you will have more examples to be able to use to improve your live SQL-generating application.

In [7]:
spider = load_dataset("spider", split="validation")
spider_df = spider.to_pandas()
spider_schema = load_dataset("richardr1126/spider-schema", split="train")
spider_schema_df = spider_schema.to_pandas()

In [8]:
spider_df.head(3)

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]"
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]"
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin..."


We now select a handful of rows for testing the query generation.

The rest will be used as candidate examples for the dynamic few-shot prompts later: importantly, to avoid data leakage and have a fair test we need to remove from `spider_df` not only the test rows, but any additional row with the same `query`.

_Note: feel free to change the indices for the test set, but since the Spider dataset targets way more databases than the two we will actually populate with made-up data, keep the indices below ~80 to ensure only those two are targeted by the tests._

In [9]:
# picking some random questions for testing.
test_indices = [25, 35, 45, 55, 65, 75]

test_df = spider_df.loc[test_indices]
test_queries = set(test_df["query"])

# for us, "identical" query here means up to spacing and upper/lowercase:
norm_test_queries = {qr.replace(" ", "").lower() for qr in test_queries}

def _is_test_query(qr: str) -> bool:
    return qr.replace(" ", "").lower() in norm_test_queries

idx_to_remove = spider_df[spider_df["query"].apply(_is_test_query)].index


examples_df = spider_df.drop(idx_to_remove)

print(
    f"Originally {len(spider_df)} rows ==> split into testing with "
    f"{len(test_queries)} rows + few-shot with {len(examples_df)} rows "
    f"({len(spider_df) - len(test_queries) - len(examples_df)} discarded "
    "as having ~the same query as the tests)."
)

Originally 1034 rows ==> split into testing with 6 rows + few-shot with 1023 rows (5 discarded as having ~the same query as the tests).


In [10]:
test_queries = {
    i: row
    for i, (_, row) in enumerate(test_df.iterrows())
}

for i, row in test_queries.items():
    if row["db_id"] in {"pets_1", "concert_singer"}:
        print(f"Test question [{i}] (on DB {row['db_id']}):\n    {row['question']}")
        print(f"Its gold Query SQL:\n    {row['query']}")
        print("=" * 80)
    else:
        raise ValueError(
            f"Test question detected targeting a db out of scope ({row['db_id']}). "
            "Please lower the values in `test_indices` and try again."
        )

Test question [0] (on DB concert_singer):
    What is the name and capacity of the stadium with the most concerts after 2013 ?
Its gold Query SQL:
    select t2.name ,  t2.capacity from concert as t1 join stadium as t2 on t1.stadium_id  =  t2.stadium_id where t1.year  >  2013 group by t2.stadium_id order by count(*) desc limit 1
Test question [1] (on DB concert_singer):
    List singer names and number of concerts for each singer.
Its gold Query SQL:
    SELECT T2.name ,  count(*) FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id  =  T2.singer_id GROUP BY T2.singer_id
Test question [2] (on DB pets_1):
    Find the number of pets whose weight is heavier than 10.
Its gold Query SQL:
    SELECT count(*) FROM pets WHERE weight  >  10
Test question [3] (on DB pets_1):
    Find the number of distinct type of pets.
Its gold Query SQL:
    SELECT count(DISTINCT pettype) FROM pets
Test question [4] (on DB pets_1):
    Find the first name and age of students who have a dog but do no

## Populate database

To insert example data in the database, the `execute_sql` function defined above will be used.

In [11]:
# Set up concert_singer DB
CREATE_TABLES_SQL = """
-- Creating the stadium table
CREATE TABLE stadium (
    Stadium_ID INT PRIMARY KEY,
    Location TEXT,
    Name TEXT,
    Capacity INT,
    Highest INT,
    Lowest INT,
    Average INT
);

-- Creating the singer table
CREATE TABLE singer (
    Singer_ID INT PRIMARY KEY,
    Name TEXT,
    Country TEXT,
    Song_Name TEXT,
    Song_release_year TEXT,
    Age INT,
    Is_male BOOLEAN
);

-- Creating the concert table
CREATE TABLE concert (
    concert_ID INT PRIMARY KEY,
    concert_Name TEXT,
    Theme TEXT,
    Stadium_ID INT,
    Year TEXT,
    FOREIGN KEY (Stadium_ID) REFERENCES stadium(Stadium_ID)
);

-- Creating the singer_in_concert table
CREATE TABLE singer_in_concert (
    concert_ID INT,
    Singer_ID INT,
    PRIMARY KEY (concert_ID, Singer_ID),
    FOREIGN KEY (concert_ID) REFERENCES concert(concert_ID),
    FOREIGN KEY (Singer_ID) REFERENCES singer(Singer_ID)
);
"""

POPULATE_DATA_SQL = """
-- Populating the stadium table
INSERT INTO stadium (Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average) VALUES
(1, 'New York, USA', 'Liberty Stadium', 50000, 1000, 500, 750),
(2, 'London, UK', 'Royal Arena', 60000, 1500, 600, 900),
(3, 'Tokyo, Japan', 'Sunshine Dome', 55000, 1200, 550, 800),
(4, 'Sydney, Australia', 'Ocean Field', 40000, 900, 400, 650),
(5, 'Berlin, Germany', 'Eagle Grounds', 45000, 1100, 450, 700);

-- Populating the singer table
INSERT INTO singer (Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male) VALUES
(1, 'John Doe', 'USA', 'Freedom Song', '2018', 28, TRUE),
(2, 'Emma Stone', 'UK', 'Rolling Hills', '2019', 25, FALSE),
(3, 'Haruki Tanaka', 'Japan', 'Tokyo Lights', '2020', 30, TRUE),
(4, 'Alice Johnson', 'Australia', 'Ocean Waves', '2021', 27, FALSE),
(5, 'Max Müller', 'Germany', 'Berlin Nights', '2017', 32, TRUE);

-- Populating the concert table
INSERT INTO concert (concert_ID, concert_Name, Theme, Stadium_ID, Year) VALUES
(1, 'Freedom Fest', 'Pop', 1, '2021'),
(2, 'Rock Mania', 'Rock', 2, '2022'),
(3, 'Electronic Waves', 'Electronic', 3, '2020'),
(4, 'Jazz Evenings', 'Jazz', 3, '2019'),
(5, 'Classical Mornings', 'Classical', 5, '2023');

-- Populating the singer_in_concert table
INSERT INTO singer_in_concert (concert_ID, Singer_ID) VALUES
(1, 1),
(1, 2),
(2, 3),
(3, 4),
(4, 5),
(5, 1),
(2, 2),
(3, 3),
(4, 4),
(5, 5);
"""

statements = [
    statement.strip() for statement in (CREATE_TABLES_SQL + POPULATE_DATA_SQL).split(";")
    if len(statement.strip()) > 0
]
execute_sql(*statements)

[]

In [12]:
# Set up pets_1 DB
CREATE_TABLES_SQL = """
-- Creating the Student table
CREATE TABLE Student (
    StuID INT PRIMARY KEY,
    LName VARCHAR(255),
    Fname VARCHAR(255),
    Age INT,
    Sex VARCHAR(10),
    Major INT,
    Advisor INT,
    city_code VARCHAR(50)
);

-- Creating the Pets table
CREATE TABLE Pets (
    PetID INT PRIMARY KEY,
    PetType VARCHAR(255),
    pet_age INT,
    weight INT
);

-- Creating the Has_Pet table
CREATE TABLE Has_Pet (
    StuID INT,
    PetID INT,
    FOREIGN KEY (StuID) REFERENCES Student(StuID),
    FOREIGN KEY (PetID) REFERENCES Pets(PetID)
);
"""

POPULATE_DATA_SQL = """
-- Populating the Student table
INSERT INTO Student (StuID, LName, Fname, Age, Sex, Major, Advisor, city_code) VALUES
(101, 'Smith', 'John', 20, 'M', 501, 301, 'NYC'),
(102, 'Johnson', 'Emma', 22, 'F', 502, 302, 'LAX'),
(103, 'Williams', 'Michael', 21, 'M', 503, 303, 'CHI'),
(104, 'Brown', 'Sarah', 23, 'F', 504, 304, 'HOU'),
(105, 'Jones', 'David', 19, 'M', 505, 305, 'PHI');

-- Populating the Pets table
INSERT INTO Pets (PetID, PetType, pet_age, weight) VALUES
(201, 'dog', 3, 20.5),
(202, 'cat', 5, 10.2),
(203, 'dog', 2, 8.1),
(204, 'parrot', 4, 0.5),
(205, 'hamster', 1, 0.7);

-- Populating the Has_Pet table
INSERT INTO Has_Pet (StuID, PetID) VALUES
(101, 201),
(101, 202),
(105, 203),
(103, 204),
(104, 205),
(105, 201);
"""

statements = [
    statement.strip() for statement in (CREATE_TABLES_SQL + POPULATE_DATA_SQL).split(";")
    if len(statement.strip()) > 0
]
execute_sql(*statements)

[]

Note that the example data, of the full Spider dataset, only covers the "pets_1" and "concert_singer" databases. For this reason, we had a check earlier to ensure the chosen test queries are all targeting those databases (lest the correctness check would trivially fail).

## Evaluator function for queries

Finally, we just need an evaluation function for the generated queries, to assess their correctness.

The evaluator function, for each generated query, runs it alongside its "gold" counterpart and checks whether the two return the same data. Here, the notion of equality for such a check requires some care:

In [13]:
def eval_generated_queries(generated_queries: Dict[int, str]) -> pd.DataFrame:
    """Evaluate the given queries against the test set, and return a report of the performance on each row"""
    report = []

    for tq_index, tq_generated_sql in generated_queries.items():
        query_row = test_queries[tq_index]
        gold_sql = query_row["query"]
        gold_results = execute_sql(gold_sql)

        try:
            gen_results = execute_sql(tq_generated_sql)
            err = None

            # Figure out correctness, not super straightforward
            if len(gold_results) == len(gen_results):
                # (note: here and in the next sorting, we ignore "sorting collisions" for simplicity...)
                in_sorted_gen = [sorted(tpl, key=str) for tpl in gen_results]
                in_sorted_gold = [sorted(tpl, key=str) for tpl in gold_results]
                if "ORDER BY" in gold_sql.upper():
                    # must leave order among tuples untouched and expect it to be correct
                    correct = in_sorted_gen == in_sorted_gold
                else:
                    # sort sequence of tuples before comparison\
                    correct = sorted(in_sorted_gen) == sorted(in_sorted_gold)
            else:
                correct = False

        except Exception as e:
            gen_results = None
            err = e
            correct = False

        report.append({
            "DB ID": query_row["db_id"],
            "Question": query_row["question"],
            "Correct": correct,
            "Error": err,
        })

    return pd.DataFrame(report)

## Generating SQL 1: Zero-Shot

Now, let's go ahead and generate some SQL queries in a few different ways.

First, we'll use an LLM with a prompt template from the [SQL-PaLM](https://arxiv.org/abs/2306.00739) paper as-is to see how it does.

### SQL-generation tools

First, we have some utility functions to define.

The main purpose of the tools below is to assemble a complete prompt for an LLM, urging it to produce SQL.

#### Schema format converter

The logic below translates the Spider schema format to match the prompt template of the SQL-PaLM paper, which has proven to be effective.

In [14]:
def _format_schema(db_id: str, spider_schema_row: str) -> Tuple[str, str, str, str]:
    """Converts the existing schema format of spider_schema_df to schema, columns used by SQL-PaLM"""
    schema_str = f"| {db_id} "
    cols_str = ""

    spider_schema_vt_str = spider_schema_row["Schema (values (type))"]
    
    for table_name in re.findall(r"([^,: ]*) :", spider_schema_vt_str):
        schema_str += f"| {table_name} : "

        start_ndx = spider_schema_vt_str.find(table_name + " :")
        end_ndx = spider_schema_vt_str.find(":", start_ndx + len(table_name) + 4)
        if end_ndx == -1:
            end_ndx = len(spider_schema_vt_str)
        curr_substr = spider_schema_vt_str[start_ndx:end_ndx]
        for col_name, col_type in re.findall(r" ([^:,]*) \(([^,]*)\)", curr_substr):
            schema_str += f"{col_name} , "
            cols_str += f"{table_name} : {col_name} ({col_type}) | "

        schema_str = schema_str[:-2]
    cols_str = cols_str[:-2]

    return schema_str + ";", cols_str + ";", spider_schema_row["Primary Keys"], spider_schema_row["Foreign Keys"]

For illustration purposes, this is the effect of this conversion on the pets database:

In [15]:
spider_schema_row0 = spider_schema_df[spider_schema_df["db_id"]=="pets_1"].iloc[0]
spider_schema0 = spider_schema_row0
palm_schema0, palm_cols0, palm_pk0, palm_fk0 = _format_schema("pets_1", spider_schema0)

print(f"BEFORE CONVERSION:\n{json.dumps(spider_schema0.to_dict(), indent=2)}")
print(f"\n======\nAFTER CONVERSION:\n- SCHEMA:\n      {palm_schema0}")
print(f"- COLS:\n      {palm_cols0}")

print(f"- PR_KEYS:\n      {palm_pk0}")
print(f"- F_KEYS:\n      {palm_fk0}")

BEFORE CONVERSION:
{
  "db_id": "pets_1",
  "Schema (values (type))": "Student : StuID (number) , LName (text) , Fname (text) , Age (number) , Sex (text) , Major (number) , Advisor (number) , city_code (text) | Has_Pet : StuID (number) , PetID (number) | Pets : PetID (number) , PetType (text) , pet_age (number) , weight (number)",
  "Primary Keys": "Student : StuID | Pets : PetID",
  "Foreign Keys": "Has_Pet : StuID equals Student : StuID | Has_Pet : PetID equals Pets : PetID"
}

AFTER CONVERSION:
- SCHEMA:
      | pets_1 | Student : StuID , LName , Fname , Age , Sex , Major , Advisor , city_code | Has_Pet : StuID , PetID | Pets : PetID , PetType , pet_age , weight ;
- COLS:
      Student : StuID (number) | Student : LName (text) | Student : Fname (text) | Student : Age (number) | Student : Sex (text) | Student : Major (number) | Student : Advisor (number) | Student : city_code (text) | Has_Pet : StuID (number) | Has_Pet : PetID (number) | Pets : PetID (number) | Pets : PetType (text) 

#### Assemble prompt for SQL generation

In [16]:
SQL_PROMPT_TEMPLATE = """Convert text to SQL:

[Schema : (values)]: {schema_str}

[Column names (type)]: {cols_str}

[Primary Keys]: {pk_str}

[Foreign Keys]: {fk_str}

[Q]: {question}

[SQL]: """

QUESTION_PREFIX_STR = (
    "Given the following schema information, generate valid SQL "
    "to answer the provided query. Enclose the query in markdown "
    "code-block syntax.\n"
)

In [17]:
@lru_cache
def _get_spider_schema_by_db_id(db_id: str) -> pd.Series:
    return spider_schema_df[spider_schema_df["db_id"] == db_id].iloc[0]


def _format_sql_prompt(db_id: str, question: str) -> str:
    """
    Returns a formatted section of the prompt describing the DB Schema.
    This core logic is factored to be reusable later, in the few-shot approach.
    """
    spider_schema_row = _get_spider_schema_by_db_id(db_id)
    schema_str, cols_str, pk_str, fk_str = _format_schema(db_id, spider_schema_row)
    sql_prompt_str = SQL_PROMPT_TEMPLATE.format(
        schema_str=schema_str,
        cols_str=cols_str,
        pk_str=pk_str,
        fk_str=fk_str,
        question=question,
    )
    return sql_prompt_str


def _format_question_prompt(db_id: str, question: str) -> str:
    core_prompt = _format_sql_prompt(db_id, question)
    return QUESTION_PREFIX_STR + core_prompt


def _clean_sql_response(resp: str) -> str:
    _start = resp.find("```sql")
    if _start < 0:
        # assume the whole response is the SQL, no markdown stuff
        # (gpt-3.5 tends to do this)
        return "\n".join(l.strip() for l in resp.split("\n") if l.strip())
    rest = resp[_start+6:]
    _stop = rest.find("```")
    if _stop < 0:
        raise ValueError("Invalid answer from LLM.")
    _sql = rest[:_stop].strip()
    return _sql


def zero_shot_prompt(db_id: str, question: str) -> str:
    """Creates a Zero-Shot prompt for the given question & db_id"""
    return _format_question_prompt(db_id, question)


def generate_sql(question: str, db_id: str, prompt_fn: Callable[[str, str], str], debug_prompt: bool = False) -> str:
    """Generates a Chat response for a given Spider question & DB ID"""
    prompt = prompt_fn(db_id=db_id, question=question)

    if debug_prompt:
        print(f"LLM prompt for SQL generation:\n======\n{prompt}\n======")

    response = llm_completion(prompt)
    return _clean_sql_response(response)

Next, we can try with the questions we pulled out from the examples above.

Check out the full prompt to the LLM used to generate the query:

In [18]:
generated_queries_zeroshot = {}

In [19]:
question0 = test_queries[0]["question"]
db_id0 = test_queries[0]["db_id"]
# Showing the prompt here as well to get a sense of what the model is seeing for context
gen_sql0 = generate_sql(
    question=question0,
    db_id=db_id0,
    prompt_fn=zero_shot_prompt,
    debug_prompt=True,
)

print("\n===\nGenerated SQL:")
print(gen_sql0)

generated_queries_zeroshot[0] = gen_sql0

LLM prompt for SQL generation:
Given the following schema information, generate valid SQL to answer the provided query. Enclose the query in markdown code-block syntax.
Convert text to SQL:

[Schema : (values)]: | concert_singer | stadium : Stadium_ID , Location , Name , Capacity , Highest , Lowest , Average | singer : Singer_ID , Name , Country , Song_Name , Song_release_year , Age , Is_male | concert : concert_ID , concert_Name , Theme , Stadium_ID , Year | singer_in_concert : concert_ID , Singer_ID ;

[Column names (type)]: stadium : Stadium_ID (number) | stadium : Location (text) | stadium : Name (text) | stadium : Capacity (number) | stadium : Highest (number) | stadium : Lowest (number) | stadium : Average (number) | singer : Singer_ID (number) | singer : Name (text) | singer : Country (text) | singer : Song_Name (text) | singer : Song_release_year (text) | singer : Age (number) | singer : Is_male (others) | concert : concert_ID (number) | concert : concert_Name (text) | concert 

The other test questions can be similarly generated:

In [20]:
for test_i, test_row in test_queries.items():
    if test_i not in generated_queries_zeroshot:
        generated_queries_zeroshot[test_i] =  generate_sql(
            question=test_row["question"],
            db_id=test_row["db_id"],
            prompt_fn=zero_shot_prompt,
        )
        print(f"\n====\nAdded for Q[{test_i}]='{test_row['question']}' -> SQL:\n----")
        print(generated_queries_zeroshot[test_i])


====
Added for Q[1]='List singer names and number of concerts for each singer.' -> SQL:
----
SELECT singer.Name, COUNT(concert_ID) AS Number_of_Concerts
FROM singer
JOIN singer_in_concert ON singer.Singer_ID = singer_in_concert.Singer_ID
GROUP BY singer.Name;

====
Added for Q[2]='Find the number of pets whose weight is heavier than 10.' -> SQL:
----
SELECT COUNT(*)
FROM Pets
WHERE weight > 10;

====
Added for Q[3]='Find the number of distinct type of pets.' -> SQL:
----
```
SELECT COUNT(DISTINCT PetType) AS num_distinct_pet_types
FROM Pets;
```

====
Added for Q[4]='Find the first name and age of students who have a dog but do not have a cat as a pet.' -> SQL:
----
SELECT Fname, Age
FROM Student
WHERE StuID IN (
    SELECT StuID
    FROM Has_Pet
    JOIN Pets ON Has_Pet.PetID = Pets.PetID
    WHERE PetType = 'dog'
)
AND StuID NOT IN (
    SELECT StuID
    FROM Has_Pet
    JOIN Pets ON Has_Pet.PetID = Pets.PetID
    WHERE PetType = 'cat'
);

====
Added for Q[5]='Find the first name an

Running the evaluation on these queries yields:

In [21]:
report = eval_generated_queries(generated_queries_zeroshot)
display(report)

Unnamed: 0,DB ID,Question,Correct,Error
0,concert_singer,What is the name and capacity of the stadium w...,True,
1,concert_singer,List singer names and number of concerts for e...,True,
2,pets_1,Find the number of pets whose weight is heavie...,True,
3,pets_1,Find the number of distinct type of pets.,False,"near ""```\nSELECT COUNT(DISTINCT PetType) AS n..."
4,pets_1,Find the first name and age of students who ha...,False,
5,pets_1,Find the first name and age of students who ha...,True,


So, **most questions get a working, correct answer** already.

The next step is to see if the performance can be improved with a few-shot approach, i.e. by passing some examples to the SQL generation prompt, in order to better "guide" the LLM to the answer.

## Generating SQL 2: Few-Shot (static)

Few-Shot prompting has been a technique for a few years, since LLMs were first conceived. The theory behind few-shot prompting (also known as "In-Context Learning") is that we are showing the model examples of our desired input/output pairs, and guiding the model to make decisions more in line with what we expect for the task at hand. Since the original GPT-3 paper, [Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165), few-shot methods have been shown to greatly improve performance over zero-shot methods (aka no examples in the prompt). Because of this, let's see how our model does if we give it a few examples of correct output.

In [22]:
EXAMPLE_PREFIX_STR = "Here is an example: "

In [23]:
def _format_example_prompt(db_id: str, question: str, gold_answer: str) -> str:
    core_prompt = _format_sql_prompt(db_id, question)
    return EXAMPLE_PREFIX_STR + core_prompt + gold_answer + "\n\n"


def few_shot_prompt(db_id: str, question: str, example_indices: List[int]) -> str:
    """Creates a few shot prompt using the given indices to construct the example"""
    prompt = ""
    for e_index in example_indices:
        e_row = examples_df.loc[e_index]
        prompt += _format_example_prompt(e_row["db_id"], e_row["question"], e_row["query"])
    return prompt + _format_question_prompt(db_id, question)

In [24]:
static_few_shot_prompt_fn = partial(few_shot_prompt, example_indices=[100, 200])

NOTE: One of the biggest tradeoffs with few-shot learning is the context length. We end up sacrificing context-length (and thus increasing costs, latency, and decreasing the space left for the generation) for each example shown to the model. For examples where the input context is very long (like showing schemas for SQL queries), this can be a difficult tradeoff to make. In our scenario, we will settle for just showing 2 examples to the model.

Now, again a query is generated for each of the test question. For the first one, take a look at the full prompt with its examples:

In [25]:
generated_queries_fewshot = {}

In [26]:
question0 = test_queries[0]["question"]
db_id0 = test_queries[0]["db_id"]
# Showing the prompt here as well to get a sense of what the model is seeing for context
gen_sql0 = generate_sql(
    question=question0,
    db_id=db_id0,
    prompt_fn=static_few_shot_prompt_fn,
    debug_prompt=True,
)

print("\n===\nGenerated SQL:")
print(gen_sql0)

generated_queries_fewshot[0] = gen_sql0

LLM prompt for SQL generation:
Here is an example: Convert text to SQL:

[Schema : (values)]: | car_1 | continents : ContId , Continent | countries : CountryId , CountryName , Continent | car_makers : Id , Maker , FullName , Country | model_list : ModelId , Maker , Model | car_names : MakeId , Model , Make | cars_data : Id , MPG , Cylinders , Edispl , Horsepower , Weight , Accelerate , Year ;

[Column names (type)]: continents : ContId (number) | continents : Continent (text) | countries : CountryId (number) | countries : CountryName (text) | countries : Continent (number) | car_makers : Id (number) | car_makers : Maker (text) | car_makers : FullName (text) | car_makers : Country (text) | model_list : ModelId (number) | model_list : Maker (number) | model_list : Model (text) | car_names : MakeId (number) | car_names : Model (text) | car_names : Make (text) | cars_data : Id (number) | cars_data : MPG (text) | cars_data : Cylinders (number) | cars_data : Edispl (number) | cars_data : Hor

In [27]:
for test_i, test_row in test_queries.items():
    if test_i not in generated_queries_fewshot:
        generated_queries_fewshot[test_i] =  generate_sql(
            question=test_row["question"],
            db_id=test_row["db_id"],
            prompt_fn=static_few_shot_prompt_fn,
        )
        print(f"\n====\nAdded for Q[{test_i}]='{test_row['question']}' -> SQL:\n----")
        print(generated_queries_fewshot[test_i])


====
Added for Q[1]='List singer names and number of concerts for each singer.' -> SQL:
----
SELECT singer.Name, COUNT(singer_in_concert.Singer_ID) AS Num_Concerts
FROM singer
JOIN singer_in_concert ON singer.Singer_ID = singer_in_concert.Singer_ID
GROUP BY singer.Name;

====
Added for Q[2]='Find the number of pets whose weight is heavier than 10.' -> SQL:
----
SELECT COUNT(*) FROM PETS WHERE weight > 10;

====
Added for Q[3]='Find the number of distinct type of pets.' -> SQL:
----
SELECT COUNT(DISTINCT PetType) AS NumDistinctPets FROM PETS;

====
Added for Q[4]='Find the first name and age of students who have a dog but do not have a cat as a pet.' -> SQL:
----
SELECT Fname, Age
FROM Student
WHERE StuID IN
(SELECT StuID
FROM Has_Pet
INNER JOIN Pets ON Has_Pet.PetID = Pets.PetID
WHERE PetType = 'dog'
AND StuID NOT IN
(SELECT StuID
FROM Has_Pet
INNER JOIN Pets ON Has_Pet.PetID = Pets.PetID
WHERE PetType = 'cat'))

====
Added for Q[5]='Find the first name and age of students who have a 

Similarly as for the zero-shot approach, let's run the query evaluation:

In [28]:
report = eval_generated_queries(generated_queries_fewshot)
display(report)

Unnamed: 0,DB ID,Question,Correct,Error
0,concert_singer,What is the name and capacity of the stadium w...,True,
1,concert_singer,List singer names and number of concerts for e...,True,
2,pets_1,Find the number of pets whose weight is heavie...,True,
3,pets_1,Find the number of distinct type of pets.,True,
4,pets_1,Find the first name and age of students who ha...,False,
5,pets_1,Find the first name and age of students who ha...,True,


Few-shot generation **usually improves the situation a little**.

Because the examples stuffed in the prompt -- always the same regardless of the question -- may have little relation to the structure of the actual problem at hand, we cannot expect this approach to yield big gains on the long run.

The "dynamic" few-shot strategy, instead, will make a difference.

## Generating SQL 3: Dynamic Few-Shot

The few-shot prompting style used above can be helpful in some cases, but sometimes we need more tailored information for the model to learn from. To better inform the model, we can find similar questions to the one being asked, and use those as In-Context Learning examples (aka few-shot examples) instead!

We'll start by loading the "examples" portion of the Spider dataset into Astra DB. Once it's there, then we can create a new prompting function that picks examples similar to the question being asked, for a question-dependent few-shot prompt.

The process of looking for similar questions is being done via vector search: for this reason, we need to evaluate embedding vectors for all questions, which will be saved alongside each row of the `examples_df` DataFrame.

### Create an Astra DB collection

In [29]:
ASTRA_DB_COLLECTION_NAME = "text2sql_examples"

astra_db_collection = astra_db.create_collection(
    ASTRA_DB_COLLECTION_NAME,
    definition=(
        CollectionDefinition.builder()
        .set_vector_dimension(1536)
        .build()
    ),
    # this is to be able to directly pass ndarray objects when inserting documents:
    spawn_api_options=APIOptions(
        serdes_options=SerdesOptions(
            unroll_iterables_to_lists=True,
        ),
    ),
)

### Calculate an embedding vector for each example question

In [30]:
tqdm.pandas()

examples_embeddings = examples_df["question"].progress_apply(lambda question: compute_embedding(question))
examples_df["question_embedding"] = examples_embeddings

  0%|          | 0/1023 [00:00<?, ?it/s]

### Write the documents (with embeddings) into the collection

In [31]:
# The documents' _id in the collection will be the indes in the DataFrame:
examples_df["_id"] = examples_df.index
# For the document to insert, must use the special name "$vector" for the embedding
examples_documents = examples_df.rename(columns={"question_embedding": "$vector"}).to_dict(orient="records")

result = astra_db_collection.insert_many(
    examples_documents,
    # these are a thousand or so big documents, so let's be overzealous with timeouts:
    timeout_ms=45000,          # this is for the whole insertion (which is chunked) ...
    request_timeout_ms=20000,  # and this is for each single HTTP request
)
print(result)

CollectionInsertManyResult(inserted_ids=[0, 1, 2, 3, 4 ... (1023 total)], raw_results=...)


### Construct the dynamic few-shot prompting function

The new prompt function first runs a vector search on the Astra DB collection, then uses the results to prepare a few-shot prompt tailored to the question:

In [32]:
def dynamic_few_shot_prompt_fn(db_id: str, question: str) -> str:
    emb = compute_embedding(question)
    docs = astra_db_collection.find(sort={"$vector": emb}, limit=2)
    closest_q_ids = [
        doc["_id"] for doc in docs
    ]
    return few_shot_prompt(db_id, question, example_indices=closest_q_ids)

Let's put this new generation technique to use. In the next two cells, the whole prompt is also printed: you can verify the examples differ in the two invocations.

In [33]:
generated_queries_dynamic = {}

In [34]:
question0 = test_queries[0]["question"]
db_id0 = test_queries[0]["db_id"]
# Showing the prompt here as well to get a sense of what the model is seeing for context
gen_sql0 = generate_sql(
    question=question0,
    db_id=db_id0,
    prompt_fn=dynamic_few_shot_prompt_fn,
    debug_prompt=True,
)

print("\n===\nGenerated SQL:")
print(gen_sql0)

generated_queries_dynamic[0] = gen_sql0

LLM prompt for SQL generation:
Here is an example: Convert text to SQL:

[Schema : (values)]: | concert_singer | stadium : Stadium_ID , Location , Name , Capacity , Highest , Lowest , Average | singer : Singer_ID , Name , Country , Song_Name , Song_release_year , Age , Is_male | concert : concert_ID , concert_Name , Theme , Stadium_ID , Year | singer_in_concert : concert_ID , Singer_ID ;

[Column names (type)]: stadium : Stadium_ID (number) | stadium : Location (text) | stadium : Name (text) | stadium : Capacity (number) | stadium : Highest (number) | stadium : Lowest (number) | stadium : Average (number) | singer : Singer_ID (number) | singer : Name (text) | singer : Country (text) | singer : Song_Name (text) | singer : Song_release_year (text) | singer : Age (number) | singer : Is_male (others) | concert : concert_ID (number) | concert : concert_Name (text) | concert : Theme (text) | concert : Stadium_ID (text) | concert : Year (text) | singer_in_concert : concert_ID (number) | singe

In [35]:
question4 = test_queries[4]["question"]
db_id4 = test_queries[4]["db_id"]
# Showing the prompt here as well to get a sense of what the model is seeing for context
gen_sql4 = generate_sql(
    question=question4,
    db_id=db_id4,
    prompt_fn=dynamic_few_shot_prompt_fn,
    debug_prompt=True,
)

print("\n===\nGenerated SQL:")
print(gen_sql4)

generated_queries_dynamic[4] = gen_sql4

LLM prompt for SQL generation:
Here is an example: Convert text to SQL:

[Schema : (values)]: | pets_1 | Student : StuID , LName , Fname , Age , Sex , Major , Advisor , city_code | Has_Pet : StuID , PetID | Pets : PetID , PetType , pet_age , weight ;

[Column names (type)]: Student : StuID (number) | Student : LName (text) | Student : Fname (text) | Student : Age (number) | Student : Sex (text) | Student : Major (number) | Student : Advisor (number) | Student : city_code (text) | Has_Pet : StuID (number) | Has_Pet : PetID (number) | Pets : PetID (number) | Pets : PetType (text) | Pets : pet_age (number) | Pets : weight (number) ;

[Primary Keys]: Student : StuID | Pets : PetID

[Foreign Keys]: Has_Pet : StuID equals Student : StuID | Has_Pet : PetID equals Pets : PetID

[Q]: Find the first name of students who have cat or dog pet.

[SQL]: SELECT DISTINCT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  = 

The other queries are similarly computed (with prompt debug turned off):

In [36]:
for test_i, test_row in test_queries.items():
    if test_i not in generated_queries_dynamic:
        generated_queries_dynamic[test_i] =  generate_sql(
            question=test_row["question"],
            db_id=test_row["db_id"],
            prompt_fn=static_few_shot_prompt_fn,
        )
        print(f"\n====\nAdded for Q[{test_i}]='{test_row['question']}' -> SQL:\n----")
        print(generated_queries_dynamic[test_i])


====
Added for Q[1]='List singer names and number of concerts for each singer.' -> SQL:
----
SELECT s.Name, COUNT(sc.concert_ID) AS num_concerts
FROM singer s
JOIN singer_in_concert sc ON s.Singer_ID = sc.Singer_ID
GROUP BY s.Name;

====
Added for Q[2]='Find the number of pets whose weight is heavier than 10.' -> SQL:
----
SELECT COUNT(*)
FROM PETS
WHERE weight > 10;

====
Added for Q[3]='Find the number of distinct type of pets.' -> SQL:
----
SELECT COUNT(DISTINCT PetType) AS num_distinct_pet_types FROM PETS;

====
Added for Q[5]='Find the first name and age of students who have a pet.' -> SQL:
----
SELECT Fname, Age
FROM STUDENT
WHERE StuID IN (SELECT StuID FROM HAS_PET)


Time to run the evaluation step for all these queries:

In [37]:
report = eval_generated_queries(generated_queries_dynamic)
display(report)

Unnamed: 0,DB ID,Question,Correct,Error
0,concert_singer,What is the name and capacity of the stadium w...,True,
1,pets_1,Find the first name and age of students who ha...,True,
2,concert_singer,List singer names and number of concerts for e...,True,
3,pets_1,Find the number of pets whose weight is heavie...,True,
4,pets_1,Find the number of distinct type of pets.,True,
5,pets_1,Find the first name and age of students who ha...,True,


The improvement to dynamic few-shot prompting has proven effective to help the LLM get all questions right. **Success!**

## Conclusion

Text-to-sql is a very critical skill for an LLM, enabling a variety of use cases and GenAI-based automation opportunities.

As is often the case, an AI-based solution can be deployed at various levels of sophistication: in the case at hand, being able to provide examples _selected dynamically_ based on the query has given an extra boost to the process accuracy.

For the retrieval of the most appropriate question/query examples, Astra DB was used: its capabilities as vector database make it a perfect fit for this use case.

If you want to learn more on vector databases, GenAI, or you simply want to experiment with Astra DB, visit the [Astra DB docs](https://docs.datastax.com/en/astra/astra-db-vector/get-started/quickstart.html#create-a-serverless-vector-database).

## Additional references

- [Spider homepage](https://yale-lily.github.io/spider)
- [SQL-PaLM](https://arxiv.org/abs/2306.00739) (arXiv)
- [Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165) (arXiv)