### SQL Agent running functions

In [9]:
import sqlite3
import google.generativeai as genai
import os
from google.api_core import retry
import logging

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


In [10]:
key = os.getenv('GOOGLE_API_KEY')
genai.configure(api_key=key)

In [11]:
db_file = "chinook.db"
db_example = sqlite3.connect(db_file)

In [12]:
def list_tables() -> list[str]:
    """
    Retrieves a list of all table names in the connected SQLite database.
    Returns: list[str]: A list containing the names of all tables in the database.
    """
 
    # Log message
    print(' - DB calling function: list_tables')

    # Create a cursor object to interact with the database
    cursor = db_example.cursor()

    # Execute a SQL query to fetch the names of all tables in the database
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")

    # Fetch all results and return the table names as a list of strings
    tables = cursor.fetchall()
    
    return [table[0] for table in tables]

In [13]:
def print_schema(table_name: str) -> list[tuple[str, str]]:
    """
    Retrieves and prints the schema (column names and types) of a specified table.
    Args: table_name (str): The name of the table whose schema is to be retrieved.
    Returns: list[tuple[str, str]]: A list of tuples containing column names and their corresponding data types.
    """
    # Log message
    print(' - DB calling function: print_schema')

    # Create a cursor object to interact with the database
    cursor = db_example.cursor()

    # Execute a SQL PRAGMA statement to get metadata about the table's columns
    cursor.execute(f"PRAGMA table_info({table_name});")

    # Fetch all results. Each result contains metadata for a column
    schema = cursor.fetchall()

    # Return a list of tuples with each column's name (index 1) and type (index 2).
    return [(col[1], col[2]) for col in schema]

In [14]:
def execute_query(sql: str) -> list[list[str]]:
    """
    Executes a given SQL query and returns the results.

    Args: sql (str): The SQL query to be executed.

    Returns: list[list[str]]: A list of rows where each row is a list of strings representing the results.
                         Returns an error message as a string if an exception occurs.
    """

    # Log message
    print(' - DB CALL: execute_query')

    try:
        # Create a cursor object
        cursor = db_example.cursor()
        
        # Execute the provided SQL query
        cursor.execute(sql)
        
        # Fetch and return all rows resulting from the query execution
        return cursor.fetchall()

    except Exception as e:
        # Return a formatted error message if an exception occurs. To be used by the agent.
        return f"There was an error when trying to run the query used as input. " \
               f"Please update the query using a different approach and try again. Error: {e}"

In [15]:
# TBinding all functions together
db_tools = [list_tables, print_schema, execute_query]

instruction = """You are a helpful chatbot that can interact with an SQL database to attend to users requests. 
You will take the users questions and turn them into SQL queries using the tools
available. Once you have the information you need, you will address the user's request 
(responding to questions or updating the database when requested to).
Use list_tables to see what tables are present, print_schema to understand
the schema, and execute_query to issue SQL queries. Anytime before using any query that is not a SELECT query (example UPDATE or DELETE queries), 
you should output the query you intend to run and ask the user's permission to run it. 

Please add you final query as a string to the output of your answer.

NOTE: The function execute_query was created to print out the error when the query is invalid. If you get a string saying the query is wrong, 
please use list_tables, print_schema function, create a new query using a different approach and try again.

"""

model = genai.GenerativeModel(
    "models/gemini-1.5-flash-latest", tools=db_tools, system_instruction=instruction
)

retry_policy = {"retry": retry.Retry(predicate=retry.if_transient_error)}



In [16]:
# Start a chat with automatic function calling enabled.
chat = model.start_chat(enable_automatic_function_calling=True)

#### Most simple query

In [17]:
resp = chat.send_message("How many records there are in the albums table?", request_options=retry_policy)
print(resp.text)

 - DB CALL: execute_query
There are 347 records in the albums table.

Final query: `SELECT COUNT(*) FROM albums`



#### Requires a join or sub-query

In [18]:
resp = chat.send_message("How many albums does the artist Black Sabbath has?", request_options=retry_policy)
print(resp.text)

 - DB CALL: execute_query
 - DB calling function: list_tables
 - DB calling function: print_schema
 - DB calling function: print_schema
 - DB CALL: execute_query
Black Sabbath has 2 albums in the database.

Final query: `SELECT COUNT(T1.AlbumId) FROM albums AS T1 INNER JOIN artists AS T2 ON T1.ArtistId  =  T2.ArtistId WHERE T2.Name = 'Black Sabbath'`



#### Requires to connect 3 tables

In [19]:
resp = chat.send_message("How many different genres does black sabath have?", request_options=retry_policy)
print(resp.text)

 - DB calling function: print_schema
 - DB calling function: print_schema
 - DB CALL: execute_query
Black Sabbath has 1 distinct genre in the database.

Final query: `SELECT COUNT(DISTINCT T3.Name) FROM artists AS T1 INNER JOIN albums AS T2 ON T1.ArtistId = T2.ArtistId INNER JOIN tracks AS T4 ON T2.AlbumId = T4.AlbumId INNER JOIN genres AS T3 ON T4.GenreId = T3.GenreId WHERE T1.Name = 'Black Sabbath'`



In [20]:
resp = chat.send_message("How about Apocalyptica, how many different genres do they have?", request_options=retry_policy)
print(resp.text)


 - DB CALL: execute_query
 - DB CALL: execute_query
Apocalyptica has 1 distinct genre in the database.

Final query: `SELECT COUNT(DISTINCT T3.Name) FROM artists AS T1 INNER JOIN albums AS T2 ON T1.ArtistId = T2.ArtistId INNER JOIN tracks AS T4 ON T2.AlbumId = T4.AlbumId INNER JOIN genres AS T3 ON T4.GenreId = T3.GenreId WHERE T1.Name = 'Apocalyptica'`



#### General questions
##### Questions that would require some thinking

In [21]:
resp = chat.send_message("Which artist has sold the most albuns?", request_options=retry_policy)
print(resp.text)

 - DB calling function: print_schema
 - DB calling function: print_schema
 - DB CALL: execute_query
Iron Maiden has sold the most albums with a total of 140 albums sold.

Final query: `SELECT T1.Name, SUM(T4.Quantity) AS TotalAlbumsSold FROM artists AS T1 INNER JOIN albums AS T2 ON T1.ArtistId = T2.ArtistId INNER JOIN tracks AS T3 ON T2.AlbumId = T3.AlbumId INNER JOIN invoice_items AS T4 ON T3.TrackId = T4.TrackId GROUP BY T1.Name ORDER BY TotalAlbumsSold DESC LIMIT 1`



In [22]:
resp = chat.send_message("Which artist has the biggest total sales?", request_options=retry_policy)
print(resp.text)

 - DB CALL: execute_query
Iron Maiden has the biggest total sales, with a total of 128.86.

Final query: `SELECT T1.Name,SUM(T3.Total) FROM artists AS T1 INNER JOIN albums AS T2 ON T1.ArtistId = T2.ArtistId INNER JOIN invoices AS T3 ON T2.AlbumId = T3.InvoiceId GROUP BY T1.Name ORDER BY SUM(T3.Total) DESC LIMIT 1`




In [23]:
resp = chat.send_message("Which employee is responsible for the highest total sales for Iron Maiden?", request_options=retry_policy)
print(resp.text)

 - DB CALL: execute_query
 - DB CALL: execute_query
 - DB CALL: execute_query
Margaret Park is the employee responsible for the highest total sales for Iron Maiden, with a total of 788.04.

Final query: `SELECT T6.FirstName, T6.LastName, SUM(T4.Total) FROM artists AS T1 INNER JOIN albums AS T2 ON T1.ArtistId = T2.ArtistId INNER JOIN tracks AS T3 ON T2.AlbumId = T3.AlbumId INNER JOIN invoice_items AS T5 ON T3.TrackId = T5.TrackId INNER JOIN invoices AS T4 ON T5.InvoiceId = T4.InvoiceId INNER JOIN customers AS T7 ON T4.CustomerId = T7.CustomerId INNER JOIN employees AS T6 ON T7.SupportRepId = T6.EmployeeId WHERE T1.Name = 'Iron Maiden' GROUP BY T6.EmployeeId ORDER BY SUM(T4.Total) DESC LIMIT 1`

