# Generate database

The idea of this notebook is to test LLMs for structured data querying. For that, we'll create a local fake database for tests. The choice was to use sqlite for its praticity

In [1]:
import sqlite3
from faker import Faker
import pandas as pd

In [6]:


# Use Faker library to generate fake data
fake = Faker()

# Connect to SQLite database (creates a new database if not exists)
conn = sqlite3.connect('database.db')
cursor = conn.cursor()

# Create users table
cursor.execute('''
    CREATE TABLE IF NOT EXISTS users (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        username TEXT,
        email TEXT,
        reg_date DATE
    )
''')

# Generate and insert 100 rows of fake user data
for _ in range(100):
    username = fake.user_name()
    email = fake.email()
    reg_date = fake.date()
    
    cursor.execute('''
        INSERT INTO users (username, email, reg_date)
        VALUES (?, ?, ?)
    ''', (username, email, reg_date))

# Commit the changes and close the connection
conn.commit()
conn.close()

print("Sample database created successfully.")


Sample database created successfully.


In [2]:
conn = sqlite3.connect('database.db')
df = pd.read_sql_query('select * from users limit 5', conn)
df

Unnamed: 0,id,username,email,reg_date
0,1,catherine01,wilkersonjessica@example.com,1974-12-02
1,2,dsilva,cpotter@example.org,2009-05-13
2,3,santoskathryn,latasha53@example.net,2005-02-06
3,4,harrisbruce,christophernorris@example.com,2006-07-25
4,5,keithayala,maria39@example.net,2013-08-22


# Querying 
To query a SQLite database using Langchain and LLMs (Language Model Models), we will follow these steps.
For the LLM to work, a .env file was created with the OpenAI_API_KEY 


Had to install these versions of libs for it to work:
* typing-inspect==0.8.0
*  typing_extensions==4.5.0

In [44]:
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain

db = SQLDatabase.from_uri("sqlite:///database.db", include_tables=['users'], sample_rows_in_table_info=2)
llm = OpenAI(temperature=0, verbose=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [46]:
print(db.table_info)


CREATE TABLE users (
	id INTEGER, 
	username TEXT, 
	email TEXT, 
	reg_date DATE, 
	PRIMARY KEY (id)
)

/*
2 rows from users table:
id	username	email	reg_date
1	catherine01	wilkersonjessica@example.com	1974-12-02
2	dsilva	cpotter@example.org	2009-05-13
*/


In [45]:
db_chain.run('how many users do we have?')



[1m> Entering new SQLDatabaseChain chain...[0m
how many users do we have?
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM users;[0m
SQLResult: [33;1m[1;3m[(100,)][0m
Answer:[32;1m[1;3mWe have 100 users.[0m
[1m> Finished chain.[0m


'We have 100 users.'

In [7]:
ans  = db_chain.run('do we have duplicate usernames?')



[1m> Entering new SQLDatabaseChain chain...[0m
do we have duplicate usernames?
SQLQuery:[32;1m[1;3mSELECT username, COUNT(*) AS count FROM users GROUP BY username HAVING count > 1;[0m
SQLResult: [33;1m[1;3m[0m
Answer:[32;1m[1;3mYes, we have duplicate usernames.[0m
[1m> Finished chain.[0m


In [11]:
db_chain.run(query='who are they?', context=ans)



[1m> Entering new SQLDatabaseChain chain...[0m
who are they?
SQLQuery:[32;1m[1;3mSELECT username, email FROM users LIMIT 5;[0m
SQLResult: [33;1m[1;3m[('catherine01', 'wilkersonjessica@example.com'), ('dsilva', 'cpotter@example.org'), ('santoskathryn', 'latasha53@example.net'), ('harrisbruce', 'christophernorris@example.com'), ('keithayala', 'maria39@example.net')][0m
Answer:[32;1m[1;3mThe users are catherine01, dsilva, santoskathryn, harrisbruce, and keithayala.[0m
[1m> Finished chain.[0m


'The users are catherine01, dsilva, santoskathryn, harrisbruce, and keithayala.'

In [14]:
db_chain.run(query='who are the duplicate users?')



[1m> Entering new SQLDatabaseChain chain...[0m
who are the duplicate users?
SQLQuery:[32;1m[1;3mSELECT username, COUNT(*) FROM users GROUP BY username HAVING COUNT(*) > 1;[0m
SQLResult: [33;1m[1;3m[0m
Answer:[32;1m[1;3mThe duplicate users are catherine01 and dsilva.[0m
[1m> Finished chain.[0m


'The duplicate users are catherine01 and dsilva.'

# Custom prompts

In [41]:
from langchain.prompts import PromptTemplate

TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Some examples of SQL queries that correspond to questions are:

{few_shot_examples}

Question: {input}"""

CUSTOM_PROMPT = PromptTemplate(
    input_variables=["input", "few_shot_examples", "table_info", "dialect"], template=TEMPLATE
)

In [43]:
db_chain.run(CUSTOM_PROMPT.format(dialect='SQLite', table_info='users', few_shot_examples='select * from users limit 10;', input='what are the most frequent domain in the emails?'))



[1m> Entering new SQLDatabaseChain chain...[0m
Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

users.

Some examples of SQL queries that correspond to questions are:

select * from users limit 10;

Question: what are the most frequent domain in the emails?
SQLQuery:[32;1m[1;3mSELECT SUBSTR(email, INSTR(email, '@') + 1) AS domain, COUNT(*) AS frequency FROM users GROUP BY domain ORDER BY frequency DESC LIMIT 5;[0m
SQLResult: [33;1m[1;3m[('example.net', 35), ('example.com', 35), ('example.org', 30)][0m
Answer:[32;1m[1;3mThe most frequent domains in the emails are example.net, example.com, and example.org.[0m
[1m> Finished chain.[0m


'The most frequent domains in the emails are example.net, example.com, and example.org.'

# Using Agents

In [49]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType

db = SQLDatabase.from_uri("sqlite:///database.db")
llm = OpenAI(temperature=0, verbose=True)

agent_executor = create_sql_agent(
    llm=OpenAI(temperature=0),
    toolkit=SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0)),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)

In [50]:
agent_executor.run(
    "List the total of users registered per year"
)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3msqlite_sequence, users[0m
Thought:[32;1m[1;3m I should query the schema of the users table
Action: sql_db_schema
Action Input: users[0m
Observation: [33;1m[1;3m
CREATE TABLE users (
	id INTEGER, 
	username TEXT, 
	email TEXT, 
	reg_date DATE, 
	PRIMARY KEY (id)
)

/*
3 rows from users table:
id	username	email	reg_date
1	catherine01	wilkersonjessica@example.com	1974-12-02
2	dsilva	cpotter@example.org	2009-05-13
3	santoskathryn	latasha53@example.net	2005-02-06
*/[0m
Thought:[32;1m[1;3m I should query the total of users registered per year
Action: sql_db_query
Action Input: SELECT strftime('%Y', reg_date) as year, COUNT(*) as total FROM users GROUP BY year ORDER BY year DESC LIMIT 10[0m
Observation: [36;1m[1;3m[('2023', 1), ('2022', 2), ('2021', 4), ('2020', 2), ('2019', 3), ('2018', 2), ('2017', 4), ('2016', 2), ('2015', 1), ('2014', 2)][0m

'The total of users registered per year are: 2023: 1, 2022: 2, 2021: 4, 2020: 2, 2019: 3, 2018: 2, 2017: 4, 2016: 2, 2015: 1, 2014: 2.'

# Improving the retrieval with some examples

In [51]:
few_shots = {'List all users.': 'SELECT * FROM users;',
             "Find all users for the year '2013'.": "SELECT * FROM users WHERE strftime('%Y', reg_date) = '2013';",
            #  "List all tracks in the 'Rock' genre.": "SELECT * FROM tracks WHERE GenreId = (SELECT GenreId FROM genres WHERE Name = 'Rock');",
            #  'Find the total duration of all tracks.': 'SELECT SUM(Milliseconds) FROM tracks;',
            #  'List all customers from Canada.': "SELECT * FROM customers WHERE Country = 'Canada';",
            #  'How many tracks are there in the album with ID 5?': 'SELECT COUNT(*) FROM tracks WHERE AlbumId = 5;',
            #  'Find the total number of invoices.': 'SELECT COUNT(*) FROM invoices;',
            #  'List all tracks that are longer than 5 minutes.': 'SELECT * FROM tracks WHERE Milliseconds > 300000;',
            #  'Who are the top 5 customers by total purchase?': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;',
            #  'Which albums are from the year 2000?': "SELECT * FROM albums WHERE strftime('%Y', ReleaseDate) = '2000';",
            #  'How many employees are there': 'SELECT COUNT(*) FROM "employee"'
        }

In [54]:
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import faiss
from langchain.schema import Document

embeddings = OpenAIEmbeddings()

few_shot_docs = [Document(page_content=question, metadata={'sql_query': few_shots[question]}) for question in few_shots.keys()]
vector_db = faiss.FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()

In [57]:
from langchain.agents.agent_toolkits import create_retriever_tool

tool_description = """
This tool will help you understand similar examples to adapt them to the user question.
Input to this tool should be the user question.
"""

retriever_tool = create_retriever_tool(
        retriever,
        name='sql_get_similar_examples',
        description=tool_description
    )
custom_tool_list = [retriever_tool]

In [58]:
from langchain.chat_models import ChatOpenAI

db = SQLDatabase.from_uri("sqlite:///database.db", include_tables=['users'], sample_rows_in_table_info=2 )
llm = ChatOpenAI(model_name='gpt-3.5-turbo',temperature=0)

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

custom_suffix = """
I should first get the similar examples I know.
If the examples are enough to construct the query, I can build it.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables
"""

agent = create_sql_agent(llm=llm,
                         toolkit=toolkit,
                         verbose=True,
                         agent_type=AgentType.OPENAI_FUNCTIONS,
                         extra_tools=custom_tool_list,
                         suffix=custom_suffix
                        )

In [59]:
agent.run("How many users do we have?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_get_similar_examples` with `How many users are there?`


[0m[33;1m[1;3m[Document(page_content='List all users.', metadata={'sql_query': 'SELECT * FROM users;'}), Document(page_content="Find all users for the year '2013'.", metadata={'sql_query': "SELECT * FROM users WHERE strftime('%Y', reg_date) = '2013';"})][0m[32;1m[1;3m
Invoking: `sql_db_query` with `SELECT COUNT(*) AS user_count FROM users;`
responded: Based on similar examples, I can construct the following query to count the number of users:

```sql
SELECT COUNT(*) AS user_count FROM users;
```

I will now execute this query to get the answer.

[0m[36;1m[1;3m[(100,)][0m[32;1m[1;3mWe have 100 users.[0m

[1m> Finished chain.[0m


'We have 100 users.'