In [90]:
import pandas as pd
from langchain_community.utilities import SQLDatabase
import sqlite3
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
from langchain_community.agent_toolkits import create_sql_agent
from langchain.callbacks.base import BaseCallbackHandler
from sqlalchemy import create_engine, inspect,text
import plotly.express as px
from dotenv import load_dotenv
import os

In [2]:
df=pd.read_csv("data/titanic.csv")
df.info()
df.head(5)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 887 entries, 0 to 886
Data columns (total 8 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   Survived                 887 non-null    int64  
 1   Pclass                   887 non-null    int64  
 2   Name                     887 non-null    object 
 3   Sex                      887 non-null    object 
 4   Age                      887 non-null    float64
 5   Siblings/Spouses Aboard  887 non-null    int64  
 6   Parents/Children Aboard  887 non-null    int64  
 7   Fare                     887 non-null    float64
dtypes: float64(2), int64(4), object(2)
memory usage: 55.6+ KB


Unnamed: 0,Survived,Pclass,Name,Sex,Age,Siblings/Spouses Aboard,Parents/Children Aboard,Fare
0,0,3,Mr. Owen Harris Braund,male,22.0,1,0,7.25
1,1,1,Mrs. John Bradley (Florence Briggs Thayer) Cum...,female,38.0,1,0,71.2833
2,1,3,Miss. Laina Heikkinen,female,26.0,0,0,7.925
3,1,1,Mrs. Jacques Heath (Lily May Peel) Futrelle,female,35.0,1,0,53.1
4,0,3,Mr. William Henry Allen,male,35.0,0,0,8.05


***Converting csv to SQL***

In [3]:
engine=create_engine(f"sqlite:///data/sqldb.db")
df.to_sql("titanic", engine, index=False)

ValueError: Table 'titanic' already exists.

In [72]:
def get_original_column_names(engine, table_name):
    # Use SQLAlchemy's inspect method to get the table columns
    inspector = inspect(engine)
    columns = inspector.get_columns(table_name)
    # Extract column names
    column_names = [column['name'] for column in columns]
    return column_names

In [4]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM titanic WHERE Age < 2;")

sqlite
['titanic']


"[(1, 2, 'Master. Alden Gates Caldwell', 'male', 0.83, 0, 2, 29.0), (0, 3, 'Master. Eino Viljami Panula', 'male', 1.0, 4, 1, 39.6875), (1, 3, 'Miss. Eleanor Ileen Johnson', 'female', 1.0, 1, 1, 11.1333), (1, 2, 'Master. Richard F Becker', 'male', 1.0, 2, 1, 39.0), (1, 1, 'Master. Hudson Trevor Allison', 'male', 0.92, 1, 2, 151.55), (1, 3, 'Miss. Maria Nakid', 'female', 1.0, 0, 2, 15.7417), (0, 3, 'Master. Sidney Leonard Goodwin', 'male', 1.0, 5, 2, 46.9), (1, 3, 'Miss. Helene Barbara Baclini', 'female', 0.75, 2, 1, 19.2583), (1, 3, 'Miss. Eugenie Baclini', 'female', 0.75, 2, 1, 19.2583), (1, 2, 'Master. Viljo Hamalainen', 'male', 0.67, 1, 1, 14.5), (1, 3, 'Master. Bertram Vere Dean', 'male', 1.0, 1, 2, 20.575), (1, 3, 'Master. Assad Alexander Thomas', 'male', 0.42, 0, 1, 8.5167), (1, 2, 'Master. Andre Mallet', 'male', 1.0, 0, 2, 37.0042), (1, 2, 'Master. George Sibley Richards', 'male', 0.83, 1, 1, 18.75)]"

I***nteracting with database***

In [5]:
#load_ext dotenv
load_dotenv()
print(load_dotenv())

True


In [6]:
#Load the LLM
llm = ChatGroq(
        model="llama-3.1-70b-versatile",
        temperature=0,
        max_tokens=None,
        timeout=None,
        max_retries=5,
        # other params...
        )

In [8]:
table_description = "The table is the titanic dataset which contains demographics and passenger information from 891 of the 2224 passengers and crew on board the Titanic.\
                    Contains the following fields:\
                    -Survived: Indicates if the passanger survived or not. Values are	0 = No, 1 = Yes\
                    -Pclass:	Ticket class. Values are1 = 1st, 2 = 2nd, 3 = 3rd\
                    -Sex: Gender or sex of the passenger\
                    -Age: Age of the passenger\
                    -Sibsp: number of siblings and/or spouses\
                    -Parch: number of parents andor childrend\
                    -Ticket: ticket number\
                    -Fare: Passenger fare\
                    -Cabin:	Cabin number\
                    -Embarke:port where the passenger embarked. Values areC = Cherbourg, Q = Queenstown,S = Southampton"

In [107]:
prompt = PromptTemplate(
        template=(
            "You are an intelligent SQL assistant. The following is a description of a SQL table:\n"
            "{table_description}\n"
            "Based on this description, your task is to generate SQL queries in response to user prompts.\n"
            "Follow these instructions:\n"
            "1. Carefully understand the user's request and identify the fields involved.\n"
            "2. Generate an SQL query that retrieves the requested information, ensuring the syntax is correct.\n"
            "3. Use aggregate functions (e.g., SUM, AVG) if needed, and apply GROUP BY or ORDER BY as appropriate.\n"
            "4. When using aggregate functions (e.g., SUM, AVG), include the fields GROUPED BY or ORDERED BY in the SELECT statement of the query.\n"
            "5. When using aggregate functions (e.g., SUM, AVG), create an alias based in the SELECT statement of the query.\n"
            "6. If you are asked to look for a passenger name, use LIKE function not WHERE.\n"
            "7. Ensure that your query returns meaningful and readable results.\n"
            "8. If the user requests, structure the results for generating plots (e.g., data suitable for bar or line charts).\n"
            "9. Do not assume the table structure; rely on the description provided.\n"
        )
    )

In [99]:
#To extract the generated SQL code. From https://github.com/langchain-ai/langchain/discussions/16624
class SQLHandler(BaseCallbackHandler):
    def __init__(self):
        print("ESTOOO")
        self.sql_result = None

    def on_agent_action(self, action, **kwargs):
        """Run on agent action. if the tool being used is sql_db_query,
         it means we're submitting the sql and we can 
         record it as the final sql"""

        if action.tool == "sql_db_query":
            self.sql_result = action.tool_input

In [108]:
agent_executor = create_sql_agent(llm, db=db, verbose=True,prompt_template=prompt)

In [109]:
print(prompt)

input_variables=['table_description'] input_types={} partial_variables={} template="You are an intelligent SQL assistant. The following is a description of a SQL table:\n{table_description}\nBased on this description, your task is to generate SQL queries in response to user prompts.\nFollow these instructions:\n1. Carefully understand the user's request and identify the fields involved.\n2. Generate an SQL query that retrieves the requested information, ensuring the syntax is correct.\n3. Use aggregate functions (e.g., SUM, AVG) if needed, and apply GROUP BY or ORDER BY as appropriate.\n4. When using aggregate functions (e.g., SUM, AVG), include the fields GROUPED BY or ORDERED BY in the SELECT statement of the query.\n5. When using aggregate functions (e.g., SUM, AVG), create an alias based in the SELECT statement of the query.\n6. If you are asked to look for a passenger name, use LIKE function not WHERE.\n7. Ensure that your query returns meaningful and readable results.\n8. If the 

In [31]:
agent_executor.invoke(
    {
        "input": "Tell me more about passenger Anders"
    }
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3mtitanic[0m[32;1m[1;3mAction: sql_db_schema
Action Input: titanic[0m[33;1m[1;3m
CREATE TABLE titanic (
	"Survived" BIGINT, 
	"Pclass" BIGINT, 
	"Name" TEXT, 
	"Sex" TEXT, 
	"Age" FLOAT, 
	"Siblings/Spouses Aboard" BIGINT, 
	"Parents/Children Aboard" BIGINT, 
	"Fare" FLOAT
)

/*
3 rows from titanic table:
Survived	Pclass	Name	Sex	Age	Siblings/Spouses Aboard	Parents/Children Aboard	Fare
0	3	Mr. Owen Harris Braund	male	22.0	1	0	7.25
1	1	Mrs. John Bradley (Florence Briggs Thayer) Cumings	female	38.0	1	0	71.2833
1	3	Miss. Laina Heikkinen	female	26.0	0	0	7.925
*/[0m[32;1m[1;3mThought: I now have the schema of the titanic table. I should query the database for passengers with the name Anders.
Action: sql_db_query_checker
A

{'input': 'Tell me more about passenger Anders',
 'output': 'There are several passengers with the name Anders, including Mr. Anders Johan Andersson, Miss. Erna Alexandra Andersson, Mr. Anders Vilhelm Gustafsson, Miss. Ellis Anna Maria Andersson, Mr. August Edvard Andersson, Miss. Carla Christine Nielsine Andersen-Jensen, Mr. Harry Anderson, Mr. William Anderson Walker, Miss. Ingeborg Constanzia Andersson, and Miss. Sigrid Elisabeth Andersson.'}

In [110]:
handler = SQLHandler()

ESTOOO


In [115]:
response=agent_executor.invoke({"input": "what's the average age of survivors by gender?Sort results in ascending order"},
                              {"callbacks": [handler]})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3mtitanic[0m[32;1m[1;3mAction: sql_db_schema
Action Input: titanic[0m[33;1m[1;3m
CREATE TABLE titanic (
	"Survived" BIGINT, 
	"Pclass" BIGINT, 
	"Name" TEXT, 
	"Sex" TEXT, 
	"Age" FLOAT, 
	"Siblings/Spouses Aboard" BIGINT, 
	"Parents/Children Aboard" BIGINT, 
	"Fare" FLOAT
)

/*
3 rows from titanic table:
Survived	Pclass	Name	Sex	Age	Siblings/Spouses Aboard	Parents/Children Aboard	Fare
0	3	Mr. Owen Harris Braund	male	22.0	1	0	7.25
1	1	Mrs. John Bradley (Florence Briggs Thayer) Cumings	female	38.0	1	0	71.2833
1	3	Miss. Laina Heikkinen	female	26.0	0	0	7.925
*/[0m[32;1m[1;3mThought: I now have the schema of the titanic table. I should query the average age of survivors by gender. I will use the AVG function to calculate

In [116]:
query=handler.sql_result

In [117]:
query

'SELECT AVG(CAST(Age AS REAL)) AS AverageAge, Sex FROM titanic WHERE Survived = 1 AND Age IS NOT NULL GROUP BY Sex ORDER BY AverageAge ASC'

In [86]:
print(response['output'])

The average age of survivors by gender is approximately 28.87 for males and 27.43 for females.


In [119]:
conn= sqlite3.connect('data/sqldb.db')

In [120]:
df_result = pd.read_sql(query,con=conn)

In [121]:
df_result

Unnamed: 0,AverageAge,Sex
0,27.428165,male
1,28.866953,female


In [122]:
fig = px.bar(df_result)
fig.show(renderer='iframe')

ValueError: Plotly Express cannot process wide-form data with columns of different type.

Unnamed: 0,AVG(CAST(Age AS REAL))
0,28.866953
1,27.428165
