# ToolRAG Agent with LangChain, Granite and watsonx.ai

ToolRAG is a method that helps AI systems become more powerful and useful by combining two key abilities:

- Retrieval-Augmented Generation [RAG](https://research.ibm.com/blog/retrieval-augmented-generation-RAG): This means the AI can look up relevant information from a database or document store before answering.
- Tool use: The AI can choose and use tools (like calculators, search engines, APIs, or code execution) to solve problems.

This approach is part of a larger series on [Agent Architectures](https://github.com/ibm-granite-community/granite-agent-cookbook/blob/main/building_agents.md), which explores how to take agents from prototype to production. ToolRAG is an architecture designed for agents that need to use a large set of tools. Instead of just guessing or generating text, the AI can look things up and take action. ToolRAG is an architecture designed for agents that need to use a large set of tools. It pre-filters the available tools from the LLM's main reasoning by first using a vector database to semantically retrieve the most relevant tools for a query. This prevents overloading the LLM's context window with hundreds of tool definitions, significantly improving efficiency and performance for tool-using agents.

This notebook shows how to build a Retrieval-Augmented Generation (RAG) agent with a large, semantically-searchable toolset, powered by LangChain’s agent framework, IBM’s Granite LLM, and watsonx embeddings. You’ll see setup of the credentials, tool semantic indexing, and agent orchestration for robust research and engineering workflows. Let's build a scalable ToolRAG Agent!

## Prerequisites
- Python 3.10+ environment (e.g., Jupyter, Colab, or watsonx.ai)
- IBM watsonx.ai credentials (API key, project ID) for Granite model access.


# Steps

## Step 1. Set up your environment

While you can choose from several tools, this recipe is best suited for a Jupyter Notebook. Jupyter Notebooks are widely used within data science to combine code with various data sources such as text, images and data visualizations. 

You can run this notebook in [Colab](https://colab.research.google.com/), or download it to your system and [run the notebook locally](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Getting_Started_with_Jupyter_Locally/Getting_Started_with_Jupyter_Locally.md). 

To avoid Python package dependency conflicts, we recommend setting up a [virtual environment](https://docs.python.org/3/library/venv.html).

Note, this notebook is compatible with Python 3.12 and well as Python 3.11, the default in Colab at the time of publishing this recipe. To check your python version, you can run the `!python --version` command in a code cell.


## Step 2. Set up a watsonx.ai instance

See [Getting Started with IBM watsonx](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Getting_Started/Getting_Started_with_WatsonX.ipynb) for information on getting ready to use watsonx.ai. 

You will need three credentials from the watsonx.ai set up to add to your environment: `WATSONX_URL`, `WATSONX_APIKEY`, and `WATSONX_PROJECT_ID`.

## Step 3. Install relevant libraries and set up credentials and the Granite model

We'll need a few libraries for this recipe. We will be using LangGraph and LangChain libraries to use Granite on watsonx.ai.

In [None]:
# Install core libraries
%pip install -qU langchain langchain-ibm langchain-core langgraph grandalf

# Install RAG components (Vector Store and utilities)
%pip install -q chromadb langchain-chroma

# Install IBM specific utility for easy credentials load
%pip install -q ibm-watsonx-ai "git+https://github.com/ibm-granite-community/utils.git"

## Step 4: Authentication and model initialization 

The next step involves initialization of the watsonx LLM (used for the agent's reasoning) and watsonx Embeddings (for Tool-RAG semantic search).

**Note:** Ensure your environment variables (`WATSONX_URL`, `WATSONX_APIKEY`, `WATSONX_PROJECT_ID`) are set before running this cell.

In [None]:
import os
import uuid
import operator
from getpass import getpass
from typing import List, Dict, Any
from langchain_ibm import WatsonxLLM, WatsonxEmbeddings
from langchain_core.tools import StructuredTool, BaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage, BaseMessage
from langchain_core.documents import Document
from langchain_chroma import Chroma
from typing import TypedDict, Annotated, Sequence
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from langchain_chroma import Chroma
from ibm_granite_community.notebook_utils import get_env_var
from langchain_core.utils.utils import convert_to_secret_str
from langchain.chat_models import init_chat_model
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames

# --- Configuration ---
model = "ibm/granite-3-3-8b-instruct"

llm_params = {
    "temperature": 0,
    "max_completion_tokens": 200,
    "repetition_penalty": 1.05,
}

# --- 1. LLM Initialization (Agent's Brain) ---
llm = init_chat_model(
    model=model,
    model_provider="ibm",
    url=convert_to_secret_str(get_env_var("WATSONX_URL")),
    apikey=convert_to_secret_str(get_env_var("WATSONX_APIKEY")),
    project_id=get_env_var("WATSONX_PROJECT_ID"),
    params=llm_params,
)
print(f"LLM initialized: {model}")


# --- 2. Embeddings Initialization (Tool-RAG Indexer) ---
watsonx_embedding = WatsonxEmbeddings(
    model_id="ibm/granite-embedding-278m-multilingual",
    url=get_env_var("WATSONX_URL"),
    apikey=get_env_var("WATSONX_APIKEY"),
    project_id=get_env_var("WATSONX_PROJECT_ID"),
    params={
        EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: 3
    }
)
print("Embeddings initialized: ibm/granite-embedding-278m-multilingual")
print("Setup Complete.")

We choose `ibm/granite-embedding-278m-multilingual` because it is an IBM Granite family model, ensuring seamless integration with other watsonx components.

Its multilingual capability is a benefit for real-world scenarios, and the '278m' size provides a good balance between performance and embedding quality for semantic retrieval of tool descriptions. We use ChromaDB as the underlying vector store for tool metadata.

## Step 5: Define small tools and data structure

The ToolRAG concept are specifically designed to address the scalability problem that arises when an agent has hundreds or thousands of tools. In classic "tool-calling" architectures, the definitions of all tools must be inserted directly into the LLM's prompt, which quickly hits the context window limit and degrades performance.

For this Proof of Concept (PoC), we use a small set of 5 tools for the following reasons:
- Isolation of ToolRAG mechanics: By keeping the total tool count small, we can easily verify that the core ToolRAG retrieval mechanism is working correctly. When a query is given, we can visually confirm that the vector store is correctly selecting the top K=3 most relevant tool definitions (e.g., finance tools for a financial query), even though the LLM could technically handle all 5 tools without RAG.
- Focus on selection logic, and avoid context overflow: This setup allows us to focus on the agent's ability to select a filtered subset of tools (`get_weather_forecast`, `get_stock_price`, `convert_currency`) from the small indexed pool, and then correctly chain their execution, without introducing the complexity of context window management or large-scale tool embedding.
- Simulating a scalable environment: In a production setting, this small set of 5 tools would ideally be hundreds of other tools (e.g., 50 financial, 50 IT, 50 HR tools). The retrieval mechanism shown here is what scales to those environments.

In [None]:
def calculate_future_value(principal: float, rate: float, years: int) -> str:
    """Calculates the future value of an investment using compound interest."""
    future_value = principal * ((1 + rate) ** years)
    return f"The future value is ${future_value:.2f}."

def get_stock_price(ticker: str) -> str:
    """Fetches the current or historical stock price for a given ticker symbol (e.g., IBM)."""
    if ticker == "IBM":
        return "The current price for IBM is $313.72."
    return f"Stock price for {ticker} not found in this mock-up."

def check_developer_skill(skill: str) -> str:
    """Checks the availability of an AI developer with a specific programming or ML skill."""
    if 'python' in skill.lower() or 'langchain' in skill.lower():
        return "Yes, we have multiple experienced developers with that skill set."
    return "Developer with that specific skill is currently limited."

def convert_currency(amount: float, from_currency: str, to_currency: str) -> str:
    """Converts a monetary amount between two specified currencies (e.g., USD to EUR)."""
    if from_currency == "USD" and to_currency == "EUR":
        converted = amount * 0.92
        return f"{amount} USD is approximately {converted:.2f} EUR."
    return f"Conversion from {from_currency} to {to_currency} is not supported."

def get_weather_forecast(city: str) -> str:
    """Provides the current weather forecast for a specified city."""
    if 'boston' in city.lower():
        return "Boston's current weather is partly cloudy with a temperature of 5°C."
    return f"Weather data for {city} is not available in this mock-up."

functions = [calculate_future_value, get_stock_price, check_developer_skill, convert_currency, get_weather_forecast]

# Wrap into StructuredTool
tools: List[BaseTool] = []
for func in functions:
    tool = StructuredTool.from_function(
        func=func,
        name=func.__name__, 
        description=func.__doc__.strip()
    )
    tools.append(tool)

print("=== All Initialized Tools ===")
for i, tool in enumerate(tools):
    print(f"  Tool {i+1}: '{tool.name}'")
print(f"\nTotal tools initialized: {len(tools)}")
print("All tools ready for indexing and agent use!\n")

# tool_map for ToolNode
tool_map = {tool.name: tool for tool in tools}

## Step 6: Create tool registry and index tools

Before our agent can perform **Tool-RAG** (Tool Retrieval-Augmented Generation), we need to create a searchable index of our tools. This involves:

- **Extracting Tool Metadata**: Convert each tool's description (natural language summary) into a `Document` object, tagged with its name for later lookup.
- **Embedding & Storing**: Use watsonx embeddings to vectorize descriptions, then persist them in Chroma. This enables semantic similarity searches (e.g., "weather-related tools" matches `get_weather_forecast` via cosine similarity on embeddings).
- **Retriever Setup**: Configure a retriever with MMR (Maximal Marginal Relevance) search to fetch diverse, relevant tools – avoiding redundant results (e.g., two finance tools for a stock query).


In [None]:
# Step 1: Generate Tool Documents for Indexing (List Comprehension)

tool_docs = [
    Document(page_content=tool.description, metadata={"tool_name": tool.name})
    for tool in tools
]

# Step 2: Ingest into Chroma Vector Store
vectorstore = Chroma.from_documents(documents=tool_docs, embedding=watsonx_embedding)

# Step 3: Create Configured Retriever for Graph Integration
tool_retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3})

print(f"Indexed {len(tool_docs)} tools in vectorstore.")


## Step 7: Agent capabilities: tool selection and execution

In this section, we construct a **vanilla LangGraph agent** and this "from-scratch" approach demonstrates the core concepts below: 

- **State Management**: Using `TypedDict` for thread-safe, annotated state (messages + retrieved tools).
- **Graph Structure**: Nodes for retrieval, reasoning (LLM), and action (tools). The edges are added for flow control.
- **Tool-RAG Pattern**: Semantically retrieve a subset of tools *before* LLM invocation to handle large toolsets efficiently.

**High-Level Flow**:
1. **Retrieve Tools Node**: Uses ToolRAG to select relevant tools based on query semantics.
2. **LLM Node**: Dynamically binds *only* retrieved tools to the LLM (reduces context bloat).
3. **Conditional Edge**: Routes to tools if calls detected, else END.
4. **Tool Node**: Executes calls using `ToolNode`.
5. **Loop Back**: From tools to LLM for multi-turn reasoning.

## Step 8: Custom tool retrieval function

This retrieval function Bridges vectorstore (semantic search) to agent state (tool objects). It performs the following steps:

- Semantic Search – Query vectorstore for similar tool descriptions. Uses `.similarity_search_with_score`: Returns (doc, score) pairs; lower score = more relevant.

- Map Docs to tools – Extract metadata['tool_name'], lookup in full tools list. We perform the lookup because vectorstore holds docs (text only); and we need callable `BaseTools` for invocation.


In [None]:
def custom_retrieve_tools(query: str, limit: int = 5) -> List[BaseTool]:
    results = vectorstore.similarity_search_with_score(query, k=limit)
    tool_map_lookup = {tool.name: tool for tool in tools}
    retrieved_tools = [
        tool_map_lookup[doc.metadata['tool_name']] 
        for doc, _ in results 
        if doc.metadata['tool_name'] in tool_map_lookup
    ]
    print(f"[ToolRAG]: Retrieved {len(retrieved_tools)}: {[t.name for t in retrieved_tools]}")
    return retrieved_tools

## Step 9: Create the ToolRAG Agent


Here, we dive into the **ToolRAG pattern** (Tool Retrieval-Augmented Generation): Semantically retrieving a subset of tools before LLM binding/execution. 

This enables multi-tool queries (e.g., "Weather + Stock + Convert") by chaining retrieval with tool calls plus reasoning. 

In [None]:
# Step 1: Define the Agent State Schema
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    retrieved_tools: List[BaseTool]
print("Step 1: AgentState schema defined – Tracks messages and retrieved tools for graph flow.")

# Step 2: Define Retrieval Node (ToolRAG Integration)
def retrieve_tools_node(state: AgentState) -> AgentState:
    query = state["messages"][-1].content
    retrieved_tools = custom_retrieve_tools(query, limit=3)
    return {"messages": state["messages"], "retrieved_tools": retrieved_tools}

print("Step 2: retrieve_tools_node defined – Implements ToolRAG as first graph step after START.")

# Step 3: Define LLM Reasoning Node (Dynamic Tool Binding)
def llm_node(state: AgentState) -> AgentState:
    retrieved_tools = state["retrieved_tools"]
    bound_llm = llm.bind_tools(retrieved_tools)
    response = bound_llm.invoke(state["messages"])
    return {"messages": [response], "retrieved_tools": retrieved_tools}
    
print("Step 3: llm_node defined – Dynamically binds ToolRAG tools to LLM for efficient reasoning.")

# Step 4: Define Tool Execution Node (Using Prebuilt ToolNode)
tool_node = ToolNode(tools)

print("Step 4: tool_node created – Prebuilt executor for parallel tool calls.")

# Step 5: Define Conditional Routing (should_continue)
def should_continue(state: AgentState):
    return "tools" if state['messages'][-1].tool_calls else END

print("Step 5: should_continue defined – Routes graph flow based on tool calls.")

# Step 6: Assemble & Compile the Graph

workflow = StateGraph(state_schema=AgentState)
workflow.add_node("retrieve_tools", retrieve_tools_node)
workflow.add_node("llm", llm_node)
workflow.add_node("tools", tool_node)

workflow.set_entry_point("retrieve_tools")

workflow.add_edge("retrieve_tools", "llm")
workflow.add_conditional_edges("llm", should_continue, {"tools": "tools", END: END})
workflow.add_edge("tools", "llm")

# Step 7: Compile and return a runnable graph
graph = workflow.compile()
print("Step 6: Graph assembled & ready. Tool-RAG Agent compiled successfully!")

print(f"\n=== LangGraph ASCII Visualization ===")

graph.get_graph().print_ascii()



### Step 10: Testing the agent with a query requiring 3 tools or a subset of tools

In [None]:

thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}}

user_query = "What's the weather like in Boston, and can you check the stock price for IBM, then convert 100 USD to EUR?"
final_result = graph.invoke(
    {"messages": [HumanMessage(content=user_query)], "custom_retrieve_tools": []},  # Init empty
    config=config
)

final_message = final_result["messages"][-1]
print("\n--- Final Agent Response ---")
print(final_message.content)

## Step 11: Leveraging pre-built libraries for ToolRAG agents


We built a Tool-RAG agent from scratch using vanilla LangGraph, however there are couple of **pre-built libraries** that abstract the boilerplate (e.g., indexing, retrieval, dynamic binding). These accelerate development for production-scale agents with loads of available tools, while preserving the core pattern of semantic retrieval plus subset binding to execution.

Key Pre-Built Libraries for Tool-RAG
- [langgraph-bigtool](https://github.com/langchain-ai/langgraph-bigtool): LangGraph extension for "Big Tool" agents; automates Tool-RAG with registry-based retrieval
- [LangChain Toolkits](https://docs.langchain.com/oss/javascript/integrations/tools) (e.g., langchain-agents): Modular agent builders with RAG-infused toolkits like create_react_agent
- [LlamaIndex Tool Integration](https://developers.llamaindex.ai/python/framework/module_guides/deploying/agents/tools/): Index-based RAG for tools via VectorStoreIndex