In [1]:
import pandas as pd 

In [2]:
data = pd.read_csv("./diabetes_chat.csv")
data.head()

Unnamed: 0,bmi,A1c,glucose,diabetic
0,25.19,6.6,140,0
1,27.32,6.6,80,0
2,27.32,5.7,158,0
3,23.45,5.0,155,0
4,20.14,4.8,155,0


In [3]:
len(data), data.dtypes

(100000,
 bmi          float64
  A1c         float64
  glucose       int64
  diabetic      int64
 dtype: object)

In [4]:
# Convert DataFrame to SQLite Database
from sqlite3 import connect

con = connect('diabetes_chat.db')
data.to_sql("Diabetes Data", con, if_exists = 'replace')

  sql.to_sql(


In [5]:
# Create our SQL Database
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///diabetes_chat.db", sample_rows_in_table_info=3)
print(db.table_info)


CREATE TABLE "Diabetes Data" (
	"index" INTEGER, 
	bmi REAL, 
	" A1c" REAL, 
	" glucose" INTEGER, 
	" diabetic" INTEGER
)

/*
3 rows from Diabetes Data table:
index	bmi	 A1c	 glucose	 diabetic
0	25.19	6.6	140	0
1	27.32	6.6	80	0
2	27.32	5.7	158	0
*/


In [6]:
# Initialize Language Model from OpenAI
from langchain_openai import ChatOpenAI 
import os

os.environ['OPENAI_API_KEY'] = 'key'

llm = ChatOpenAI(model="gpt-3.5-turbo",
                 temperature=0,
                 openai_api_key=os.environ.get("OPENAI_API_KEY")
                )

llm.config_schema

<bound method Runnable.config_schema of ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7ff710e82820>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7ff710e8af70>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_proxy='')>

In [7]:
# Use SQLDatabase Chain
from langchain_experimental.sql.base import SQLDatabaseChain

sql_db_chain = SQLDatabaseChain.from_llm(llm = llm, db = db, verbose = True)

sql_db_chain.invoke("What was the average BMI?")



[1m> Entering new SQLDatabaseChain chain...[0m
What was the average BMI?
SQLQuery:[32;1m[1;3mSELECT AVG(bmi) AS average_bmi FROM "Diabetes Data"[0m
SQLResult: [33;1m[1;3m[(27.32076709999422,)][0m
Answer:[32;1m[1;3mThe average BMI was 27.32.[0m
[1m> Finished chain.[0m


{'query': 'What was the average BMI?', 'result': 'The average BMI was 27.32.'}

In [8]:
prompt_template = '''
Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 10 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
CREATE TABLE "Diabetes Data" (
	"bmi" REAL, 
	"A1c" REAL, 
	"glucose" INTEGER, 
	"diabetic" INTEGER
)

Question: {input}
'''

In [9]:
from langchain_core.prompts import PromptTemplate
PROMPT = PromptTemplate.from_template(prompt_template, variables=['input'])

In [10]:
sql_db_chain = SQLDatabaseChain.from_llm(llm = llm, db = db, verbose = True, prompt = PROMPT)

In [11]:
sql_db_chain.invoke("Provide the glucose-wise diabetes rate based on the data.")



[1m> Entering new SQLDatabaseChain chain...[0m
Provide the glucose-wise diabetes rate based on the data.
SQLQuery:[32;1m[1;3mSELECT "glucose", AVG("diabetic") AS "diabetes_rate"
FROM "Diabetes Data"
GROUP BY "glucose"
ORDER BY "glucose" ASC
LIMIT 10;[0m
SQLResult: [33;1m[1;3m[('glucose', 0.0)][0m
Answer:[32;1m[1;3mAnswer: The diabetes rate based on glucose levels is 0.0 for all the available data.[0m
[1m> Finished chain.[0m


{'query': 'Provide the glucose-wise diabetes rate based on the data.',
 'result': 'Answer: The diabetes rate based on glucose levels is 0.0 for all the available data.'}

In [12]:
# Create SQL Query Chain
from langchain.chains import create_sql_query_chain
sql_chain = create_sql_query_chain(llm, db)
sql_chain

RunnableAssign(mapper={
  input: RunnableLambda(...),
  table_info: RunnableLambda(...)
})
| RunnableLambda(lambda x: {k: v for (k, v) in x.items() if k not in ('question', 'table_names_to_use')})
| PromptTemplate(input_variables=['input', 'table_info'], partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below.

In [13]:
response = sql_chain.invoke({"question": "How many patients were there?"})
response

'SELECT COUNT(*) AS "Total Patients" FROM "Diabetes Data"'

In [14]:
# Run SQL Query on Database
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

db_execution = QuerySQLDataBaseTool(db = db)
execution_chain = sql_chain | db_execution

response = execution_chain.invoke({"question": "How many patients were there?"})
response

'[(100000,)]'

In [15]:
# Summarizing Final Results
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate

template = '''
Given the following user's quesion, corresponsing SQL query and SQL result, answer the user's question.
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer:
'''
prompt = PromptTemplate.from_template(template)

output = prompt | llm | StrOutputParser()
chain = (RunnablePassthrough.assign(query = sql_chain).assign(result = itemgetter("query") | db_execution) | output)

In [16]:
ch = RunnablePassthrough.assign(query = sql_chain)
ch2 = RunnablePassthrough.assign(query = sql_chain).assign(result = itemgetter("query") | db_execution)

In [17]:
# Showcasing in UI with Gradio
import gradio as gr

template = '''
Given the following user's quesion, corresponsing SQL query and SQL result, answer the user's question.
Question: {question}
SQL Query: {sql_query}
SQL Result: {result}
Answer:
'''
prompt = PromptTemplate.from_template(template)

In [18]:
def create_chain(question):
    db = SQLDatabase.from_uri("sqlite:///diabetes_chat.db", sample_rows_in_table_info = 3)
    sql_chain = create_sql_query_chain(llm, db)
    db_execution = QuerySQLDataBaseTool(db = db)
    output = prompt | llm | StrOutputParser()
    chain = (RunnablePassthrough.assign(sql_query = sql_chain).assign(result = itemgetter("sql_query") | db_execution) | output)

    return chain.invoke({"question": question})

In [19]:
def extract_data(user_message, history):
    question_with_history = ""
    for hist in history[-3:]:
        question_with_history += f"User: {hist[0]}\nAssistant: {hist[1]}\n"
    question_with_history += f"User: {user_message}\n"
    print("Input to LLM:\n", question_with_history)

    bot_message = create_chain(question_with_history)

    history += [[user_message, bot_message]]

    return bot_message, history

In [20]:
with gr.Blocks() as demo:
    chatbot = gr.Chatbot(label="Chat with Data")
    msg = gr.Textbox(label="Question", placeholder="Enter your question here")
    clear = gr.Button("Clear")

    def user(user_message, history):
        bot_message, history = extract_data(user_message, history)
        print("LLM Response: ", bot_message)
        return "", history
    
    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://b88eaf9b06599a9c1a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Input to LLM:
 User: How many patients are there in the dataset?

LLM Response:  There are 100,000 patients in the dataset.
