# Using LangChain and LLMs to Analyze Data in Amazon Redshift

Demonstration of [LangChain SQL Chain](https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html) (`SQLDatabaseChain` and `SQLDatabaseSequentialChain`) and [SQL Database Agent](https://python.langchain.com/en/latest/modules/agents/toolkits/examples/sql_database.html) to analyze the data in an [Amazon Redshift](https://aws.amazon.com/redshift/) cloud data warehousing. Demonstration uses OpenAI's LLMs via an API.

Author: Gary A. Stafford  
Date: 2023-06-01  
License: MIT  
Kernal: `conda_python3`  
References:
- [LangChain Documentation: SQL Chain example](https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#sql-chain-example)
- [LangChain Blog: LLMs and SQL](https://blog.langchain.dev/llms-and-sql/)
- [How do davinci and text-davinci-003 differ?](https://help.openai.com/en/articles/6643408-how-do-davinci-and-text-davinci-003-differ)
- [How do text-davinci-002 and text-davinci-003 differ?](https://help.openai.com/en/articles/6779149-how-do-text-davinci-002-and-text-davinci-003-differ)

## Prerequisites

1. Import [TICKIT database](https://docs.aws.amazon.com/redshift/latest/dg/c_sampledb.html) into an [Amazon Redshift](https://aws.amazon.com/redshift/) database.

2. Create a new [Amazon SageMaker notebook instance](https://docs.aws.amazon.com/sagemaker/latest/dg/nbi.html) for this demonstration. Make sure your Redshift cluster is accessible to your SageMaker Notebook environment.

3. `git clone` this post's GitHub project to your Amazon SageMaker notebook instance.

4. Create or update the `.env` file, used by `dotenv`, using the terminal in your SageMaker Notebook environment. A sample `env.txt` file in the project.

5. Add your Amazon Redshift credentials to the `.env` file. See this post's GitHub project for an example.

6. Create an OpenAI account and update the `.env` file to include your OpenAI API Key.

__NOTE__: When using `dotenv`, credentials will be stored in plain text. The recommended and more secure method is to use [AWS Secrets Manager](https://docs.aws.amazon.com/secretsmanager/latest/userguide/intro.html).

## Required for ChromaDB in Amazon Jumpstart environment

In [None]:
!apt-get update -qq && apt-get install -y build-essential -qq

## Install Required Packages

In [None]:
# Optional: update version of pip
%pip install pip -Uq

# Currently sqlalchemy-redshift not compatible with 2.x
%pip install SQLAlchemy==1.4.48 -q

# Install latest versions of required packages
%pip install ipywidgets langchain openai python-dotenv sqlalchemy-redshift psycopg2-binary chromadb -Uq
%pip install pyyaml -q

# Avoid issues with install
# https://github.com/aws/amazon-sagemaker-examples/issues/1890#issuecomment-758871546
%pip install sentence-transformers -Uq --no-cache-dir #--force-reinstall

In [None]:
# Optional: restart kernel to update packages
# import os
# os._exit(00)

In [None]:
# Check verions of critical packages
%pip list | grep "langchain\|openai\|sentence-transformers\|SQLAlchemy"

## Setup Environment Variable

Use `dotenv` to load the OpenAI and Redshift environment variables. __NOTE__: credentials will be stored in plain text. The recommended, more secure method is to use [AWS Secrets Manager](https://docs.aws.amazon.com/secretsmanager/latest/userguide/intro.html).

In [None]:
import os

# Avoid huggingface/tokenizers parallelism error
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# Load env vars from .env file
%load_ext dotenv

# %reload_ext dotenv

%dotenv

In [None]:
# SQLAlchemy 2.0 reference: https://pypi.org/project/sqlalchemy-redshift/
# Endpoint format: redshift+psycopg2://username@host.amazonaws.com:5439/database

REDSHIFT_HOST = os.environ.get("REDSHIFT_HOST")
REDSHIFT_PORT = os.environ.get("REDSHIFT_PORT")
REDSHIFT_DATABASE = os.environ.get("REDSHIFT_DATABASE")
REDSHIFT_USERNAME = os.environ.get("REDSHIFT_USERNAME")
REDSHIFT_PASSWORD = os.environ.get("REDSHIFT_PASSWORD")
REDSHIFT_ENDPOINT = f"redshift+psycopg2://{REDSHIFT_USERNAME}:{REDSHIFT_PASSWORD}@{REDSHIFT_HOST}:{REDSHIFT_PORT}/{REDSHIFT_DATABASE}"

# print URI
REDSHIFT_ENDPOINT_PRINT = REDSHIFT_ENDPOINT.replace(
    REDSHIFT_HOST, "******.******.us-east-1.redshift.amazonaws.com"
)
REDSHIFT_ENDPOINT_PRINT = REDSHIFT_ENDPOINT_PRINT.replace(REDSHIFT_PASSWORD, "******")
print(REDSHIFT_ENDPOINT_PRINT)

## LangChain OpenAI

Use OpenAI's `text-davinci-003` LLM. See OpenAI's [Models Overview](https://platform.openai.com/docs/models/overview) for model information.

In [None]:
from langchain import SQLDatabase, SQLDatabaseChain, OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.chains import SQLDatabaseSequentialChain

In [None]:
# llm = OpenAI(model_name="text-davinci-003", temperature=0, verbose=True)
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, verbose=True)

## Using LangChain's SQL Chain

Next, we will use LangChain's [SQLDatabaseChain](https://python.langchain.com/en/latest/modules/chains/examples) and [SQLDatabaseSequentialChain](https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#sqldatabasesequentialchain) for answering questions of the MoMA database.

In [None]:
# A few sample questions
QUESTION_01 = "How many categories are there?"
QUESTION_02 = "How many rows are in the listing table?"
QUESTION_03 = "How many customers made a purchase in May 2022?"
QUESTION_04 = "What were the total sales in September 2022?"
QUESTION_05 = "Who are the top 10 buyers based on number of tickets?"
QUESTION_06 = "What are the top 3 events in terms of all time gross sales?"
QUESTION_07 = "Who are the top 5 sellers based on all time gross sales?"
QUESTION_08 = "Which venue hosted the most events?"
QUESTION_09 = (
    "How many events are in the 99.9 percentile in terms of all time gross sales?"
)

In [None]:
from sqlalchemy.exc import ProgrammingError, DataError

db = SQLDatabase.from_uri(REDSHIFT_ENDPOINT)

db_chain = SQLDatabaseSequentialChain.from_llm(
    llm, db, verbose=True, use_query_checker=True
)

try:
    db_chain(QUESTION_03)
except (ProgrammingError, ValueError, DataError) as exc:
    print(f"\n\n{exc}")

## More Options: Custom Table Info and Query Checker

According to LangChain's [documentation](https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#custom-table-info), "_In some cases, it can be useful to provide custom table information instead of using the automatically generated table definitions and the first sample_rows_in_table_info sample rows._" Of course, this is impractical when dealing with a large number of tables.

"_Sometimes the Language Model generates invalid SQL with small mistakes that can be self-corrected using the same technique used by the SQL Database Agent to try and fix the SQL using the LLM. You can simply specify this option when creating the chain._"

_
According to LangChain's [documentation](https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#adding-example-rows-from-each-table), "_Sometimes the Language Model generates invalid SQL with small mistakes that can be self-corrected using the same technique used by the SQL Database Agent to try and fix the SQL using the LLM._"

In [None]:
# reduce tokens by slimming down table definitions

custom_table_info_slim = {
    "d_category": """CREATE TABLE d_category (
    catid smallint NOT NULL,
        catgroup character varying(10),
        catname character varying(10),
        catdesc character varying(50));

/*
3 rows from d_category table:
catid	catgroup	catname	catdesc
2	Sports	NHL	National Hockey League	
4	Sports	NBA	National Basketball Association	
5	Sports	MLS	Major League Soccer	
*/""",
    "d_date": """CREATE TABLE d_date (
        dateid smallint NOT NULL,
        caldate date NOT NULL,
        day smallint NOT NULL,
        month smallint NOT NULL,
        year smallint NOT NULL,
        week smallint NOT NULL,
        qtr smallint NOT NULL,
        holiday boolean DEFAULT false,
        PRIMARY KEY (dateid));

/*
3 rows from d_date table:
dateid	caldate	day	month	year	week	qtr	holiday
1827	2008-01-01	1	1	2008	1	1	true	
1831	2008-01-05	5	1	2008	2	1	false	
1836	2008-01-10	10	1	2008	2	1	false	
*/""",
    "d_event": """CREATE TABLE d_event (
        eventid integer NOT NULL,
        venueid smallint NOT NULL,
        catid smallint NOT NULL,
        dateid smallint NOT NULL,
        eventname character varying(200),
        starttime timestamp without time zone,
        PRIMARY KEY (eventid),
        FOREIGN KEY (venueid) REFERENCES d_venue(venueid),
        FOREIGN KEY (catid) REFERENCES d_category(catid),
        FOREIGN KEY (dateid) REFERENCES d_date(dateid));

/*
3 rows from d_event table:
eventid	venueid	catid	dateid	eventname	starttime
1217	238	6	1827	Mamma Mia!	2008-01-01 20:00:00	
1433	248	6	1827	Grease	2008-01-01 19:00:00	
2811	207	7	1827	Spring Awakening	2008-01-01 15:00:00	
*/""",
    "d_user": """CREATE TABLE d_user (
        userid integer,
        username character(8),
        firstname character varying(30),
        lastname character varying(30),
        city character varying(30),
        state character(2),
        email character varying(100),
        phone character(14),
        likesports boolean,
        liketheatre boolean,
        likeconcerts boolean,
        likejazz boolean,
        likeclassical boolean,
        likeopera boolean,
        likerock boolean,
        likevegas boolean,
        likebroadway boolean,
        likemusicals boolean);

/*
3 rows from d_user table:
userid	username	firstname	lastname	city	state	email	phone	likesports	liketheatre	likeconcerts	likejazz	likeclassical	likeopera	likerock	likevegas	likebroadway	likemusicals
2	PGL08LJI	Vladimir	Humphrey	Murfreesboro	SK	Suspendisse.tristique@nonnisiAenean.edu	(783) 492-1886	NULL	NULL	NULL	true	true	NULL	NULL	true	false	true	
4	XDZ38RDD	Barry	Roy	Omaha	AB	sed@lacusUtnec.ca	(355) 452-8168	false	true	NULL	false	NULL	NULL	NULL	NULL	NULL	false	
5	AEB55QTM	Reagan	Hodge	Forest Lake	NS	Cum@accumsan.com	(476) 519-9131	NULL	NULL	true	false	NULL	NULL	true	true	false	true	
*/""",
    "d_venue": """CREATE TABLE d_venue (
        venueid smallint,
        venuename character varying(100),
        venuecity character varying(30),
        venuestate character(2),
        venueseats integer);

/*
3 rows from d_user table:
venueid	venuename	venuecity	venuestate	venueseats
1	Toyota Park	Bridgeview	IL	0	
3	RFK Stadium	Washington	DC	0	
6	New York Giants Stadium	East Rutherford	NJ	80242	
*/""",
    "f_listing": """CREATE TABLE f_listing (
        listid integer,
        sellerid integer,
        eventid integer,
        dateid smallint,
        numtickets smallint,
        priceperticket numeric(8, 2),
        totalprice numeric(8, 2),
        listtime timestamp without time zone);

/*
3 rows from f_listing table:
listid	sellerid	eventid	dateid	numtickets	priceperticket	totalprice	listtime
614	25339	770	1827	10	236	2360	2008-01-01 05:07:30	
776	20797	1811	1827	18	133	2394	2008-01-01 06:59:39	
2092	42560	8609	1827	22	194	4268	2008-01-01 05:49:06	
*/""",
    "f_sales": """CREATE TABLE f_sales (
        salesid integer,
        listid integer,
        sellerid integer,
        buyerid integer,
        eventid integer,
        dateid smallint,
        qtysold smallint,
        pricepaid numeric(8, 2),
        commission numeric(8, 2),
        saletime timestamp without time zone);

3 rows from f_sales table:
salesid	listid	sellerid	buyerid	eventid	dateid	qtysold	pricepaid	commission	saletime
33095	36572	30047	660	2903	1827	2	234	35.1	2008-01-01 09:41:06	
88268	100813	45818	698	8649	1827	4	836	125.4	2008-01-01 07:26:20	
110917	127048	37631	116	1749	1827	1	337	50.55	2008-01-01 07:05:02	
*/""",
}

In [None]:
print(custom_table_info_slim["d_event"])

In [None]:
db = SQLDatabase.from_uri(
    REDSHIFT_ENDPOINT,
    include_tables=[
        "d_category",
        "d_date",
        "d_event",
        "d_user",
        "d_venue",
        "f_listing",
        "f_sales",
    ],
    sample_rows_in_table_info=3,
    custom_table_info=custom_table_info_slim,
)

db_chain = SQLDatabaseSequentialChain.from_llm(
    llm, db, verbose=True, use_query_checker=True, top_k=3
)

try:
    db_chain(QUESTION_08)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")

## Customize Prompt and Return Intermediate Steps

For this part of the demonstration, we will also use a `PromptTemplate`. LangChain's [Prompt Templates](https://python.langchain.com/en/latest/modules/prompts/prompt_templates.html). According to LangChain, "_A prompt template refers to a reproducible way to generate a prompt. It contains a text string (“the template”), that can take in a set of parameters from the end user and generate a prompt._"

According to LangChain's [documentation](https://python.langchain.com/en/latest/modules/chains/examples/sqlite.html#return-intermediate-steps), "_You can also return the intermediate steps of the `SQLDatabaseChain`. This allows you to access the SQL statement that was generated, as well as the result of running that against the SQL Database._"

In [None]:
from langchain.prompts.prompt import PromptTemplate

_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}

Synonyms for category include categories.
Synonyms for user include customer, buyer, and seller.


Question: {input}"""

PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)

# Revert to db without custom_table_info
# Could overflow context window (max prompt+completion length) of 4097
db = SQLDatabase.from_uri(REDSHIFT_ENDPOINT)

db_chain = SQLDatabaseChain.from_llm(
    llm,
    db,
    prompt=PROMPT,
    verbose=True,
    use_query_checker=True,
    return_intermediate_steps=True,
)

try:
    result = db_chain(QUESTION_07)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")

result["intermediate_steps"]

## Using Few-shot Learning

To improve the accuracy of the SQL query, LangChain allows us to use few-shot learning (aka few-shot prompting). According to [Wikipedia](https://en.wikipedia.org/wiki/In-context_learning_(natural_language_processing), "_In natural language processing, in-context learning, few-shot learning or few-shot prompting is a prompting technique that allows a model to process examples before attempting a task. The method was popularized after the advent of GPT-3 and is considered to be an emergent property of large language models._"

In [None]:
from typing import Dict
import yaml

chain = SQLDatabaseChain.from_llm(
    llm, db, verbose=True, return_intermediate_steps=True, use_query_checker=True
)


def _parse_example(result: Dict) -> Dict:
    sql_cmd_key = "sql_cmd"
    sql_result_key = "sql_result"
    table_info_key = "table_info"
    input_key = "input"
    final_answer_key = "answer"

    _example = {
        "input": result.get("query"),
    }

    steps = result.get("intermediate_steps")
    answer_key = sql_cmd_key  # the first one
    for step in steps:
        # The steps are in pairs, a dict (input) followed by a string (output).
        # Unfortunately there is no schema but you can look at the input key of the
        # dict to see what the output is supposed to be
        if isinstance(step, dict):
            # Grab the table info from input dicts in the intermediate steps once
            if table_info_key not in _example:
                _example[table_info_key] = step.get(table_info_key)

            if input_key in step:
                if step[input_key].endswith("SQLQuery:"):
                    answer_key = sql_cmd_key  # this is the SQL generation input
                if step[input_key].endswith("Answer:"):
                    answer_key = final_answer_key  # this is the final answer input
            elif sql_cmd_key in step:
                _example[sql_cmd_key] = step[sql_cmd_key]
                answer_key = sql_result_key  # this is SQL execution input
        elif isinstance(step, str):
            # The preceding element should have set the answer_key
            _example[answer_key] = step
    return _example


example: any
try:
    result = chain(QUESTION_07)
    print("\n*** Query succeeded")
    example = _parse_example(result)
except Exception as exc:
    print("\n*** Query failed")
    result = {"query": QUESTION_07, "intermediate_steps": exc.intermediate_steps}
    example = _parse_example(result)


# print for now, in reality you may want to write this out to a YAML file or database for manual fix-ups offline
yaml_example = yaml.dump(example, allow_unicode=True)
print("\n" + yaml_example)

In [None]:
# Use the corrected examples for few shot prompt examples
SQL_SAMPLES = None

with open("../few_shot_examples/sql_examples_redshift_slim.yaml", "r") as stream:
    SQL_SAMPLES = yaml.safe_load(stream)

print(yaml.dump(SQL_SAMPLES[0], allow_unicode=True))

In [None]:
from langchain import FewShotPromptTemplate, PromptTemplate
from langchain.chains.sql_database.prompt import _postgres_prompt, PROMPT_SUFFIX
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.prompts.example_selector.semantic_similarity import (
    SemanticSimilarityExampleSelector,
)
from langchain.vectorstores import Chroma

example_prompt = PromptTemplate(
    input_variables=["table_info", "input", "sql_cmd", "sql_result", "answer"],
    template="{table_info}\n\nQuestion: {input}\nSQLQuery: {sql_cmd}\nSQLResult: {sql_result}\nAnswer: {answer}",
)

examples_dict = SQL_SAMPLES

local_embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

example_selector = SemanticSimilarityExampleSelector.from_examples(
    # This is the list of examples available to select from.
    examples_dict,
    # This is the embedding class used to produce embeddings which are used to measure semantic similarity.
    local_embeddings,
    # This is the VectorStore class that is used to store the embeddings and do a similarity search over.
    Chroma,  # type: ignore
    # This is the number of examples to produce and include per prompt
    k=min(3, len(examples_dict)),
)

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=_postgres_prompt + "Here are some examples:",
    suffix=PROMPT_SUFFIX,
    input_variables=["table_info", "input", "top_k"],
)

In [None]:
db_chain = SQLDatabaseChain.from_llm(
    llm,
    db,
    prompt=few_shot_prompt,
    use_query_checker=True,
    verbose=True,
    return_intermediate_steps=True,
)

try:
    result = db_chain(QUESTION_07)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")

## LangChain SQL Database Agent

According to LangChain [documentation](https://python.langchain.com/en/latest/modules/agents/toolkits/examples/sql_database.html#sql-database-agent), the SQL Database Agent "_builds off of `SQLDatabaseChain` and is designed to answer more general questions about a database, as well as recover from errors._" __NOTE__: _it is not guaranteed that the agent won’t perform DML statements on your database given certain questions. Be careful running it on sensitive data!_"

In [None]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase

In [None]:
# Example of describing a table using the agent
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

agent_executor = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)

try:
    agent_executor.run("Describe the d_venue table.")
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")

In [None]:
# Example of running queries using the agent
try:
    agent_executor.run(QUESTION_07)
except (ProgrammingError, ValueError) as exc:
    print(f"\n\n{exc}")