**Prerequisites**

Before you begin, ensure that you have the following:

- A database named `titanic`. You can create one in just a few clicks using the Neon console. See [Create a database](https://neon.tech/docs/manage/databases#create-a-database).
- A connection string for your `titanic` database. You can copy it from the **Connection Details** widget on the Neon Dashboard. See [Connect from any application](https://neon.tech/docs/connect/connect-from-any-app) for instructions.
- Your OpenAI API key. If you do not have an OpenAI API key, obtain one from [https://platform.openai.com/account/api-keys](https://platform.openai.com/account/api-keys).

**Install the required modules**

This notebook requires the `LangChain`, `LangChain-experimental`, and `openai` modules. Run this code block to install them.

In [None]:
! pip install LangChain LangChain-experimental openai

**Provide your OpenAI API key**

Run this code block and provide your OpenAI API key when prompted.

In [1]:
from getpass import getpass
import os

# Directly prompt for the API key
api_key = getpass("Enter your OPENAI_API_KEY: ")

if api_key:
    print("Your OPENAI_API_KEY is now available for this session")
    # Optionally, you can set it as an environment variable for the current session
    os.environ["OPENAI_API_KEY"] = api_key
else:
    print("You did not enter your OPENAI_API_KEY")


Enter your OPENAI_API_KEY: ··········
Your OPENAI_API_KEY is now available for this session


**Provide your database connection string and open a cursor**

Input the connection string for your Neon database and run the code block.

In [2]:
import os
import psycopg2
from urllib.parse import urlparse

# Provide your Neon connection string
connection_string = "postgres://<user>:<password>@<hostname>/<dbname>"

# Extract details from connection string
parsed_uri = urlparse(connection_string)
username = parsed_uri.username
password = parsed_uri.password
host = parsed_uri.hostname
port = parsed_uri.port or 5432
database = parsed_uri.path[1:]  # remove leading '/'

# Connect using the connection string
connection = psycopg2.connect(connection_string)

# Create a new cursor object
cursor = connection.cursor()

**Test your database connection**

Run this code block to test your database connection.

In [3]:
# Execute this query to test the database connection
cursor.execute("SELECT 1;")
result = cursor.fetchone()

# Check the query result
if result == (1,):
    print("Your database connection was successful!")
else:
    print("Your connection failed.")

Your database connection was successful!


**Create a table for passenger data**

Run this code block to create a `passenger` table in your `titanic` database.

In [None]:
create_table_sql = '''
CREATE TABLE public.passenger (
    passengerid integer NOT NULL,
    survived double precision,
    pclass integer,
    name text,
    sex text,
    age double precision,
    sibsp integer,
    parch integer,
    ticket text,
    fare double precision,
    cabin text,
    embarked text,
    wikiid double precision,
    name_wiki text,
    age_wiki double precision,
    hometown text,
    boarded text,
    destination text,
    lifeboat text,
    body text,
    class integer
);
'''

# Execute the SQL statement
cursor.execute(create_table_sql)

# Commit the changes
connection.commit()

**Load data into the passenger table**

Run this code block to load data into the `passenger` table in your `titanic` database from a `titanic.csv` file hosted on GitHub.

In [None]:
import io

# Download the file from GitHub
!wget https://github.com/neondatabase/postgres-sample-dbs/raw/main/titanic.sql -O titanic.csv

# Path to the downloaded CSV file
csv_file_path = 'titanic.csv'

# Define a generator function to process the csv file
def process_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            yield line

# Create a StringIO object to store the modified lines
modified_lines = io.StringIO(''.join(list(process_file(csv_file_path))))

# Create the COPY command for copy_expert
copy_command = '''
COPY public.passenger (passengerid, survived, pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked, wikiid, name_wiki, age_wiki, hometown, boarded, destination, lifeboat, body, class)
FROM STDIN WITH (FORMAT CSV, HEADER true, DELIMITER ',');
'''

# Assuming you have already set up a connection and cursor...
cursor.copy_expert(copy_command, modified_lines)

# Commit the changes
connection.commit()

**Prompt for a question**

Run this code block to ask your database a question. When prompted, enter a question like, "How many passengers survived?" or "What was the average age of passengers?"

In [None]:
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain
from langchain.prompts import PromptTemplate

# Setup database
db = SQLDatabase.from_uri(
    f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{database}"
)

# Setup llm
llm = OpenAI(temperature=0, openai_api_key=os.environ["OPENAI_API_KEY"])

# Define table_info and few_shot_examples
table_info = """public.passenger (
    passengerid integer NOT NULL,
    survived double precision,
    pclass integer,
    name text,
    sex text,
    age double precision,
    sibsp integer,
    parch integer,
    ticket text,
    fare double precision,
    cabin text,
    embarked text,
    wikiid double precision,
    name_wiki text,
    age_wiki double precision,
    hometown text,
    boarded text,
    destination text,
    lifeboat text,
    body text,
    class integer
)"""

few_shot_examples = """
- Question: "How many passengers survived?"
  SQLQuery: "SELECT COUNT(*) FROM public.passenger WHERE survived = 1;"

- Question: "What was the average age of passengers?"
  SQLQuery: "SELECT AVG(age) FROM public.passenger;"

- Question: "How many male and female passengers were there?"
  SQLQuery: "SELECT sex, COUNT(*) FROM public.passenger GROUP BY sex;"

- Question: "Which passenger had the highest fare?"
  SQLQuery: "SELECT name, fare FROM public.passenger WHERE fare IS NOT NULL ORDER BY fare DESC LIMIT 1;"

- Question: "How many passengers boarded from each location?"
  SQLQuery: "SELECT embarked, COUNT(*) FROM public.passenger GROUP BY embarked;"

- Question: "Who is the oldest passenger and what was their age?":
  SQLQuery: SELECT name, age FROM public.passenger WHERE age IS NOT NULL ORDER BY age DESC LIMIT 1;
"""

# Define Custom Prompt
TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}

Some examples of SQL queries that correspond to questions are:

{few_shot_examples}

Question: {input}"""

CUSTOM_PROMPT = PromptTemplate(
    input_variables=["input", "few_shot_examples", "table_info", "dialect"], template=TEMPLATE
)

# Setup the database chain
# Setup the database chain
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)

def get_prompt():
    print("Type 'exit' to quit")
    while True:
        prompt = input("Ask a question or type exit to quit: ")

        if prompt.lower() == 'exit':
            print('Exiting...')
            break
        else:
            try:
                question = CUSTOM_PROMPT.format(
                    input=prompt,
                    few_shot_examples=few_shot_examples,
                    table_info=table_info,
                    dialect="PostgreSQL"
                )
                print(db_chain.run(question))
            except Exception as e:
                print(e)

get_prompt()