# Generating SQL for Postgres using OpenAI, ChromaDB
This notebook runs through the process of using the `vanna` Python package to generate SQL using AI (RAG + LLMs) including connecting to a database and training. If you're not ready to train on your own database, you can still try it using a sample [SQLite database](app.md).

## Setup

In [3]:
# %pip install 'vanna[chromadb,openai,postgres]'

In [1]:

from vanna.openai import OpenAI_Chat
from vanna.chromadb import ChromaDB_VectorStore


class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)
from dotenv import load_dotenv, find_dotenv
import os
_ = load_dotenv(find_dotenv())
vn = MyVanna(config={'api_key':  os.environ.get('OPENAI_API_KEY'), 'model': 'gpt-4o-mini'})
vn.connect_to_postgres(host='localhost',dbname='dvdrental', port='5432', user='postgres', password='1122')


In [2]:
res = vn.ask('show me 3 actors',visualize=False)
res

Add of existing embedding ID: 78eba354-54d8-5e4e-b45d-7aba82ec533c-sql
Number of requested results 10 is greater than number of elements in index 2, updating n_results = 2


SQL Prompt: [{'role': 'system', 'content': "You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. ===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n3. If the provided context is insufficient, please explain why it can't be generated. \n4. Please use the most relevant table(s). \n5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n6. Ensure that the output SQL is PostgreSQL-compliant and executable, and free of syntax errors. 

('SELECT * FROM actor LIMIT 3;',
    actor_id first_name last_name             last_update
 0         1   Penelope   Guiness 2013-05-26 14:47:57.620
 1         2       Nick  Wahlberg 2013-05-26 14:47:57.620
 2         3         Ed     Chase 2013-05-26 14:47:57.620,
 None)

## Training
You only need to train once. Do not train again unless you want to add more training data.

In [21]:

# The information schema query may need some tweaking depending on your database. This is a good starting point.
df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema='public'")
# df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")

# This will break up the information schema into bite-sized chunks that can be referenced by the LLM
plan = vn.get_training_plan_generic(df_information_schema)
plan


# If you like the plan, then uncomment this and run it to train
vn.train(plan=plan)



### Clean training data

In [12]:

df = vn.get_training_data()
# for i  in df['id']:
#     vn.remove_training_data(i)
# df

Unnamed: 0,id,question,content,training_data_type


## Quering

In [22]:
question = 'who is most famous actor'
question = 'show me 3 actors'
res = vn.ask(question=question, visualize=False)
res

Number of requested results 10 is greater than number of elements in index 0, updating n_results = 0


SQL Prompt: [{'role': 'system', 'content': "You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Additional Context \n\nThe following columns are in the actor_info table in the dvdrental database:\n\n|     | table_catalog   | table_schema   | table_name   | column_name   | data_type         |\n|----:|:----------------|:---------------|:-------------|:--------------|:------------------|\n|   4 | dvdrental       | public         | actor_info   | film_info     | text              |\n|   9 | dvdrental       | public         | actor_info   | actor_id      | integer           |\n|  21 | dvdrental       | public         | actor_info   | last_name     | character varying |\n| 126 | dvdrental       | public         | actor_info   | first_name    | character varying |\n\nThe following columns are in the actor table in the dvdrental database:

('SELECT first_name, last_name \nFROM actor \nLIMIT 3;',
   first_name last_name
 0   Penelope   Guiness
 1       Nick  Wahlberg
 2         Ed     Chase,
 None)

## generate summary:
Natural Language Answer

In [24]:
vn.generate_summary(question=question, df=res[1])


Using model gpt-4o-mini for 115.75 tokens (approx)


'The data contains a list of 3 actors, including their first and last names: Penelope Guiness, Nick Wahlberg, and Ed Chase.'

## misc


In [6]:

# # The following are methods for adding training data. Make sure you modify the examples to match your database.

# # DDL statements are powerful because they specify table names, colume names, types, and potentially relationships
# vn.train(ddl="""
#     CREATE TABLE IF NOT EXISTS my-table (
#         id INT PRIMARY KEY,
#         name VARCHAR(100),
#         age INT
#     )
# """)

# # Sometimes you may want to add documentation about your business terminology or definitions.
# vn.train(documentation="Our business defines OTIF score as the percentage of orders that are delivered on time and in full")

# # You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL.
# vn.train(sql="SELECT * FROM my-table WHERE name = 'John Doe'")
# # At any time you can inspect what training data the package is able to reference
# training_data = vn.get_training_data()
# training_data
# # You can remove training data if there's obsolete/incorrect information.
# vn.remove_training_data(id='1-ddl')


In [21]:

from vanna.openai import OpenAI_Chat
from vanna.chromadb import ChromaDB_VectorStore


class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)

vn = MyVanna(config={'api_key': 'sk-', 'model': 'gpt-4o-mini'})
vn.connect_to_postgres(host='localhost',dbname='dvdrental', port='5432', user='postgres', password='1122')


res = vn.ask(question='show me entries of actor table', visualize=False, print_results=False, )
print(res)


Number of requested results 10 is greater than number of elements in index 6, updating n_results = 6


Using model gpt-4o-mini for 2492.75 tokens (approx)


Insert of existing embedding ID: 75b0eabf-0456-503d-9672-09dc4d61dda9-sql
Add of existing embedding ID: 75b0eabf-0456-503d-9672-09dc4d61dda9-sql


Extracted SQL: SELECT *
FROM actor;
('SELECT *\nFROM actor;',      actor_id first_name     last_name             last_update
0           1   Penelope       Guiness 2013-05-26 14:47:57.620
1           2       Nick      Wahlberg 2013-05-26 14:47:57.620
2           3         Ed         Chase 2013-05-26 14:47:57.620
3           4   Jennifer         Davis 2013-05-26 14:47:57.620
4           5     Johnny  Lollobrigida 2013-05-26 14:47:57.620
..        ...        ...           ...                     ...
195       196       Bela        Walken 2013-05-26 14:47:57.620
196       197      Reese          West 2013-05-26 14:47:57.620
197       198       Mary        Keitel 2013-05-26 14:47:57.620
198       199      Julia       Fawcett 2013-05-26 14:47:57.620
199       200      Thora        Temple 2013-05-26 14:47:57.620

[200 rows x 4 columns], None)


## Asking the AI
Whenever you ask a new question, it will find the 10 most relevant pieces of training data and use it as part of the LLM prompt to generate the SQL.

In [8]:
res = vn.ask(question='show me entries of actor table', visualize=False, print_results=False, )
res


Number of requested results 10 is greater than number of elements in index 4, updating n_results = 4


Using model gpt-4o-mini for 2406.25 tokens (approx)


Insert of existing embedding ID: 75b0eabf-0456-503d-9672-09dc4d61dda9-sql
Add of existing embedding ID: 75b0eabf-0456-503d-9672-09dc4d61dda9-sql


Extracted SQL: SELECT *
FROM actor;


('SELECT *\nFROM actor;',
      actor_id first_name     last_name             last_update
 0           1   Penelope       Guiness 2013-05-26 14:47:57.620
 1           2       Nick      Wahlberg 2013-05-26 14:47:57.620
 2           3         Ed         Chase 2013-05-26 14:47:57.620
 3           4   Jennifer         Davis 2013-05-26 14:47:57.620
 4           5     Johnny  Lollobrigida 2013-05-26 14:47:57.620
 ..        ...        ...           ...                     ...
 195       196       Bela        Walken 2013-05-26 14:47:57.620
 196       197      Reese          West 2013-05-26 14:47:57.620
 197       198       Mary        Keitel 2013-05-26 14:47:57.620
 198       199      Julia       Fawcett 2013-05-26 14:47:57.620
 199       200      Thora        Temple 2013-05-26 14:47:57.620
 
 [200 rows x 4 columns],
 None)

In [9]:
res[1].info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 200 entries, 0 to 199
Data columns (total 4 columns):
 #   Column       Non-Null Count  Dtype         
---  ------       --------------  -----         
 0   actor_id     200 non-null    int64         
 1   first_name   200 non-null    object        
 2   last_name    200 non-null    object        
 3   last_update  200 non-null    datetime64[ns]
dtypes: datetime64[ns](1), int64(1), object(2)
memory usage: 6.4+ KB


In [10]:

from openai import OpenAI

import os
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())

client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

def submit_prompt( prompt, **kwargs) -> str:
    if prompt is None:
        raise Exception("Prompt is None")

    if len(prompt) == 0:
        raise Exception("Prompt is empty")

    # Count the number of tokens in the message log
    # Use 4 as an approximation for the number of characters per token
    num_tokens = 0
    for message in prompt:
        num_tokens += len(message["content"]) / 4

    if kwargs.get("model", None) is not None:
        model = kwargs.get("model", None)
        print(
            f"Using model {model} for {num_tokens} tokens (approx)"
        )
        response = client.chat.completions.create(
            model=model,
            messages=prompt,
            stop=None,
            temperature=0,
        )
    else:
        if num_tokens > 3500:
            model = "gpt-3.5-turbo-16k"
        else:
            model = "gpt-3.5-turbo"

        print(f"Using model {model} for {num_tokens} tokens (approx)")
        response = client.chat.completions.create(
            model=model,
            messages=prompt,
            stop=None,
            temperature=0,
        )

    # Find the first response from the chatbot that has text in it (some responses may not have text)
    for choice in response.choices:
        if "text" in choice:
            return choice.text

    # If no response with text is found, return the first response's content (which may be empty)
    return response.choices[0].message.content

def system_message(message: str) -> any:
    return {"role": "system", "content": message}

def user_message(message: str) -> any:
    return {"role": "user", "content": message}

def _sanitize_plotly_code(raw_plotly_code: str) -> str:
    # Remove the fig.show() statement from the plotly code
    plotly_code = raw_plotly_code.replace("fig.show()", "")

    return plotly_code

import re
def _extract_python_code(markdown_string: str) -> str:
    # Regex pattern to match Python code blocks
    pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"

    # Find all matches in the markdown string
    matches = re.findall(pattern, markdown_string, re.IGNORECASE)

    # Extract the Python code from the matches
    python_code = []
    for match in matches:
        python = match[0] if match[0] else match[1]
        python_code.append(python.strip())

    if len(python_code) == 0:
        return markdown_string

    return python_code[0]


def generate_plotly_code(
    question: str = None, sql: str = None, df_metadata: str = None, **kwargs
) -> str:
    if question is not None:
        system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'"
    else:
        system_msg = "The following is a pandas DataFrame "

    if sql is not None:
        system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n"

    system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}"
    print(system_message)
    
    message_log = [
        system_message(system_msg),
        user_message(
            "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
        ),
    ]

    plotly_code = submit_prompt(message_log, kwargs=kwargs)
    # import pprint
    # pprint.pprint(plotly_code)
    return _sanitize_plotly_code(_extract_python_code(plotly_code))


In [11]:
import pprint

import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go

def get_plotly_figure(
    plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
) -> plotly.graph_objs.Figure:
    """
    **Example:**
    ```python
    fig = vn.get_plotly_figure(
        plotly_code="fig = px.bar(df, x='name', y='salary')",
        df=df
    )
    fig.show()
    ```
    Get a Plotly figure from a dataframe and Plotly code.

    Args:
        df (pd.DataFrame): The dataframe to use.
        plotly_code (str): The Plotly code to use.

    Returns:
        plotly.graph_objs.Figure: The Plotly figure.
    """
    ldict = {"df": df, "px": px, "go": go}
    try:
        exec(plotly_code, globals(), ldict)

        fig = ldict.get("fig", None)
    except Exception as e:
        # Inspect data types
        numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
        categorical_cols = df.select_dtypes(
            include=["object", "category"]
        ).columns.tolist()

        # Decision-making for plot type
        if len(numeric_cols) >= 2:
            # Use the first two numeric columns for a scatter plot
            fig = px.scatter(df, x=numeric_cols[0], y=numeric_cols[1])
        elif len(numeric_cols) == 1 and len(categorical_cols) >= 1:
            # Use a bar plot if there's one numeric and one categorical column
            fig = px.bar(df, x=categorical_cols[0], y=numeric_cols[0])
        elif len(categorical_cols) >= 1 and df[categorical_cols[0]].nunique() < 10:
            # Use a pie chart for categorical data with fewer unique values
            fig = px.pie(df, names=categorical_cols[0])
        else:
            # Default to a simple line plot if above conditions are not met
            fig = px.line(df)

    if fig is None:
        return None

    if dark_mode:
        fig.update_layout(template="plotly_dark")

    return fig

def ask(question:str):
    res = vn.ask(question=question, visualize=False, print_results=False, )
    # print('\n printing res \n: ')
    # pprint.pp(res)
    # print()
    print('\n \n============');
    print('query: \n',res[0]);
    print('-----------');
    print('result: \n',res[1]);
    print('============');
    plotly_code = generate_plotly_code(question=question, sql=res[0], df_metadata=f"Running df.types gives: \n {res[1].dtypes}")
    
    fig = get_plotly_figure(plotly_code=plotly_code, df=res[1])
    
    return fig


In [12]:
question='show me 10 movies?'

In [13]:
res = vn.ask(question=question, visualize=False, print_results=False, )
res


Number of requested results 10 is greater than number of elements in index 4, updating n_results = 4


Using model gpt-4o-mini for 2588.25 tokens (approx)


Insert of existing embedding ID: 1f0d0d12-05db-5dd8-82a5-37a782d5c99f-sql
Add of existing embedding ID: 1f0d0d12-05db-5dd8-82a5-37a782d5c99f-sql


Extracted SQL: SELECT title
FROM film
LIMIT 10;


('SELECT title\nFROM film\nLIMIT 10;',
                title
 0    Chamber Italian
 1   Grosse Wonderful
 2    Airport Pollock
 3  Bright Encounters
 4   Academy Dinosaur
 5     Ace Goldfinger
 6   Adaptation Holes
 7   Affair Prejudice
 8        African Egg
 9       Agent Truman,
 None)

In [14]:
df = res[1]
plotly_code = vn.generate_plotly_code(question=question,sql=res[0],df_metadata=f'df.dtypes = {df.dtypes}')
vn.get_plotly_figure(plotly_code=plotly_code,df=df,dark_mode=False)

Using model gpt-4o-mini for 153.0 tokens (approx)


In [15]:
ask('show me 10 movies')

Number of requested results 10 is greater than number of elements in index 4, updating n_results = 4


Using model gpt-4o-mini for 2711.25 tokens (approx)


Insert of existing embedding ID: a8ab41a9-5a55-50e3-8f25-9dd925345e6a-sql
Add of existing embedding ID: a8ab41a9-5a55-50e3-8f25-9dd925345e6a-sql


Extracted SQL: SELECT title
FROM film
LIMIT 10;

 
query: 
 SELECT title
FROM film
LIMIT 10;
-----------
result: 
                title
0    Chamber Italian
1   Grosse Wonderful
2    Airport Pollock
3  Bright Encounters
4   Academy Dinosaur
5     Ace Goldfinger
6   Adaptation Holes
7   Affair Prejudice
8        African Egg
9       Agent Truman
<function system_message at 0x000001FC5ECE9260>
Using model gpt-3.5-turbo for 156.25 tokens (approx)


In [16]:
import copy

original_list = [1, [2, 3], 4]
shallow_copy = copy.deepcopy(original_list)
# shallow_copy = [i for i in original_list]

print("Original:", original_list)
print("Shallow copy:", shallow_copy)

# Modifying the nested list
original_list[1][0] = 'X'

print("\nAfter modification:")
print("Original:", original_list)
print("Shallow copy:", shallow_copy)

Original: [1, [2, 3], 4]
Shallow copy: [1, [2, 3], 4]

After modification:
Original: [1, ['X', 3], 4]
Shallow copy: [1, [2, 3], 4]


In [17]:
ask('show me top 10 rental movies')

Number of requested results 10 is greater than number of elements in index 4, updating n_results = 4


Using model gpt-4o-mini for 2563.25 tokens (approx)
Extracted SQL: SELECT f.title, COUNT(r.rental_id) AS rental_count
FROM film f
JOIN inventory i ON f.film_id = i.film_id
JOIN rental r ON i.inventory_id = r.inventory_id
GROUP BY f.title
ORDER BY rental_count DESC
LIMIT 10;

 
query: 
 SELECT f.title, COUNT(r.rental_id) AS rental_count
FROM film f
JOIN inventory i ON f.film_id = i.film_id
JOIN rental r ON i.inventory_id = r.inventory_id
GROUP BY f.title
ORDER BY rental_count DESC
LIMIT 10;
-----------
result: 
                  title  rental_count
0   Bucket Brotherhood            34
1     Rocketeer Mother            33
2       Juggler Hardly            32
3  Ridgemont Submarine            32
4       Grit Clockwork            32
5       Forward Temple            32
6        Scalawag Duck            32
7        Apache Divine            31
8    Goodfellas Salute            31
9      Rush Goodfellas            31
<function system_message at 0x000001FC5ECE9260>
Using model gpt-3.5-turbo fo

In [18]:
ask('show me top 10  drama movies, drama is a category')

Number of requested results 10 is greater than number of elements in index 5, updating n_results = 5


Using model gpt-4o-mini for 2548.25 tokens (approx)
Extracted SQL: SELECT title
FROM film_list
WHERE category = 'Drama'
LIMIT 10;

 
query: 
 SELECT title
FROM film_list
WHERE category = 'Drama'
LIMIT 10;
-----------
result: 
                  title
0          Apollo Teen
1        Beauty Grease
2   Beethoven Exorcist
3         Blade Polish
4    Bright Encounters
5          Bunch Minds
6           Chill Luck
7          Chitty Lock
8    Coneheads Smoochy
9  Confessions Maguire
<function system_message at 0x000001FC5ECE9260>
Using model gpt-3.5-turbo for 171.75 tokens (approx)


In [19]:
ask('show me top 10  drama movies, drama is a category')

Number of requested results 10 is greater than number of elements in index 6, updating n_results = 6


Using model gpt-4o-mini for 2576.0 tokens (approx)


Insert of existing embedding ID: d96d2274-66f4-57bb-a0bc-18755ef7c9c7-sql
Add of existing embedding ID: d96d2274-66f4-57bb-a0bc-18755ef7c9c7-sql


Extracted SQL: SELECT title
FROM film_list
WHERE category = 'Drama'
LIMIT 10;

 
query: 
 SELECT title
FROM film_list
WHERE category = 'Drama'
LIMIT 10;
-----------
result: 
                  title
0          Apollo Teen
1        Beauty Grease
2   Beethoven Exorcist
3         Blade Polish
4    Bright Encounters
5          Bunch Minds
6           Chill Luck
7          Chitty Lock
8    Coneheads Smoochy
9  Confessions Maguire
<function system_message at 0x000001FC5ECE9260>
Using model gpt-3.5-turbo for 171.75 tokens (approx)


### generate Questions 

In [20]:
from typing import List





def _response_language(language=None) -> str:
    if language is None:
        return ""

    return f"Respond in the {language} language."


def generate_db_questions(
    question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
) -> list:
    """
    **Example:**
    ```python
    vn.generate_followup_questions("What are the top 10 customers by sales?", sql, df)
    ```

    Generate a list of followup questions that you can ask Vanna.AI.

    Args:
        question (str): The question that was asked.
        sql (str): The LLM-generated SQL query.
        df (pd.DataFrame): The results of the SQL query.
        n_questions (int): Number of follow-up questions to generate.

    Returns:
        list: A list of followup questions that you can ask Vanna.AI.
    """

    message_log = [
        system_message(
            f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
        ),
        user_message(
            f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." +
            _response_language()
        ),
    ]

    llm_response = submit_prompt(message_log, **kwargs)

    numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
    return numbers_removed.split("\n")

def generate_followup_questions(
    question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
) -> list:
    """
    **Example:**
    ```python
    vn.generate_followup_questions("What are the top 10 customers by sales?", sql, df)
    ```

    Generate a list of followup questions that you can ask Vanna.AI.

    Args:
        question (str): The question that was asked.
        sql (str): The LLM-generated SQL query.
        df (pd.DataFrame): The results of the SQL query.
        n_questions (int): Number of follow-up questions to generate.

    Returns:
        list: A list of followup questions that you can ask Vanna.AI.
    """

    message_log = [
        system_message(
            f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
        ),
        user_message(
            f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." +
            _response_language()
        ),
    ]

    llm_response = submit_prompt(message_log, **kwargs)

    numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
    return numbers_removed.split("\n")
