In [65]:
# Importación de librerias
!{sys.executable} -m pip install langchain-community langchain psycopg2 faiss-gpu
from langchain_community.llms.ollama import Ollama
from langchain_community.utilities import SQLDatabase
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool, InfoSQLDatabaseTool, ListSQLDatabaseTool, QuerySQLCheckerTool
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate, ChatPromptTemplate
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain.agents import AgentExecutor, create_react_agent



In [None]:
llm = Ollama(base_url="http://localhost:11434",model = 'gemma2:2b')
database = SQLDatabase.from_uri("postgresql://postgres:postgres@localhost:5432/chinook_serial")

In [None]:
#Dynamic few shots
examples = [
    {
        "input": "How many customers are from Canada?",
        "query": "SELECT COUNT(*) FROM public.customer c WHERE upper(c.country) = 'CANADA';"
    },
    {
        "input": "Who are the top 5 customers by their purchases?",
        "query": "SELECT c.id, c.first_name, c.last_name, SUM(i.total) AS total FROM public.customer c INNER JOIN public.invoice i ON c.id = i.customer_id GROUP BY c.id ORDER BY total DESC LIMIT 5;",
    },
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM public.album a INNER JOIN public.artist a2 on a2.id = a.artist_id WHERE upper(a2.name) = 'AC/DC';",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM public.track t INNER JOIN public.genre g on t.genre_id = g.id WHERE upper(g.name) = 'ROCK';",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(t.milliseconds) FROM public.track t;",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM public.track t WHERE t.album_id = 5;",
    },
    {
        "input": "Find the total number of Albums available.",
        "query": "SELECT COUNT(*) FROM public.album a;",
    },
    {
        "input": "List the number of customers group by each country",
        "query": "SELECT c.country, COUNT(*) FROM public.customer c GROUP BY c.country;",
    }
]
embeddings = (
    OllamaEmbeddings(model = "gemma2:2b")
)
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    FAISS,
    k=3,
    input_keys=["input"],
)
sql_db_query =  QuerySQLDataBaseTool(db = database)
sql_db_schema =  InfoSQLDatabaseTool(db = database)
sql_db_list_tables =  ListSQLDatabaseTool(db = database)
sql_db_query_checker = QuerySQLCheckerTool(db = database, llm = llm)

tools = [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker]

In [None]:
#Prompts
system_prefix = """
Answer the following questions as best you can. You have access to the following tools:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Final Thought: I now know the final answer
Final Answer: the final answer to the original input question

Here are some examples of user inputs and their corresponding SQL queries:

"""

suffix = """
Begin!

Question: {input}
Thought:{agent_scratchpad}
"""
dynamic_few_shot_prompt_template = FewShotPromptTemplate(
    example_selector = example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input"],
    prefix=system_prefix,
    suffix=suffix
)
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=dynamic_few_shot_prompt_template),
    ]
)

In [None]:
agent = create_react_agent(llm, tools, full_prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)

In [None]:
response = agent_executor.invoke({"input": "List all customers from USA"})

In [None]:
response