In [None]:
import os

import google.generativeai as genai
import pandas as pd
from dotenv import load_dotenv

pd.set_option("display.max_colwidth", None)

load_dotenv()  # API key is stored in .env file

GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=GOOGLE_API_KEY)

data = pd.read_csv("../data/amazon_products.csv")
data.head(5)

In [13]:
import sqlite3


class DbClient:
    def __init__(self, uri):
        self.uri = uri
        self._connection = None
        self.cursor = None

    def get_cursor(self):
        return self._get_connection().cursor()

    def commit(self):
        self._get_connection().commit()

    def _get_connection(self):
        if self._connection is None:
            self._connection = sqlite3.connect(self.uri)
        return self._connection

    def create_products_table(self):
        query = """CREATE TABLE IF NOT EXISTS products (
                product_id VARCHAR(255) PRIMARY KEY,
                product_title VARCHAR(255) NOT NULL,
                price DECIMAL(10, 2) NOT NULL
        )"""

        cursor = self.get_cursor()
        cursor.execute(query)
        self.commit()

    def create_staff_table(self):
        query = """CREATE TABLE IF NOT EXISTS staff (
                staff_id INTEGER PRIMARY KEY AUTOINCREMENT,
                first_name VARCHAR(255) NOT NULL,
                last_name VARCHAR(255) NOT NULL
        )"""

        cursor = self.get_cursor()
        cursor.execute(query)
        self.commit()

    def create_orders_table(self):
        query = """CREATE TABLE IF NOT EXISTS orders (
                order_id INTEGER PRIMARY KEY AUTOINCREMENT,
                customer_name VARCHAR(255) NOT NULL,
                staff_id INTEGER NOT NULL,
                product_id INTEGER NOT NULL,
                FOREIGN KEY (staff_id) REFERENCES staff (staff_id),
                FOREIGN KEY (product_id) REFERENCES products (product_id)
        )"""

        cursor = self.get_cursor()
        cursor.execute(query)
        self.commit()

    def list_tables(self) -> list[str]:
        print(" - DB CALL: list_tables")

        cursor = self.get_cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()
        return [t[0] for t in tables if "sqlite" not in t[0]]

    def describe_table(self, table_name: str) -> list[tuple[str, str]]:
        print(" - DB CALL: describe_table")

        cursor = self.get_cursor()
        cursor.execute(f"PRAGMA table_info({table_name});")
        schema = cursor.fetchall()
        # [column index, column name, column type, ...]
        return [(col[1], col[2]) for col in schema]

    def execute_query(self, sql: str) -> list[list[str]]:
        print(" - DB CALL: execute_query")

        cursor = self.get_cursor()
        cursor.execute(sql)
        return cursor.fetchall()

    def setup_tables(self):
        [getattr(self, method_name)() for method_name in dir(db_client) if "create" in method_name]

    def drop_tables(self):
        cursor = self.get_cursor()
        for table_name in self.list_tables():
            cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
            self.commit()


db_client = DbClient(uri="sample.db")

In [None]:
db_client.drop_tables()
db_client.setup_tables()

SAMPLE_SIZE = 100
staff_query = """INSERT INTO staff (first_name, last_name) VALUES
  	('Alice', 'Smith'),
  	('Bob', 'Johnson'),
  	('Charlie', 'Williams')"""

orders_query = """INSERT INTO orders (customer_name, staff_id, product_id) VALUES
  	('David Lee', 1, 1),
  	('Emily Chen', 2, 2),
  	('Frank Brown', 1, 3)"""

products_query = "INSERT INTO products (product_id, product_title, price) VALUES (?,?,?)"

try:
    cursor = db_client.get_cursor()
    cursor.executemany(products_query, [(i, t, p) for i, t, p in data[["asin", "title", "price"]].values[:SAMPLE_SIZE]])
    cursor.execute(staff_query)
    cursor.execute(orders_query)
    db_client.commit()
except Exception as e:
    print(e)

In [18]:
from google.api_core import retry

retry_policy = {"retry": retry.Retry(predicate=retry.if_transient_error)}

db_tools = [db_client.list_tables, db_client.describe_table, db_client.execute_query]

instruction = """You are a helpful chatbot that can interact with an SQL database for a computer
store. You will take the users questions and turn them into SQL queries using the tools
available. Once you have the information you need, you will answer the user's question using
the data returned. Use list_tables to see what tables are present, describe_table to understand
the schema, and execute_query to issue an SQL SELECT query."""

model = genai.GenerativeModel("models/gemini-1.5-flash-latest", tools=db_tools, system_instruction=instruction)

chat = model.start_chat(enable_automatic_function_calling=True)

In [None]:
resp = chat.send_message("What is the cheapest product?", request_options=retry_policy)
print(resp.text)

In [None]:
resp = chat.send_message("and how much is it?", request_options=retry_policy)
print(resp.text)

In [None]:
import textwrap


def print_chat_turns(chat):
    """Prints out each turn in the chat history, including function calls and responses."""
    for event in chat.history:
        print(f"{event.role.capitalize()}:")

        for part in event.parts:
            if txt := part.text:
                print(f'  "{txt}"')
            elif fn := part.function_call:
                args = ", ".join(f"{key}={val}" for key, val in fn.args.items())
                print(f"  Function call: {fn.name}({args})")
            elif resp := part.function_response:
                print("  Function response:")
                print(textwrap.indent(str(resp), "    "))

        print()


print_chat_turns(chat)