In [100]:
# %pip install --upgrade --quiet  langchain langchain-community

In [101]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db", ignore_tables = ['Album', 'Artist',  'Genre' , 'MediaType', 'Playlist', 'PlaylistTrack',], )
#db.run("SELECT * FROM Artist LIMIT 10;")
print(f"This is the dialect of the database {db.dialect}")
print(f"This is the usable table names {db.get_usable_table_names()}")
print(f"This is the table info of the database {db.table_info}")

This is the dialect of the database sqlite
This is the usable table names ['Customer', 'Employee', 'Invoice', 'InvoiceLine', 'Track']
This is the table info of the database 
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Embraer - Empresa Brasileira de Aeronáutica S.A.	Av. Brigadeiro Faria Lima, 2170	São José dos Campos	SP	Brazil	12227-000	+55 (12) 3923-5555	+55 (12) 3923-5566	luisg@embraer.com.br	3
2

In [102]:
import os 



# Getting API key from another environment file
from dotenv import load_dotenv, find_dotenv
load_dotenv('api_keys.env')
# _ = load_dotenv(find_dotenv()) # read

api_key = os.environ['GEMINI_API_KEY']

In [103]:
from langchain.chains import create_sql_query_chain
from langchain_google_genai import GoogleGenerativeAI

# Large model to be used
llm = GoogleGenerativeAI(model="models/text-bison-001", google_api_key=api_key)
#SQL query chain 
generate_query = create_sql_query_chain(llm, db)
query = generate_query.invoke({"question": "How many employees are there"})
query

'SELECT COUNT(*) FROM Employee'

In [104]:
# To execute the query 
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query = QuerySQLDataBaseTool(db=db)
generate_query = create_sql_query_chain(llm, db)
chain = generate_query | execute_query
query = chain.invoke({"question": "How many employees are there"})
print(query)


[(8,)]


In [105]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [106]:
from operator import itemgetter

from  langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, 
    answer the user question in a contructive sentence that mentions the question asked.
    Question: {question}
    SQL Query: {query}
    SQL  Result: {result}
    Answer: """
)

rephrase_answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

answer = chain.invoke({"question": "How many orders are there"})
print(answer)

There are 412 orders.


In [107]:
answer = chain.invoke({"question": "How many artist are there"})
print(answer)

There are 275 artists.


In [108]:
example = [
    {
    "input": "Hey, can you tell me what track number  For Those About To Rock (We Salute You) is?",
    "query": "What is the TrackId for Name = 'For Those About To Rock (We Salute You)'?"
  },
  
    {"input": "Show details of the invoice for track ID 4", 
     "query": "SELECT i.InvoiceId, i.CustomerId, c.FirstName, c.LastName, il.UnitPrice, il.Quantity, i.Total FROM Invoice i JOIN InvoiceLine il ON i.InvoiceId = il.InvoiceId JOIN Customer c ON i.CustomerId = c.CustomerId WHERE il.TrackId = 4"
     },

{"input": "I want to find the contact details for customer Leonie Köhler",
  "query": "SELECT * FROM Customer WHERE FirstName = 'Leonie' AND LastName = 'Köhler'"
  },

{"input": "What is the total amount invoiced to customer Embraer - Empresa Brasileira de Aeronáutica S.A.?" , 
 "query": "SELECT SUM(Total) FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE Company = 'Embraer - Empresa Brasileira de Aeronáutica S.A.')"
 },

{"input": "List all invoices with a total over $4.00", 
 "query": "SELECT * FROM Invoice WHERE Total > 4.00"
 },

{
"input": "Change Luis Gonçalves' email to luisg@embraer.com.br",
"query": "UPDATE Customer SET Email = 'luisg@embraer.com.br' WHERE FirstName = 'Luís' AND LastName = 'Gonçalves';"
}
]

In [109]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
     [
         ("human", "{input}\nSQLQuery:"),
         ("ai", "{query}"),
     ]
 )
few_shot_prompt = FewShotChatMessagePromptTemplate(
     example_prompt=example_prompt,
     examples=example,
     # input_variables=["input","top_k"],
     input_variables=["input"],
 )
print(few_shot_prompt.format(input="How many products are there?"))


Human: Hey, can you tell me what track number  For Those About To Rock (We Salute You) is?
SQLQuery:
AI: What is the TrackId for Name = 'For Those About To Rock (We Salute You)'?
Human: Show details of the invoice for track ID 4
SQLQuery:
AI: SELECT i.InvoiceId, i.CustomerId, c.FirstName, c.LastName, il.UnitPrice, il.Quantity, i.Total FROM Invoice i JOIN InvoiceLine il ON i.InvoiceId = il.InvoiceId JOIN Customer c ON i.CustomerId = c.CustomerId WHERE il.TrackId = 4
Human: I want to find the contact details for customer Leonie Köhler
SQLQuery:
AI: SELECT * FROM Customer WHERE FirstName = 'Leonie' AND LastName = 'Köhler'
Human: What is the total amount invoiced to customer Embraer - Empresa Brasileira de Aeronáutica S.A.?
SQLQuery:
AI: SELECT SUM(Total) FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE Company = 'Embraer - Empresa Brasileira de Aeronáutica S.A.')
Human: List all invoices with a total over $4.00
SQLQuery:
AI: SELECT * FROM Invoice WHERE Total > 4.00
H

In [110]:
from langchain_google_genai import GoogleGenerativeAIEmbeddings

embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)

In [111]:
%pip install langchain_openai




In [112]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings
vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    example,
    embeddings,
    vectorstore,
    # k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "Who is highest paying customer"})

[{'input': 'Show details of the invoice for track ID 4',
  'query': 'SELECT i.InvoiceId, i.CustomerId, c.FirstName, c.LastName, il.UnitPrice, il.Quantity, i.Total FROM Invoice i JOIN InvoiceLine il ON i.InvoiceId = il.InvoiceId JOIN Customer c ON i.CustomerId = c.CustomerId WHERE il.TrackId = 4'},
 {'input': 'What is the total amount invoiced to customer Embraer - Empresa Brasileira de Aeronáutica S.A.?',
  'query': "SELECT SUM(Total) FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE Company = 'Embraer - Empresa Brasileira de Aeronáutica S.A.')"},
 {'input': 'List all invoices with a total over $4.00',
  'query': 'SELECT * FROM Invoice WHERE Total > 4.00'},
 {'input': 'I want to find the contact details for customer Leonie Köhler',
  'query': "SELECT * FROM Customer WHERE FirstName = 'Leonie' AND LastName = 'Köhler'"}]

In [113]:
few_shot_prompt = FewShotChatMessagePromptTemplate(
     example_prompt=example_prompt,
     example_selector=example_selector,
     input_variables=["input", "top_k"],
 )
print(few_shot_prompt.format(input="I am looking for the most paying customer"))

Human: Show details of the invoice for track ID 4
SQLQuery:
AI: SELECT i.InvoiceId, i.CustomerId, c.FirstName, c.LastName, il.UnitPrice, il.Quantity, i.Total FROM Invoice i JOIN InvoiceLine il ON i.InvoiceId = il.InvoiceId JOIN Customer c ON i.CustomerId = c.CustomerId WHERE il.TrackId = 4
Human: I want to find the contact details for customer Leonie Köhler
SQLQuery:
AI: SELECT * FROM Customer WHERE FirstName = 'Leonie' AND LastName = 'Köhler'
Human: List all invoices with a total over $4.00
SQLQuery:
AI: SELECT * FROM Invoice WHERE Total > 4.00
Human: What is the total amount invoiced to customer Embraer - Empresa Brasileira de Aeronáutica S.A.?
SQLQuery:
AI: SELECT SUM(Total) FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE Company = 'Embraer - Empresa Brasileira de Aeronáutica S.A.')


In [126]:
final_prompt = ChatPromptTemplate.from_messages(
    [
    ("system", "You are a SQLite expert. Given an input question, create a symtactically correct SQLite query to run. Unless otherwise specified. \n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
    few_shot_prompt, 
    ("human", "{input}"),
    ]
)

# print(final_prompt.format(input="",table_info="some table info"))

In [127]:
generate_query = create_sql_query_chain(llm, db,final_prompt)
chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)
chain.invoke({"question": "WHich customer has the highest invoice"})

'The question is "WHich customer has the highest invoice". The SQL query is "AI: SELECT CustomerId, MAX(Total) FROM Invoice GROUP BY CustomerId". The SQL result is "Error: (sqlite3.OperationalError) near \\"AI\\": syntax error". The correct SQL query is "SELECT CustomerId, MAX(Total) FROM Invoice GROUP BY CustomerId". The SQL result is "652". The answer is "Customer 652 has the highest invoice".'

In [116]:
import pandas as pd

data = {'table': ['Customer', 'Employee', 'Invoice', 
                 'InvoiceLine', 'Track'],
        'description': [

            'Stores customer information. Includes columns for CustomerId (primary key), name, contact details (address, phone, email), and a foreign key referencing the Employee table for the support representative assigned to the customer. Connects to the Invoice table through the CustomerId foreign key in the Invoice table.',
            'Stores employee information. Includes columns for EmployeeId (primary key), name, title, hire/birth dates, contact details (address, phone, email), and a foreign key referencing another Employee for their manager (if applicable). Connects to the Customer table through the SupportRepId foreign key in the Customer table (one-to-many relationship - one employee can support many customers).',
            'Stores information about customer invoices. Includes columns for InvoiceId (primary key), CustomerId (foreign key referencing the Customer table), InvoiceDate, billing address details, and Total invoice amount. Connects to the Customer table through the CustomerId foreign key. Connects to the InvoiceLine table through the InvoiceId foreign key in the InvoiceLine table (one-to-many relationship - one invoice can have many invoice lines).',
            'Links invoices to specific tracks purchased. Includes columns for InvoiceLineId (primary key), InvoiceId (foreign key referencing the Invoice table), TrackId (foreign key referencing the Track table), UnitPrice per track, and Quantity purchased. Connects to the Invoice table through the InvoiceId foreign key. Connects to the Track table through the TrackId foreign key.',
            'Stores information about individual music tracks. Connects to the InvoiceLine table through the TrackId foreign key in the InvoiceLine table (one-to-many relationship - one track can be included in many invoices). '
            ]
        }

df = pd.DataFrame(data)
df

Unnamed: 0,table,description
0,Customer,Stores customer information. Includes columns ...
1,Employee,Stores employee information. Includes columns ...
2,Invoice,Stores information about customer invoices. In...
3,InvoiceLine,Links invoices to specific tracks purchased. I...
4,Track,Stores information about individual music trac...


In [117]:
# import pandas as pd

# data = {'table': ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 
#                  'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track'],
#         'description': [
#             'Stores information about music albums. Includes columns for AlbumId (primary key), Title (album title), and ArtistId (foreign key referencing the Artist table).',
#             'Stores information about artists. Includes columns for ArtistId (primary key) and Name (artist name).Connects to the Album table through the ArtistId foreign key in the Album table.',
#             'Stores customer information. Includes columns for CustomerId (primary key), name, contact details (address, phone, email), and a foreign key referencing the Employee table for the support representative assigned to the customer. Connects to the Invoice table through the CustomerId foreign key in the Invoice table.',
#             'Stores employee information. Includes columns for EmployeeId (primary key), name, title, hire/birth dates, contact details (address, phone, email), and a foreign key referencing another Employee for their manager (if applicable). Connects to the Customer table through the SupportRepId foreign key in the Customer table (one-to-many relationship - one employee can support many customers).',
#             'Stores musical genre information. Includes columns for GenreId (primary key) and Name (genre name). Connects to the Track table through the GenreId foreign key in the Track table.',
#             'Stores information about customer invoices. Includes columns for InvoiceId (primary key), CustomerId (foreign key referencing the Customer table), InvoiceDate, billing address details, and Total invoice amount. Connects to the Customer table through the CustomerId foreign key. Connects to the InvoiceLine table through the InvoiceId foreign key in the InvoiceLine table (one-to-many relationship - one invoice can have many invoice lines).',
#             'Links invoices to specific tracks purchased. Includes columns for InvoiceLineId (primary key), InvoiceId (foreign key referencing the Invoice table), TrackId (foreign key referencing the Track table), UnitPrice per track, and Quantity purchased. Connects to the Invoice table through the InvoiceId foreign key. Connects to the Track table through the TrackId foreign key.',
#             'Stores different media types (e.g., audio file format).Includes columns for MediaTypeId (primary key) and Name (media type name). Connects to the Track table through the MediaTypeId foreign key in the Track table.',
#             'Stores playlists. Includes columns for PlaylistId (primary key) and Name (playlist name). Connects to the Track table through the many-to-many relationship table PlaylistTrack.',
#             'Links playlists to tracks included in them. Includes foreign keys referencing both PlaylistId and TrackId. Connects the Playlist and Track tables through a many-to-many relationship.',
#             'Stores information about individual music tracks. Includes columns for TrackId (primary key), Name (track title), AlbumId (foreign key referencing the Album table), MediaTypeId (foreign key referencing the MediaType table), GenreId (foreign key referencing the Genre table), composer information, Milliseconds (track duration), Bytes (file size), and UnitPrice. Connects to the Album table through the AlbumId foreign key. Connects to the Genre table through the GenreId foreign key. Connects to the MediaType table through the MediaTypeId foreign key. Connects to the InvoiceLine table through the TrackId foreign key in the InvoiceLine table (one-to-many relationship - one track can be included in many invoices). Connects to the Playlist table through the many-to-many relationship table PlaylistTrack.'
#             ]
#         }

# df = pd.DataFrame(data)
# df

In [118]:
from operator import itemgetter
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from typing import List


def get_table_details(df):
    # Read the CSV file into a dataframe
    table_description = df
    table_docs = []

    #Iterates over the Dataframe rows to create Document objects
    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name: " + row['table'] + "\n" + "Table Description: " + row['description'] + "\n"
    
    return table_details

class Table(BaseModel):
    """ Table in SQL databases"""

    name: str = Field(description="Name of table in SQL database")

table_details = get_table_details(df)
print(table_details)

Table Name: Customer
Table Description: Stores customer information. Includes columns for CustomerId (primary key), name, contact details (address, phone, email), and a foreign key referencing the Employee table for the support representative assigned to the customer. Connects to the Invoice table through the CustomerId foreign key in the Invoice table.
Table Name: Employee
Table Description: Stores employee information. Includes columns for EmployeeId (primary key), name, title, hire/birth dates, contact details (address, phone, email), and a foreign key referencing another Employee for their manager (if applicable). Connects to the Customer table through the SupportRepId foreign key in the Customer table (one-to-many relationship - one employee can support many customers).
Table Name: Invoice
Table Description: Stores information about customer invoices. Includes columns for InvoiceId (primary key), CustomerId (foreign key referencing the Customer table), InvoiceDate, billing address

In [119]:
# from langchain_google_genai import ChatGoogleGenerativeAI
# llm1 = ChatGoogleGenerativeAI(model="gemini-pro",
#                              temperature=0.7,  google_api_key=api_key, convert_system_message_to_human=True)

# table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevan to the user question. \
#     The tables are:
    
#     {table_details}
# Remember to include ALL POTENTIALLY RELEVANT tables, even if you are not sure that they are needed.
#     """

# table_chin = create_extraction_chain_pydantic(Table, llm1, system_message=table_details_prompt)
# tables = table_chin.invoke({"input": "give us details of customer and their invoice count"})
# tables

In [120]:
final_prompt = ChatPromptTemplate.from_messages(
    [
    ("system", "You are a SQLite expert. Given an input question, create a symtactically correct SQLite query to run. Unless otherwise specified. \n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
    few_shot_prompt, 
    MessagesPlaceholder(variable_name="messages"),
    ("human", "{input}"),
    ]
)

print(final_prompt.format(input="I am looking for the most paying customer?", table_info="some table info", messages=[]))

System: You are a SQLite expert. Given an input question, create a symtactically correct SQLite query to run. Unless otherwise specified. 

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries.
Human: Show details of the invoice for track ID 4
SQLQuery:
AI: SELECT i.InvoiceId, i.CustomerId, c.FirstName, c.LastName, il.UnitPrice, il.Quantity, i.Total FROM Invoice i JOIN InvoiceLine il ON i.InvoiceId = il.InvoiceId JOIN Customer c ON i.CustomerId = c.CustomerId WHERE il.TrackId = 4
Human: What is the total amount invoiced to customer Embraer - Empresa Brasileira de Aeronáutica S.A.?
SQLQuery:
AI: SELECT SUM(Total) FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE Company = 'Embraer - Empresa Brasileira de Aeronáutica S.A.')
Human: I want to find the contact details for customer Leonie Köhler
SQLQuery:
AI: SELECT * FROM Customer WHERE FirstName = 'Leonie' AND LastName = 'Köhler'
Human: List a

In [121]:
from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()
generate_query = create_sql_query_chain(llm, db, final_prompt)
chain= (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
| rephrase_answer
)

In [122]:
question = "How many customers with order count more than 5"
response = chain.invoke({"question": question,"messages":history.messages})
response

'The query contains an error.'

In [123]:
history.add_user_message(question)
history.add_ai_message(response)
response = chain.invoke({"question": "Can you list there fullnames?", "messages":history.messages})
response

'The query you provided has a syntax error. Please fix the query and try again.'