<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/AGENTIC_RAG_GEMINI_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --upgrade langchain langgraph langchain-google-genai -q

In [None]:
!pip show langchain

In [None]:
!pip install --upgrade google-generativeai==0.8.5 google-ai-generativelanguage==0.6.15 -q

In [None]:
# Install required packages
!pip install --upgrade langchain langgraph langchain-google-genai -q
!pip install --upgrade google-generativeai==0.8.5 google-ai-generativelanguage==0.6.15 -q

In [None]:
!pip show google-generativeai google-ai-generativelanguage langchain langgraph langchain-google-genai

In [5]:
# Import libraries
import os
import ast
from typing import TypedDict, List
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
from google.colab import userdata
import google.generativeai as genai
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.tools import tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END

# --- 0. SETUP: ACCESS API KEY AND FIX AUTH ERROR/MODEL CONFIG ---
try:
    from google.colab import userdata
except ImportError:
    class MockUserdata:
        def get(self, key): return os.environ.get(key)
    userdata = MockUserdata()

# Disable GCE metadata check to prevent RefreshError
os.environ['GOOGLE_CLOUD_DISABLE_GCE_CHECK'] = 'true'

# Get and configure the API Key
GEMINI_API_KEY = None
try:
    if 'GEMINI_API_KEY' not in os.environ:
        os.environ['GEMINI_API_KEY'] = userdata.get('GEMINI')

    GEMINI_API_KEY = os.environ['GEMINI_API_KEY']
    genai.configure(api_key=GEMINI_API_KEY)
    print("✅ GEMINI_API_KEY successfully loaded and configured.")
except Exception as e:
    print(f"⚠️ Could not load API key: {e}")
    print("Please ensure your secret is named 'GEMINI' or the environment variable is set.")
    if not GEMINI_API_KEY:
        raise RuntimeError("API Key is missing or failed to load. Cannot proceed.")

# Define the state for LangGraph
class AgentState(TypedDict):
    """The state of the agent in the graph."""
    messages: List[BaseMessage]

# --- 1. DEFINE THE TOOLS ---
@tool
def internal_knowledge_retriever(query: str) -> str:
    """
    Retrieves detailed, internal documentation or facts from the secure knowledge base.
    Use this tool for questions about company policy, specific projects, or proprietary data.
    """
    knowledge_base = {
        "Refund Policy": "Full refund within 30 days for unused products, 50% refund thereafter.",
        "Project Zenith": "Project Zenith is the new Q4 data initiative focused on multimodal embedding optimization. Lead: Dr. Aris Thorne.",
        "Team Structure": "The Data Science team reports to the CTO, while the Engineering team reports to the COO.",
    }
    query_lower = query.lower()
    retrieved_data = [
        f"Document: {key}. Content: {val}" for key, val in knowledge_base.items()
        if query_lower in key.lower() or any(word in val.lower() for word in query_lower.split())
    ]
    if retrieved_data:
        return "\n".join(retrieved_data)
    else:
        return "No highly relevant internal documents found for that query."

@tool
def financial_calculator(expression: str) -> str:
    """
    Performs basic arithmetic calculations. Input must be a valid mathematical expression (e.g., '1200 * 0.85' or '987 / 3 - 12.5').
    """
    # FIXED: Simplified AST validation to avoid false positives
    allowed_operators = {
        ast.Add: '+',
        ast.Sub: '-',
        ast.Mult: '*',
        ast.Div: '/',
        ast.UAdd: '+',
        ast.USub: '-'
    }

    try:
        # Parse the expression
        tree = ast.parse(expression, mode='eval')
        print(f"Parsed AST: {ast.dump(tree, indent=2)}")  # Debug: Log AST for verification

        # FIXED: Better validation that only checks actual operation nodes
        def validate_node(node):
            if isinstance(node, ast.Expression):
                return validate_node(node.body)
            elif isinstance(node, ast.Constant):
                if not isinstance(node.value, (int, float)):
                    return False, "Constants must be numbers"
                return True, None
            elif isinstance(node, ast.BinOp):
                if type(node.op) not in [ast.Add, ast.Sub, ast.Mult, ast.Div]:
                    return False, f"Disallowed binary operator: {type(node.op).__name__}"
                left_ok, left_msg = validate_node(node.left)
                if not left_ok:
                    return False, left_msg
                right_ok, right_msg = validate_node(node.right)
                if not right_ok:
                    return False, right_msg
                return True, None
            elif isinstance(node, ast.UnaryOp):
                if type(node.op) not in [ast.UAdd, ast.USub]:
                    return False, f"Disallowed unary operator: {type(node.op).__name__}"
                return validate_node(node.operand)
            else:
                return False, f"Disallowed construct: {type(node).__name__}"

        # Validate the expression tree
        is_valid, error_msg = validate_node(tree)
        if not is_valid:
            return f"Calculation Error: {error_msg}"

        # Safely evaluate the expression
        result = eval(compile(tree, '<string>', 'eval'), {"__builtins__": {}}, {})
        if not isinstance(result, (int, float)):
            return "Calculation Error: Result is not a simple numeric type."
        return f"The result of {expression} is {result}"  # Return formatted result
    except SyntaxError:
        return "Calculation Error: Invalid syntax in arithmetic expression."
    except ZeroDivisionError:
        return "Calculation Error: Division by zero is not allowed."
    except Exception as e:
        return f"Calculation Error: Unhandled error {str(e)}. Ensure it is a valid arithmetic expression."

# --- 2. ASSEMBLE THE AGENTIC RAG SYSTEM ---
def setup_agentic_rag_gemini():
    # Initialize the LLM with Gemini 2.5 Flash
    llm = ChatGoogleGenerativeAI(
        model="gemini-2.5-flash",
        temperature=0,
        google_api_key=GEMINI_API_KEY
    )
    tools = [internal_knowledge_retriever, financial_calculator]
    tool_map = {tool.name: tool for tool in tools}

    # System prompt to enforce tool usage and prevent expression reformulation
    prompt = ChatPromptTemplate.from_messages([
        ("system", (
            "You are a highly capable Agentic RAG system powered by Gemini 2.5 Flash. "
            "For factual company questions (e.g., policies, projects, team structure), ALWAYS use the 'internal_knowledge_retriever' tool. "
            "For ANY mathematical or arithmetic queries (e.g., calculations, equations), ALWAYS use the 'financial_calculator' tool and pass the expression EXACTLY as provided in the query as a string (e.g., '1200 * 0.85' or '987 / 3 - 12.5'). "
            "Do NOT reformulate or modify the arithmetic expression (e.g., do NOT change '987 / 3 - 12.5' to '987 / 3 + (-12.5)'). "
            "Do NOT attempt to answer math questions directly; ALWAYS use the 'financial_calculator' tool. "
            "Process ONE question at a time. If a user asks multiple questions in one message, use the internal_knowledge_retriever tool ONCE with the full original query."
            "After receiving tool results, synthesize the information and provide a clear final answer to the user."
        )),
        MessagesPlaceholder(variable_name="messages"),
    ])

    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(tools)

    # Define the agent node
    def agent(state: AgentState) -> AgentState:
        messages = state["messages"]
        response = llm_with_tools.invoke(prompt.invoke({"messages": messages}))
        print(f"LLM Response: {response}")  # Debug: Log LLM response
        return {"messages": messages + [response]}

    # Define the tool node - FIXED to handle multiple tool calls properly
    def call_tools(state: AgentState) -> AgentState:
        last_message = state["messages"][-1]
        new_messages = []

        if hasattr(last_message, "tool_calls") and last_message.tool_calls:
            # Process ALL tool calls but ensure we handle them properly
            for tool_call in last_message.tool_calls:
                tool_name = tool_call["name"]
                tool_args = tool_call["args"]
                tool = tool_map.get(tool_name)
                print(f"Invoking tool: {tool_name} with args: {tool_args}")  # Debug: Log tool call

                if tool:
                    try:
                        if tool_name == "financial_calculator":
                            expression = str(tool_args.get("expression", ""))
                            if not expression:
                                raise ValueError("No expression provided for financial_calculator")
                            result = tool.invoke(expression)
                        else:
                            result = tool.invoke(tool_args)
                        new_messages.append(ToolMessage(
                            content=str(result),
                            tool_call_id=tool_call["id"]
                        ))
                    except Exception as e:
                        new_messages.append(ToolMessage(
                            content=f"Tool error ({tool_name}): {str(e)}",
                            tool_call_id=tool_call["id"]
                        ))
                else:
                    new_messages.append(ToolMessage(
                        content=f"Tool error: Tool '{tool_name}' not found.",
                        tool_call_id=tool_call["id"]
                    ))
        else:
            print("No tool calls found in last message.")  # Debug: Log if no tool calls

        return {"messages": state["messages"] + new_messages}

    # Define the router - FIXED VERSION
    def router(state: AgentState) -> str:
        last_message = state["messages"][-1]

        # If the last message is from a tool, we should go back to the agent to synthesize the answer
        if isinstance(last_message, ToolMessage):
            return "agent"

        # If the last message has tool calls, we need to execute them
        if hasattr(last_message, "tool_calls") and last_message.tool_calls:
            return "tools"

        # Otherwise, we're done
        return END

    # Create the graph
    workflow = StateGraph(AgentState)
    workflow.add_node("agent", agent)
    workflow.add_node("tools", call_tools)
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges("agent", router, {"tools": "tools", END: END})
    workflow.add_conditional_edges("tools", router, {"agent": "agent", END: END})
    agent_executor = workflow.compile()

    return agent_executor

# --- 3. RUN THE DEMO ---
if __name__ == "__main__":
    if "GEMINI_API_KEY" not in os.environ or not os.environ.get("GEMINI_API_KEY"):
        print("\n❌ DEMO FAILED: API Key is missing. Please fix the setup in step 0.")
    else:
        try:
            agent_executor = setup_agentic_rag_gemini()
            print("\n" + "🌟"*20)
            print("    REAL AGENTIC RAG DEMO with GEMINI 2.5 FLASH (LangGraph)    ")
            print("🌟"*20)

            # Test 1: Triggers the RAG Tool
            query_1 = "Who leads Project Zenith, and who does the Engineering team report to?"
            print(f"\n[USER 1]: {query_1}")
            result_1 = agent_executor.invoke({"messages": [HumanMessage(content=query_1)]})
            final_message_1 = result_1['messages'][-1]
            print(f"\n[AGENT FINAL ANSWER 1]: {final_message_1.content if hasattr(final_message_1, 'content') else 'No content in final message.'}")

            print("\n" + "---"*23)

            # Test 2: Triggers the Calculator Tool
            query_2 = "What is the result of 987 divided by 3, minus 12.5?"
            print(f"\n[USER 2]: {query_2}")
            result_2 = agent_executor.invoke({"messages": [HumanMessage(content=query_2)]})
            final_message_2 = result_2['messages'][-1]
            print(f"\n[AGENT FINAL ANSWER 2]: {final_message_2.content if hasattr(final_message_2, 'content') else 'No content in final message.'}")

        except Exception as e:
            print(f"\n❌ Execution Error during agent run: {e}")
            print("This usually means a problem with your API key, configuration, or network.")

✅ GEMINI_API_KEY successfully loaded and configured.

🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟
    REAL AGENTIC RAG DEMO with GEMINI 2.5 FLASH (LangGraph)    
🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟🌟

[USER 1]: Who leads Project Zenith, and who does the Engineering team report to?
LLM Response: content='' additional_kwargs={'function_call': {'name': 'internal_knowledge_retriever', 'arguments': '{"query": "Who leads Project Zenith, and who does the Engineering team report to?"}'}} response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'gemini-2.5-flash', 'safety_ratings': [], 'grounding_metadata': {}, 'model_provider': 'google_genai'} id='lc_run--2ec65860-bf96-4f3f-bd42-bf7489aecaf2-0' tool_calls=[{'name': 'internal_knowledge_retriever', 'args': {'query': 'Who leads Project Zenith, and who does the Engineering team report to?'}, 'id': '4dfcf66f-c260-4e52-8aa7-722baf96ec1c', 'type': 'tool_call'}]
Invoking tool: internal_knowledge_retriever with args: {'query': 'W