<a href="https://colab.research.google.com/github/kinjaljoshi/ma_rag_components/blob/master/planning_and_execute.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install langchain-community langchain-experimental
!pip install --upgrade langchain

In [31]:
object_routing_registry = {
    "purchase_order": {
        "access_method": "api",
        "endpoint": "/api/purchase-orders"
    },
    "sales_order": {
        "access_method": "api",
        "endpoint": "/api/sale-orders"
    },
    "__default__": {
        "access_method": "sql"
    }
}


In [32]:
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import os
from google.colab import userdata



api_key = userdata.get('OPENAI_API_KEY')
os.environ["OPENAI_API_KEY"] = api_key

llm = ChatOpenAI(temperature=0)

def format_object_descriptions(registry):
    return "\n".join([
        f"- {k}: Accessed via {v['access_method'].upper()} at {v.get('endpoint', 'SQL engine')}."
        for k, v in registry.items()
    ])

# LLM-based router to find best object name
def get_object_for_query(query: str):
    prompt = PromptTemplate.from_template("""
You are an assistant that maps user queries to internal data objects.

Available objects:
{object_descriptions}

Determine the best matching object for the query:
"{query}"

Respond with only the exact object name from the list above, or "__default__".
""")

    chain = LLMChain(llm=llm, prompt=prompt)
    object_descriptions = format_object_descriptions(object_routing_registry)
    result = chain.run({
        "object_descriptions": object_descriptions,
        "query": query
    }).strip()
    return result if result in object_routing_registry else "__default__"


In [33]:
print(format_object_descriptions(object_routing_registry))

- purchase_order: Accessed via API at /api/purchase-orders.
- sales_order: Accessed via API at /api/sale-orders.
- __default__: Accessed via SQL at SQL engine.


In [34]:
from langchain.agents import Tool

def api_agent(query: str) -> str:
    obj = get_object_for_query(query)
    endpoint = object_routing_registry[obj]["endpoint"]
    return f"[API Agent] Called {endpoint} for: '{query}'"

def sql_agent(query: str) -> str:
    return f"[SQL Agent] Executed SQL for: '{query}'"

tools = [
    Tool(
        name="InvokeAPIAgent",
        func=api_agent,
        description="Use for questions related to purchase_order or sales order objects (via API)."
    ),
    Tool(
        name="InvokeSQLAgent",
        func=sql_agent,
        description="Use this for anything else that involves querying the database."
    )
]


In [35]:
from langchain.agents import initialize_agent, AgentType

agent = initialize_agent(
    tools=tools,
    llm=llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True
)


In [36]:
agent.run("Get all approved purchase orders for last week")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mI should use the API to retrieve purchase orders
Action: InvokeAPIAgent
Action Input: query="approved purchase orders for last week"[0m
Observation: [36;1m[1;3m[API Agent] Called /api/purchase-orders for: 'query="approved purchase orders for last week'[0m
Thought:[32;1m[1;3mI need to specify the date range for last week in the query
Action: InvokeAPIAgent
Action Input: query="approved purchase orders for last week"[0m
Observation: [36;1m[1;3m[API Agent] Called /api/purchase-orders for: 'query="approved purchase orders for last week'[0m
Thought:[32;1m[1;3mI should specify the date range in the query to get the approved purchase orders for last week
Action: InvokeAPIAgent
Action Input: query="approved purchase orders for last week with date range"[0m
Observation: [36;1m[1;3m[API Agent] Called /api/purchase-orders for: 'query="approved purchase orders for last week with date range'[0m
Thought:[32;1m[1;3mI now k

'Get all approved purchase orders for last week by specifying the date range in the query.'

In [37]:
agent.run("Get all Employess expenses for employee id 1234")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mI need to query the database to retrieve the expenses for employee id 1234.
Action: InvokeSQLAgent
Action Input: SELECT * FROM expenses WHERE employee_id = 1234[0m
Observation: [33;1m[1;3m[SQL Agent] Executed SQL for: 'SELECT * FROM expenses WHERE employee_id = 1234'[0m
Thought:[32;1m[1;3mI have retrieved the expenses for employee id 1234.
Final Answer: All expenses for employee id 1234 have been retrieved.[0m

[1m> Finished chain.[0m


'All expenses for employee id 1234 have been retrieved.'