In [1]:
import pandas as pd 

In [2]:
data = pd.read_csv("./Titanic-Dataset.csv")
data.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [3]:
len(data), data.dtypes

(891,
 PassengerId      int64
 Survived         int64
 Pclass           int64
 Name            object
 Sex             object
 Age            float64
 SibSp            int64
 Parch            int64
 Ticket          object
 Fare           float64
 Cabin           object
 Embarked        object
 dtype: object)

In [4]:
# Convert DataFrame to SQLite Database
from sqlite3 import connect

con = connect('titanic.db')
data.to_sql("Passenger Data", con, if_exists = 'replace')

In [5]:
# Create our SQL Database
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///titanic.db", sample_rows_in_table_info=3)
print(db.table_info)


CREATE TABLE "Passenger Data" (
	"index" INTEGER, 
	"PassengerId" INTEGER, 
	"Survived" INTEGER, 
	"Pclass" INTEGER, 
	"Name" TEXT, 
	"Sex" TEXT, 
	"Age" REAL, 
	"SibSp" INTEGER, 
	"Parch" INTEGER, 
	"Ticket" TEXT, 
	"Fare" REAL, 
	"Cabin" TEXT, 
	"Embarked" TEXT
)

/*
3 rows from Passenger Data table:
index	PassengerId	Survived	Pclass	Name	Sex	Age	SibSp	Parch	Ticket	Fare	Cabin	Embarked
0	1	0	3	Braund, Mr. Owen Harris	male	22.0	1	0	A/5 21171	7.25	None	S
1	2	1	1	Cumings, Mrs. John Bradley (Florence Briggs Thayer)	female	38.0	1	0	PC 17599	71.2833	C85	C
2	3	1	3	Heikkinen, Miss. Laina	female	26.0	0	0	STON/O2. 3101282	7.925	None	S
*/


In [12]:
# Initialize Language Model from OpenAI
from langchain_openai import ChatOpenAI 
import os

os.environ['OPENAI_API_KEY'] = 'key'

llm = ChatOpenAI(model="gpt-3.5-turbo",
                 temperature=0,
                 openai_api_key=os.environ.get("OPENAI_API_KEY")
                )

llm.config_schema

<bound method Runnable.config_schema of ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fcc64545eb0>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fcc64548670>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_proxy='')>

In [13]:
# Use SQLDatabase Chain
from langchain_experimental.sql.base import SQLDatabaseChain

sql_db_chain = SQLDatabaseChain.from_llm(llm = llm, db = db, verbose = True)

#sql_db_chain.invoke("How many passengers were there?")



[1m> Entering new SQLDatabaseChain chain...[0m
How many passengers were there?
SQLQuery:[32;1m[1;3mSELECT COUNT("PassengerId") AS total_passengers
FROM "Passenger Data"[0m
SQLResult: [33;1m[1;3m[(891,)][0m
Answer:[32;1m[1;3mThere were 891 passengers.[0m
[1m> Finished chain.[0m


{'query': 'How many passengers were there?',
 'result': 'There were 891 passengers.'}

In [15]:
# Custom Prompts
sql_db_chain = SQLDatabaseChain.from_llm(llm = llm, db = db, verbose = True)

connection = connect("titanic.db")
curr = connection.execute('''
SELECT "Sex", 
    CASE 
        WHEN "Age" < 20 THEN '0-19'
        WHEN "Age" >= 20 AND "Age" < 40 THEN '20-39'
        WHEN "Age" >= 40 AND "Age" < 60 THEN '40-59'
        ELSE '60+'
    END AS "Age_Bin",
    AVG("Survived") AS "Survival_Rate"
FROM "Passenger Data"
GROUP BY "Sex", "Age_Bin"
ORDER BY "Sex", "Age_Bin";
''')

print(curr.fetchall())

[('female', '0-19', 0.7066666666666667), ('female', '20-39', 0.7727272727272727), ('female', '40-59', 0.76), ('female', '60+', 0.7017543859649122), ('male', '0-19', 0.29213483146067415), ('male', '20-39', 0.18823529411764706), ('male', '40-59', 0.1839080459770115), ('male', '60+', 0.13013698630136986)]


In [16]:
curr.close()
connection.close()

In [17]:
prompt_template = '''
Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 10 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
CREATE TABLE "Passenger Data" (
	"index" INTEGER, 
	"PassengerId" INTEGER, 
	"Survived" INTEGER, 
	"Pclass" INTEGER, 
	"Name" TEXT, 
	"Sex" TEXT, 
	"Age" REAL, 
	"SibSp" INTEGER, 
	"Parch" INTEGER, 
	"Ticket" TEXT, 
	"Fare" REAL, 
	"Cabin" TEXT, 
	"Embarked" TEXT
)

/*
3 rows from Passenger Data table:
index	PassengerId	Survived	Pclass	Name	Sex	Age	SibSp	Parch	Ticket	Fare	Cabin	Embarked
0	1	0	3	Braund, Mr. Owen Harris	male	22.0	1	0	A/5 21171	7.25	None	S
1	2	1	1	Cumings, Mrs. John Bradley (Florence Briggs Thayer)	female	38.0	1	0	PC 17599	71.2833	C85	C
2	3	1	3	Heikkinen, Miss. Laina	female	26.0	0	0	STON/O2. 3101282	7.925	None	S
*/

Question: {input}
'''

In [18]:
from langchain_core.prompts import PromptTemplate
PROMPT = PromptTemplate.from_template(prompt_template, variables=['input'])

In [19]:
sql_db_chain = SQLDatabaseChain.from_llm(llm = llm, db = db, verbose = True, prompt = PROMPT)

In [20]:
sql_db_chain.invoke("Provide the gender-wise survival rate along with age bins of 20.")



[1m> Entering new SQLDatabaseChain chain...[0m
Provide the gender-wise survival rate along with age bins of 20.
SQLQuery:[32;1m[1;3mSQLQuery: 
SELECT "Sex",
       CASE
           WHEN "Age" < 20 THEN '0-19'
           WHEN "Age" >= 20 AND "Age" < 40 THEN '20-39'
           WHEN "Age" >= 40 THEN '40+'
           ELSE 'Unknown'
       END AS "Age_Bin",
       AVG("Survived") AS "Survival_Rate"
FROM "Passenger Data"
GROUP BY "Sex", "Age_Bin"
ORDER BY "Sex", "Age_Bin";[0m
SQLResult: [33;1m[1;3m[('female', '0-19', 0.7066666666666667), ('female', '20-39', 0.7727272727272727), ('female', '40+', 0.7777777777777778), ('female', 'Unknown', 0.6792452830188679), ('male', '0-19', 0.29213483146067415), ('male', '20-39', 0.18823529411764706), ('male', '40+', 0.1743119266055046), ('male', 'Unknown', 0.12903225806451613)][0m
Answer:[32;1m[1;3mThe gender-wise survival rate along with age bins of 20 are as follows:
- For females:
  - Age 0-19: 70.67%
  - Age 20-39: 77.27%
  - Age 40+: 77.78%

{'query': 'Provide the gender-wise survival rate along with age bins of 20.',
 'result': 'The gender-wise survival rate along with age bins of 20 are as follows:\n- For females:\n  - Age 0-19: 70.67%\n  - Age 20-39: 77.27%\n  - Age 40+: 77.78%\n  - Age Unknown: 67.92%\n- For males:\n  - Age 0-19: 29.21%\n  - Age 20-39: 18.82%\n  - Age 40+: 17.43%\n  - Age Unknown: 12.90%'}

In [21]:
# Create SQL Query Chain
from langchain.chains import create_sql_query_chain
sql_chain = create_sql_query_chain(llm, db)
sql_chain

RunnableAssign(mapper={
  input: RunnableLambda(...),
  table_info: RunnableLambda(...)
})
| RunnableLambda(lambda x: {k: v for (k, v) in x.items() if k not in ('question', 'table_names_to_use')})
| PromptTemplate(input_variables=['input', 'table_info'], partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below.

In [22]:
response = sql_chain.invoke({"question": "How many passengers were there?"})
response

'SELECT COUNT("PassengerId") AS TotalPassengers\nFROM "Passenger Data"'

In [23]:
# Run SQL Query on Database
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

db_execution = QuerySQLDataBaseTool(db = db)
execution_chain = sql_chain | db_execution

response = execution_chain.invoke({"question": "How many passengers were there?"})
response

'[(891,)]'

In [32]:
# Summarizing Final Results
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate

template = '''
Given the following user's quesion, corresponsing SQL query and SQL result, answer the user's question.
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer:
'''
prompt = PromptTemplate.from_template(template)

output = prompt | llm | StrOutputParser()
chain = (RunnablePassthrough.assign(query = sql_chain).assign(result = itemgetter("query") | db_execution) | output)

In [25]:
ch = RunnablePassthrough.assign(query = sql_chain)
ch.invoke({"question": "How does the prices varies with respect to passenger classes?"})

{'question': 'How does the prices varies with respect to passenger classes?',
 'query': 'SELECT "Pclass", AVG("Fare") AS "Average Fare"\nFROM "Passenger Data"\nGROUP BY "Pclass"\nORDER BY "Pclass";'}

In [26]:
ch2 = RunnablePassthrough.assign(query = sql_chain).assign(result = itemgetter("query") | db_execution)
ch2.invoke({"question": "How does the prices varies with respect to passenger classes?"})

{'question': 'How does the prices varies with respect to passenger classes?',
 'query': 'SELECT "Pclass", AVG("Fare") AS "Average Fare"\nFROM "Passenger Data"\nGROUP BY "Pclass"\nORDER BY "Pclass";',
 'result': '[(1, 84.15468749999992), (2, 20.66218315217391), (3, 13.675550101832997)]'}

In [35]:
# Showcasing in UI with Gradio
import gradio as gr

template = '''
Given the following user's quesion, corresponsing SQL query and SQL result, answer the user's question.
Question: {question}
SQL Query: {sql_query}
SQL Result: {result}
Answer:
'''
prompt = PromptTemplate.from_template(template)

In [36]:
def create_chain(question):
    db = SQLDatabase.from_uri("sqlite:///titanic.db", sample_rows_in_table_info = 3)
    sql_chain = create_sql_query_chain(llm, db)
    db_execution = QuerySQLDataBaseTool(db = db)
    output = prompt | llm | StrOutputParser()
    chain = (RunnablePassthrough.assign(sql_query = sql_chain).assign(result = itemgetter("sql_query") | db_execution) | output)

    return chain.invoke({"question": question})

In [37]:
def extract_data(user_message, history):
    question_with_history = ""
    for hist in history[-3:]:
        question_with_history += f"User: {hist[0]}\nAssistant: {hist[1]}\n"
    question_with_history += f"User: {user_message}\n"
    print("Input to LLM:\n", question_with_history)

    bot_message = create_chain(question_with_history)

    history += [[user_message, bot_message]]

    return bot_message, history

In [39]:
with gr.Blocks() as demo:
    chatbot = gr.Chatbot(label="Chat with Data")
    msg = gr.Textbox(label="Question", placeholder="Enter your question here")
    clear = gr.Button("Clear")

    def user(user_message, history):
        bot_message, history = extract_data(user_message, history)
        print("LLM Response: ", bot_message)
        return "", history
    
    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://2f249df0e24359ef95.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Input to LLM:
 User: what are the number of passengers

LLM Response:  The number of passengers in the "Passenger Data" table is 891.
