# SQL Agent & Data Visualisation with CrewAI

- https://github.com/imanoop7/SQL-Agent-using-CrewAI-and-Groq/blob/main/crewai_agent.ipynb
- https://www.crewai.com/

## Imports

In [1]:
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDataBaseTool,
)
from langchain_community.utilities.sql_database import SQLDatabase

from textwrap import dedent

from langchain_openai import ChatOpenAI

from crewai import Agent, Crew, Process, Task
from crewai_tools import tool
from crewai.crews.crew_output import CrewOutput

from dotenv import load_dotenv
import os
import re

In [2]:
# Change the default LLM
os.environ["OPENAI_MODEL_NAME"]="gpt-4o-mini"
llm = ChatOpenAI(model="gpt-4o-mini")

In [3]:
# Load the sql database
db = SQLDatabase.from_uri("sqlite:///example.db")

## Define SQL `tools`

In [4]:
@tool("list_tables")
def list_tables() -> str:
    """List the available tables in the database"""
    return ListSQLDatabaseTool(db=db).invoke("")

list_tables.run()

Using Tool: list_tables


'users'

In [5]:
@tool("tables_schema")
def tables_schema(tables: str) -> str:
    """
    Input is a comma-separated list of tables, output is the schema and sample rows
    for those tables. Be sure that the tables actually exist by calling `list_tables` first!
    Example Input: table1, table2, table3
    """
    tool = InfoSQLDatabaseTool(db=db)
    return tool.invoke(tables)

print(tables_schema.run("users"))

Using Tool: tables_schema

CREATE TABLE users (
	id INTEGER, 
	name TEXT NOT NULL, 
	age INTEGER, 
	email TEXT NOT NULL, 
	registration_date TEXT NOT NULL, 
	PRIMARY KEY (id)
)

/*
3 rows from users table:
id	name	age	email	registration_date
1	Jonathan Johnson	51	austinjoshua@example.net	2023-10-05
2	Jason Nicholson	27	martinezjamie@example.net	2023-12-01
3	Chad Little	73	twalter@example.com	2024-06-14
*/


In [6]:
@tool("execute_sql")
def execute_sql(sql_query: str) -> str:
    """Execute a SQL query against the database. Returns the result"""
    return QuerySQLDataBaseTool(db=db).invoke(sql_query)

execute_sql.run("SELECT * FROM users WHERE age > 40 LIMIT 5")

Using Tool: execute_sql


"[(1, 'Jonathan Johnson', 51, 'austinjoshua@example.net', '2023-10-05'), (3, 'Chad Little', 73, 'twalter@example.com', '2024-06-14'), (6, 'Patrick Wood', 64, 'marywheeler@example.org', '2024-05-20'), (8, 'Elizabeth Bentley', 43, 'gardnerkayla@example.org', '2024-07-30'), (9, 'John Mullins', 66, 'iwhite@example.com', '2023-11-03')]"

In [7]:
@tool("check_sql")
def check_sql(sql_query: str) -> str:
    """
    Use this tool to double check if your query is correct before executing it. Always use this
    tool before executing a query with `execute_sql`.
    """
    return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})

check_sql.run("SELECT * FROM users WHERE age > 40 LIMIT 5")

Using Tool: check_sql


'```sql\nSELECT * FROM users WHERE age > 40 LIMIT 5\n```'

## Agents

Create some agents to work together in a crew:
- `sql_dev`: creates and executes SQL queries
- `data_visualiser`: recieves data from the `sql_dev` and creates a visualisation in plotly

In [8]:
sql_dev = Agent(
    role="Senior Database Developer",
    goal="Construct and execute SQL queries based on a request",
    backstory=dedent(
        """
        You are an experienced database engineer who is master at creating efficient and complex SQL queries.
        You have a deep understanding of how different databases work and how to optimize queries.
        Use the `list_tables` to find available tables.
        Use the `tables_schema` to understand the metadata for the tables.
        Use the `execute_sql` to check your queries for correctness.
        Use the `check_sql` to execute queries against the database.
    """
    ),
    llm=llm,
    tools=[list_tables, tables_schema, execute_sql, check_sql],
    allow_delegation=False,
)

In [9]:
data_visualiser  = Agent(
    role="Data Visualization Specialist",
    goal="You receive data from the database developer and visualize it",
    backstory=dedent(
        """
        You have deep experience with visualizing data using Python and create
        graphs using Plotly.
        Your work is always based on the provided data and is clear,
        easy-to-understand and to the point. You have attention
        to detail and always produce very detailed work (as long as you need).
        """
    ),
    llm=llm,
    allow_delegation=False
)

## Tasks

In [10]:
extract_data = Task(
    description="Extract data that is required for the query {query}.",
    expected_output="Database result for the query",
    agent=sql_dev,
)

In [11]:
data_visulaisation = Task(
    description="Plot the data if required by the query {query}.",
    expected_output="Plotly figure",
    agent=data_visualiser,
    context=[extract_data],
)

## Crew

In [12]:
crew = Crew(
    agents=[sql_dev, data_visualiser],
    tasks=[extract_data, data_visulaisation],
    process=Process.sequential,
    verbose=True,
    full_output=True,
    memory=False,
    output_log_file="crew.log",
)

In [13]:
inputs = {
    "query": "Plot how many users register each month."
}

result = crew.kickoff(inputs=inputs)

[1m[95m [2024-09-25 21:44:00][DEBUG]: == Working Agent: Senior Database Developer[00m
[1m[95m [2024-09-25 21:44:00][INFO]: == Starting Task: Extract data that is required for the query Plot how many users register each month..[00m
[1m[92m [2024-09-25 21:44:11][DEBUG]: == [Senior Database Developer] Task output: ```
[
    ('2023-09', 166), 
    ('2023-10', 862), 
    ('2023-11', 799), 
    ('2023-12', 842), 
    ('2024-01', 856), 
    ('2024-02', 808), 
    ('2024-03', 867), 
    ('2024-04', 799), 
    ('2024-05', 914), 
    ('2024-06', 801), 
    ('2024-07', 810), 
    ('2024-08', 826), 
    ('2024-09', 650)
]
```

[00m
[1m[95m [2024-09-25 21:44:11][DEBUG]: == Working Agent: Data Visualization Specialist[00m
[1m[95m [2024-09-25 21:44:11][INFO]: == Starting Task: Plot the data if required by the query Plot how many users register each month..[00m
[1m[92m [2024-09-25 21:44:15][DEBUG]: == [Data Visualization Specialist] Task output: ```python
import plotly.graph_objects a

In [14]:
# Custom parser
def extract_python(crew_output: CrewOutput) -> str:
    """Extracts python content from a string where python code is embedded between ```python and ``` tags.

    Parameters:
        text (str): The text containing the python code content.

    Returns:
        str: A text string of the first python match.
    """
    text = crew_output.raw
    # Define the regular expression pattern to match JSON blocks
    pattern = r"```python(.*?)```"

    # Find all non-overlapping matches of the pattern in the string
    matches = re.findall(pattern, text, re.DOTALL)

    # Return the list of matched JSON strings, stripping any leading or trailing whitespace
    try:
        return matches[0].strip()
    except Exception:
        raise ValueError(f"Failed to parse: {message}")

In [15]:
python_output_code = extract_python(result)

In [16]:
python_output_code

"import plotly.graph_objects as go\n\n# Data provided\ndata = [\n    ('2023-09', 166), \n    ('2023-10', 862), \n    ('2023-11', 799), \n    ('2023-12', 842), \n    ('2024-01', 856), \n    ('2024-02', 808), \n    ('2024-03', 867), \n    ('2024-04', 799), \n    ('2024-05', 914), \n    ('2024-06', 801), \n    ('2024-07', 810), \n    ('2024-08', 826), \n    ('2024-09', 650)\n]\n\n# Extracting the months and user counts\nmonths, user_counts = zip(*data)\n\n# Creating the Plotly figure\nfig = go.Figure()\n\n# Adding a bar chart for user registrations\nfig.add_trace(go.Bar(\n    x=months,\n    y=user_counts,\n    marker_color='royalblue'\n))\n\n# Updating layout for better readability\nfig.update_layout(\n    title='Monthly User Registrations',\n    xaxis_title='Month',\n    yaxis_title='Number of Users Registered',\n    xaxis_tickangle=-45,\n    yaxis=dict(\n        title='Number of Users',\n        gridcolor='lightgrey'\n    ),\n    plot_bgcolor='white'\n)\n\n# Show the figure\nfig.show()"

<h2 style="color: red;">Warning! Running the cell below will execute the (arbitrary) python code generated by the LLM Crew. Make sure you've cheked the code before running.</h2>

In [17]:
exec(python_output_code)

In [20]:
sql_query = "SELECT strftime('%Y-%m', registration_date) AS registration_month, COUNT(*) AS registrations_count FROM users GROUP BY registration_month ORDER BY registration_month"
execute_sql.run(sql_query)

Using Tool: execute_sql


"[('2023-09', 166), ('2023-10', 862), ('2023-11', 799), ('2023-12', 842), ('2024-01', 856), ('2024-02', 808), ('2024-03', 867), ('2024-04', 799), ('2024-05', 914), ('2024-06', 801), ('2024-07', 810), ('2024-08', 826), ('2024-09', 650)]"