In [1]:
#Code to mount Google Drive at Colab Notebook instance
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


##**Install Langchain and Google Gemini related libraries**

In [2]:
!pip install -q -U langchain==0.1.2
!pip install -q -U google-generativeai==0.5.2
!pip install -q -U langchain-google-genai==1.0.3

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m803.6/803.6 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m52.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.9/302.9 kB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.3/49.3 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.3/49.3 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m44.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━

# Get Gemini API Key from Secrets
Set GEMINI_KEY secret key at Google Colab and get that here to runn Google Gemini LLM. You can get Google Gemini Key from following link https://makersuite.google.com/app/apikey

In [7]:
from google.colab import userdata
GOOGLE_API_KEY = userdata.get('GEMINI_KEY')

Copy chinhook.db to your Google Drive

Now, chinhook.db is in our directory and we can interface with it using the SQLAlchemy-driven SQLDatabase class.

Location of chinhook.db
/content/drive/MyDrive/Colab Notebooks/chinook.db

In [8]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:////content/drive/MyDrive/Colab Notebooks/chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artists LIMIT 10;")

sqlite
['albums', 'artists', 'customers', 'employees', 'genres', 'invoice_items', 'invoices', 'media_types', 'playlist_track', 'playlists', 'tracks']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

# Chain
Let’s create a simple chain that takes a question, turns it into a SQL query, executes the query, and uses the result to answer the original question.

# Convert question to SQL query
The first step in a SQL chain or agent is to take the user input and convert it to a SQL query. LangChain comes with a built-in chain for this: create_sql_query_chain

In [12]:
import re
from langchain.chains import create_sql_query_chain
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate

llm = ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=GOOGLE_API_KEY, convert_system_message_to_human=True, temperature=0.0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})

# Generate regex pattern to remove prefix and suffix
regex_pattern = r'^```sql\n|\n```$'

# Remove prefix and suffix using regex to get proper SQL query
modified_response = re.sub(regex_pattern, '', response)
print(modified_response)



SELECT COUNT(*) AS "Number of Employees"
FROM employees;


In [14]:
db.run(modified_response)

'[(8,)]'

#Execute SQL query
Now that we’ve generated a SQL query, we’ll want to execute it.

We can use the QuerySQLDatabaseTool to easily add query execution to our chain:

In [21]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate

execute_query = QuerySQLDataBaseTool(db=db)

template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the {top_k} answer.
Use the following format:

Question: "Question here"
"SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Provide SQL query as simple string without any markdown.

Question: {input}'''
prompt = PromptTemplate.from_template(template)

write_query = create_sql_query_chain(llm, db, prompt)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there"})



'[(8,)]'

#Answer the question
Now that we’ve got a way to automatically generate and execute queries, we just need to combine the original question and SQL query result to generate a final answer.

In [22]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
{query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})



'There are 8 employees.'

# Agents
LangChain has an SQL Agent which provides a more flexible way of interacting with SQL databases. The main advantages of using the SQL Agent are:

It can answer questions based on the databases’ schema as well as on the databases’ content (like describing a specific table).
It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.
It can answer questions that require multiple dependent queries.
It will save tokens by only considering the schema from relevant tables.
To initialize the agent, we use create_sql_agent function. This agent contains the SQLDatabaseToolkit which contains tools to:

- Create and execute queries
- Check query syntax
- Retrieve table descriptions

In [23]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit

agent_executor = create_sql_agent(llm,
                                  toolkit=SQLDatabaseToolkit(db=db, llm=llm),
                                  agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                                  verbose=True)

In [24]:
agent_executor.invoke(
    {
        "input": "List the total sales per country. Which country's customers spent the most?"
    }
)



[1m> Entering new SQL Agent Executor chain...[0m




[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3malbums, artists, customers, employees, genres, invoice_items, invoices, media_types, playlist_track, playlists, tracks[0m



[32;1m[1;3mAction: sql_db_query
Action Input: SELECT country, SUM(total) AS total_sales FROM invoices GROUP BY country ORDER BY total_sales DESC[0m[36;1m[1;3mError: (sqlite3.OperationalError) no such column: country
[SQL: SELECT country, SUM(total) AS total_sales FROM invoices GROUP BY country ORDER BY total_sales DESC]
(Background on this error at: https://sqlalche.me/e/20/e3q8)[0m



[32;1m[1;3mAction: sql_db_schema
Action Input: invoices[0m[33;1m[1;3m
CREATE TABLE invoices (
	"InvoiceId" INTEGER NOT NULL, 
	"CustomerId" INTEGER NOT NULL, 
	"InvoiceDate" DATETIME NOT NULL, 
	"BillingAddress" NVARCHAR(70), 
	"BillingCity" NVARCHAR(40), 
	"BillingState" NVARCHAR(40), 
	"BillingCountry" NVARCHAR(40), 
	"BillingPostalCode" NVARCHAR(10), 
	"Total" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("InvoiceId"), 
	FOREIGN KEY("CustomerId") REFERENCES customers ("CustomerId")
)

/*
3 rows from invoices table:
InvoiceId	CustomerId	InvoiceDate	BillingAddress	BillingCity	BillingState	BillingCountry	BillingPostalCode	Total
1	2	2009-01-01 00:00:00	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	1.98
2	4	2009-01-02 00:00:00	Ullevålsveien 14	Oslo	None	Norway	0171	3.96
3	8	2009-01-03 00:00:00	Grétrystraat 63	Brussels	None	Belgium	1000	5.94
*/[0m



[32;1m[1;3mAction: sql_db_query
Action Input: SELECT country, SUM(total) AS total_sales FROM invoices GROUP BY country ORDER BY total_sales DESC[0m[36;1m[1;3mError: (sqlite3.OperationalError) no such column: country
[SQL: SELECT country, SUM(total) AS total_sales FROM invoices GROUP BY country ORDER BY total_sales DESC]
(Background on this error at: https://sqlalche.me/e/20/e3q8)[0m



[32;1m[1;3mAction: sql_db_schema
Action Input: customers[0m[33;1m[1;3m
CREATE TABLE customers (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES employees ("EmployeeId")
)

/*
3 rows from customers table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Embraer - Empresa Brasileira de Aeronáutica S.A.	Av. Brigadeiro Faria Lima, 2170	São José dos Campos	SP	Brazil	12227-000	+55 (12) 3923-5555	+55 (12) 3923-5566	luisg@embraer.com.br	3
2	Leonie	Köhler	None	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	+49 0711 2842222	None	leon



[32;1m[1;3mAction: sql_db_query
Action Input: SELECT c.country, SUM(i.total) AS total_sales FROM customers c JOIN invoices i ON c.customerid = i.customerid GROUP BY c.country ORDER BY total_sales DESC[0m[36;1m[1;3m[('USA', 523.0600000000004), ('Canada', 303.96), ('France', 195.09999999999994), ('Brazil', 190.1), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24), ('Portugal', 77.24), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)][0m



[32;1m[1;3mFinal Answer: The country with the highest total sales is the USA, with $523.06 in total sales.[0m

[1m> Finished chain.[0m


{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The country with the highest total sales is the USA, with $523.06 in total sales.'}