In [104]:
# | default_exp integrations.projects.tools

In [105]:
# | exporti
import sqlite3
from typing import Any, List
from langchain.agents import tool, Tool

In [106]:
# | exporti
conn = sqlite3.connect("db.sqlite")

In [107]:
# | export
@tool
def run_sqlite_query(query: str) -> Any:  # query result
    """executes a sqlite query"""
    c = conn.cursor()
    try:
        c.execute(query)
        return c.fetchall()

    except sqlite3.OperationalError as e:
        return f"the following error occured {str(e)}"

In [108]:
run_sqlite_query("select count(*) from users")

[(2000,)]

In [109]:
# | export


def list_tables() -> List[str]:  # returns a list of table names
    """retrieves a list of tables from the sqlite database"""

    rows = []
    c = conn.cursor()
    c.execute("SELECT name from sqlite_master WHERE type = 'table';")
    rows += c.fetchall()

    return "\n".join(row[0] for row in rows if row[0] is not None)

In [110]:
list_tables()

'users\naddresses\nproducts\ncarts\norders\norder_products'

In [111]:
# |export
@tool
def describe_tables(table_names: str):  # comma separated list of table_names
    """given a list of tablenames, returns the schema of those tables"""
    if "," in table_names:
        table_names = table_names.split(",")

    elif isinstance(table_names, list):
        table_names = table_names

    elif isinstance(table_names, str):
        table_names = [table_names]

    tables = ",".join("'" + table + "'" for table in table_names)
    sql = f"SELECT sql from sqlite_master WHERE type ='table' and name in ({tables});"

    c = conn.cursor()
    rows = c.execute(sql)
    return "\n".join(row[0] for row in rows if row[0] is not None)

In [112]:
describe_tables("orders")

'CREATE TABLE orders (\n    id INTEGER PRIMARY KEY,\n    user_id INTEGER,\n    created TEXT\n    )'

In [113]:
# | hide
import nbdev

nbdev.nbdev_export()