### SQL agent with SelectAI and Langchain Agent
* https://github.com/cohere-ai/notebooks/blob/main/notebooks/agents/Vanilla_Tool_Use.ipynb
* This is a better version, improved doc

In [1]:
import requests
from IPython.display import Markdown, display

from langchain.tools import Tool

from langchain_community.chat_models import ChatOCIGenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.messages import ToolMessage

from select_ai_sql_agent import SelectAISQLAgent
from config_reader import ConfigReader

In [2]:
# SETUP
COMPARTMENT_ID = "ocid1.compartment.oc1..aaaaaaaaushuwb2evpuf7rcpl4r7ugmqoe7ekmaiik3ra3m7gec3d234eknq"

# fopr now only Cohere Command-R plus
LLM_MODEL = "cohere.command-r-plus-08-2024"
ENDPOINT = "https://inference.generativeai.eu-frankfurt-1.oci.oraclecloud.com"

DEBUG = False

# internet search
TAVILY_API_KEY = "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"
SEARCH_ENDPOINT = "https://api.tavily.com/search"

chat = ChatOCIGenAI(
    auth_type="API_KEY",
    model_id=LLM_MODEL,
    service_endpoint=ENDPOINT,
    compartment_id=COMPARTMENT_ID,
    model_kwargs={"temperature": 0, "max_tokens": 2048},
    is_stream=False,
)

# will be added as SystemMessage
preamble = """
## Task & Context
You are an interactive assistant designed to help users by answering their questions and fulfilling their requests. 
Users may ask about a wide range of topics, and your goal is to provide the most accurate and helpful response possible.
To assist you, you have access to several specialized tools that can help you research and generate answers. 
Use these tools effectively to meet the user's needs.

## Tools Available
1. **not_allowed**: Use this tool to handle requests involving operations such as DROP, DELETE, INSERT, or UPDATE in a database. These actions are prohibited.
2. **generate_sql**: Use this tool to generate SQL queries when the request requires reading data from a database.
3. **do_rag**: Use this tool to answer general informational questions through retrieval-augmented generation (RAG).
4. **do_internet_search**: Use this tool only when Internet search is explicitely required or when realtime information are required.

## Style Guide
- Always respond in clear, full sentences unless the user specifies otherwise.
- Use proper grammar and spelling in all responses.
- Strive to be concise yet comprehensive, ensuring your answers fully address the user's needs.
"""

In [3]:
# read general configuration
config = ConfigReader("./config.toml")

sql_agent = SelectAISQLAgent(config)

#### Definition of Tools

In [4]:
# Tool: Simulate RAG
def do_rag(query: str):
    """
    Responds to general information requests using an LLM model.
    Requires one parameter:
    - query (string): the question or information request.
    Returns a textual response as a string.
    """
    print("Calling do_rag...")

    # here we should add the retrieval code
    response = chat.invoke([HumanMessage(query)])

    if DEBUG:
        print(response)

    return response.content


# Tool: Generate SQL
def generate_sql(query: str):
    """
    Generates an SQL query based on the provided user request.
    Only READ (SELECT) operations are allowed.

    Requires one parameter:
    - query (string): the user request, a description of the data to extract.
    Returns the data read from database as a string.
    """
    print("Calling generate sql...")

    sql = sql_agent.generate_sql(query)

    print(f"SQL generated:\n {sql}")

    if sql_agent.check_sql(sql):
        rows = sql_agent.execute_sql(sql)
        print(f"SQL executed, passing data to LLM for answer generation...")

    return str(rows)


def not_allowed(query: str):
    """
    Generate the answer for NOT allowed requests:
    INSERT, UPDATE, DELETE, DROP

    Requires one parameter:
    - query (string): user request.
    Returns a message.
    """
    print("Calling not allowed...")

    return f"I'm sorry to inform you that your request: {query} is NOT allowed!!!"


def do_internet_search(query: str):
    """
    Generate the answer ONLY when it is requested a search on Internet
    """
    print("Calling do internet search...")

    headers = {"Content-Type": "application/json"}
    payload = {"api_key": TAVILY_API_KEY, "query": query}

    response = requests.post(SEARCH_ENDPOINT, json=payload, headers=headers)

    if response.status_code == 200:
        print("")
        print("Search results:")
        print(response.json())
        print("")

        return response.json()
    else:
        err_msg = f"Errore: {response.status_code} - {response.text}"
        print(err_msg)
        return err_msg


do_rag_tool = Tool(
    name="do_rag",
    func=do_rag,
    description="Responds to general information requests using an LLM model. Requires one parameter: query (string). Returns the model's response as a string.",
)

generate_sql_tool = Tool(
    name="generate_sql",
    func=generate_sql,
    description="Generates and execute SQL query based on the provided input. Requires one parameter: query (string). Returns retrieved data as a string.",
)

not_allowed_tool = Tool(
    name="not_allowed",
    func=not_allowed,
    description="Return the answer when an action is not allowed",
)

do_internet_search_tool = Tool(
    name="do_internet_search",
    func=do_internet_search,
    description="Return the answer when an internet search is requested",
)

# Setup tools
tools = [do_rag_tool, generate_sql_tool, not_allowed_tool, do_internet_search_tool]

# function mapped for direct access
functions_map = {
    "do_rag": do_rag_tool,
    "generate_sql": generate_sql_tool,
    "not_allowed": not_allowed_tool,
    "do_internet_search": do_internet_search_tool,
}

# Bind tools to chat model
chat_with_tools = chat.bind_tools(tools)

In [5]:
# supporting functions


# gives only the action chosen (which tool is called
def test_with_tools(query):
    messages = [SystemMessage(preamble), HumanMessage(query)]

    ai_msg = chat_with_tools.invoke(messages, is_force_single_step=False)

    print("Msg content: ", ai_msg.content)
    print("")
    # read tool calls from model output
    for call in ai_msg.tool_calls:
        print("Tool call: ", call)


def answer_with_tools(chat_with_tools, query):
    # we could also add here the chat history (memory)
    messages = [SystemMessage(preamble), HumanMessage(query)]

    ai_msg = chat_with_tools.invoke(messages, is_force_single_step=False)

    if DEBUG:
        print(ai_msg)

    if ai_msg is None:
        print("None returned chat_with_tools call...")
        return "Error in chat_with_tools call"

    messages.append(ai_msg)

    if DEBUG:
        print(ai_msg.content)
        for call in ai_msg.tool_calls:
            print("Tool call: ", call)

    for tool_call in ai_msg.tool_calls:
        selected_tool = functions_map[tool_call["name"].lower()]
        tool_output = selected_tool.invoke(tool_call["args"])
        messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"]))

    final_response = chat_with_tools.invoke(messages)

    return final_response


def print_citations(answer):
    if answer.additional_kwargs["citations"] is not None:
        print()
        print("--- Citations ---")
        for cite in answer.additional_kwargs["citations"]:
            print(cite)

## Invoking tools and passing tool outputs back to model

In [None]:
%%time
# query that needs an answer from the LLM (do_rag tool)
query = "Who is Enrico Fermi?"

answer = answer_with_tools(chat_with_tools, query)
print("")
print(answer.content)

# print_citations(answer)

In [8]:
%%time
query = """I want a list of the top 10 products sold.
For each product I want the product name, number of sales and total amount in euro sold.
Return output as markdown"""

answer = answer_with_tools(chat_with_tools, query)

print("")
display(Markdown(answer.content))

2024-12-11 16:45:35,221 - Generating SQL...


Calling generate sql...
SQL generated:
 SELECT T2.prod_name AS product_name, COUNT(*) AS num_sales, SUM(T1.amount_sold) AS total_amount_in_euro
FROM "SELAI"."SALES" T1
INNER JOIN "SELAI"."PRODUCTS" T2 ON T1.prod_id = T2.prod_id
GROUP BY T2.prod_name
ORDER BY total_amount_in_euro DESC
FETCH FIRST 10 ROWS ONLY


2024-12-11 16:45:39,301 - SQL validated. Executing...
2024-12-11 16:45:39,804 - Executed successfully. Rows fetched: 10


SQL executed, passing data to LLM for answer generation...



| Product Name | Number of Sales | Total Amount in Euro |
|---|---|---|
| Envoy Ambassador | 9591 | 15011642.52 |
| Mini DV Camcorder with 3.5" Swivel LCD | 6160 | 8314815.4 |
| 17" LCD w/built-in HDTV Tuner | 6010 | 7189171.77 |
| Home Theatre Package with DVD-Audio/Video Play | 10903 | 6691996.81 |
| 5MP Telephoto Digital Camera | 6002 | 6312268.4 |
| Envoy 256MB - 40GB | 5766 | 5635963.08 |
| 18" Flat Panel Graphics Monitor | 5205 | 5498727.81 |
| 8.3 Minitower Speaker | 7197 | 3845387.38 |
| Unix/Windows 1-user pack | 16796 | 3543725.89 |
| SIMM- 16MB PCMCIAII card | 15950 | 2572944.13 |

CPU times: user 253 ms, sys: 36.8 ms, total: 290 ms
Wall time: 24.6 s


In [9]:
query = "Please, DROP SALES"

answer = answer_with_tools(chat_with_tools, query)

print("")
print(answer.content)

Calling not allowed...

I'm sorry to inform you that your request: DROP SALES is NOT allowed!!!


In [None]:
%%time
query = "What are the latest news regarding Italy? Make a report"

answer = answer_with_tools(chat_with_tools, query)

print("")
print(answer.content)

In [10]:
%%time
query = "Analyze the data provided? Make a report"

answer = answer_with_tools(chat_with_tools, query)

print("")
print(answer.content)


I'm sorry, I don't have access to the data you would like me to analyse. Could you please provide it?
CPU times: user 25.2 ms, sys: 4.57 ms, total: 29.8 ms
Wall time: 4.46 s
