https://github.com/Ne0bliviscaris/Ollama-SQLite-RAG/blob/main/test_field.ipynb
https://github.com/TheAILearner/Langchain-Agents/blob/main/Dynamic%20Few-shot%20Prompting%20with%20Llama%203.ipynb

In [3]:
from langchain_community.utilities import SQLDatabase


In [4]:
# connecting to sqldb 
db_path = "sqlite:///" + r"D:\Generative AI\generative ai project\QA_SQL_Tabular_data\data\sqldb.db"

db = SQLDatabase.from_uri(db_path)

In [5]:
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x1b20668b740>

In [6]:
# check the connection
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

## TRY LLAMA MODEL

In [8]:
from langchain_community.llms import Ollama
# messages = [
#     {"role": "system", "content": str(
#         "You are a helpful assistant"
#     )},
#     {"role": "user", "content": str("hello")}
# ]

client = Ollama(base_url='http://localhost:11434', model='llama3', temperature=0.0)


# response = client.invoke(
#     messages
# )

# print(response)


In [7]:
# from langchain.chains import create_sql_query_chain
# chain = create_sql_query_chain(client, db)
# STATIC_MESSAGE = 'You are an experienced data scientist. Answer the question verbally and add the SQL query to prove it. Make the query one line so it is easy to test. Use the simplest and most efficient query'
# # Main question
# QUESTION = "How many employees are there"
# # Generate response
# response = chain.invoke({f"question": f"{QUESTION} {STATIC_MESSAGE}"})
# print(response)

# # # Extract SQL query from the response
# # from modules.ai.rag_pipeline import extract_query_from_model
# # extracted_query = extract_query_from_model(response)
# # print(f"Extracted query:\n {extracted_query}")


In [9]:
def extract_query_from_model(model_response):
    """Get the query from the model response"""
    import re

    match = re.search(r"SELECT.*?(?:;|$)", model_response, re.IGNORECASE)
    if match:
        sql_query = match.group(0).strip()
        # Add a semicolon if missing
        if not sql_query.endswith(";"):
            sql_query += ";"
        return sql_query
    else:
        return "No SQL query found in the response."

In [13]:
from langchain_core.prompts import PromptTemplate
from langchain.chains import create_sql_query_chain
schema = db.get_table_info()

dialect="sqlite"
table_info=schema
question = "How many customer are there in total?"

template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Question: {input}
TopK: {top_k}'''
prompt = PromptTemplate.from_template(template)
chain = create_sql_query_chain(client, db, prompt)

response = chain.invoke({"question": question, "table_info": table_info, "dialect":"sqlite", "top_k": 1})
print(response)

# Extract SQL query from the response
extracted_query = extract_query_from_model(response)
print(f"Extracted query:\n {extracted_query}")

To answer this question, we need to create a SQL query that counts the number of unique customers in the "Customer" table (which is not provided in the given schema). However, based on the given tables and data, we can assume that there are no additional customer information available. Therefore, we can use the following SQL query:

```sql
SELECT COUNT(DISTINCT CustomerId) AS TotalCustomers
FROM Invoice;
```

This query counts the number of unique "CustomerId" values in the "Invoice" table, which represents the total number of customers.

Note: The actual SQL query might be different depending on the actual schema and data.
Extracted query:
 No SQL query found in the response.


In [14]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)

chain = chain | execute_query

chain.invoke({"question": "How many customer are there"})

'Error: (sqlite3.OperationalError) near "To": syntax error\n[SQL: To answer this question, we need to create a SQL query that counts the number of unique customers in the "Customer" table. Since you didn\'t provide the "Customer" table schema, I\'ll assume it has a column named "CustomerId".\n\nHere\'s the SQL query:\n\n```sql\nSELECT COUNT(DISTINCT CustomerId) AS TotalCustomers\nFROM Invoice;\n```\n\nThis query uses the `COUNT(DISTINCT)` function to count the number of unique customers. The `DISTINCT` keyword ensures that each customer is only counted once, even if they have multiple invoices.\n\nIf you want to get the top 5 customers with the most invoices, you can use a subquery:\n\n```sql\nSELECT CustomerId, COUNT(*) AS TotalInvoices\nFROM Invoice\nGROUP BY CustomerId\nORDER BY TotalInvoices DESC\nLIMIT 5;\n```\n\nThis query groups the invoices by customer and counts the number of invoices for each customer. The `ORDER BY` clause sorts the results in descending order (most invoices

In [16]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | client | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=chain).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})

'There are no employees to count. The provided database schema does not include an "Employee" table, so it is not possible to determine the number of employees. If you meant something else, please clarify the question and I\'ll be happy to help!'

I'm happy to help! However, I need more information about the company or organization you're referring to. Could you please provide me with the name of the company or organization, and if possible, the industry or sector it operates in? That way, I can try to find the most accurate answer for you.

Alternatively, if you're asking a general question about employees in the world, according to the International Labor Organization (ILO), there were approximately 3.3 billion employed people worldwide as of 2020. However, this number may have changed since then due to various factors such as economic conditions and global events.

Please let me know if you have any further questions or if there's anything else I can help you with!


In [23]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain.agents import AgentType

agent_executor = create_sql_agent(client, db = db, agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose = True)
agent_executor.invoke("How many employees are there?")



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mLet's start by listing all the tables in the database.

Action: sql_db_list_tables
Action Input: (empty string[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3mThought: Now that I have a list of all the tables in the database, I should look at the schema of the most relevant tables to see what columns they have. Since the question is asking about employees, I'll focus on the Employee table.

Action: sql_db_schema
Action Input: Employee[0m[33;1m[1;3m
CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATETIME, 
	"HireDate" DATETIME, 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax"

{'input': 'How many employees are there?', 'output': 'There are 8 employees.'}