In [None]:
!apt-get install graphviz libgraphviz-dev pkg-config

In [None]:
!pip install -qU \
    datasets==2.19.1 \
    langchain-pinecone==0.1.1 \
    langchain-openai==0.1.9 \
    langchain==0.2.5 \
    langchain-core==0.2.9 \
    langgraph==0.1.1 \
    semantic-router==0.0.48 \
    serpapi==0.1.5 \
    google-search-results==2.4.2 \
    pygraphviz==1.12 \
    fsspec # for visualizing

In [None]:
pip install google-generativeai langchain-google-genai pyowm langchain_community

In [None]:
pip install numpy==2.2.0

In [None]:
import os
import getpass

GEMINI_API_KEY = getpass.getpass("Enter Gemini API Key: ")
os.environ["GOOGLE_API_KEY"] = GEMINI_API_KEY

PC_API_KEY = getpass.getpass("Enter Pinecone API Key: ")
os.environ["PC_API_KEY"] = PC_API_KEY

serpapi_key=getpass.getpass("Enter SerpAPI key: ")

OPENWEATHERMAP_API_KEY = getpass.getpass("Enter OpenWeatherMap API Key: ")
os.environ["OPENWEATHERMAP_API_KEY"]=OPENWEATHERMAP_API_KEY

Enter Gemini API Key: ··········
Enter Pinecone API Key: ··········
Enter SerpAPI key: ··········
Enter OpenWeatherMap API Key: ··········


In [None]:

from datasets import load_dataset

dataset = load_dataset("jamescalam/ai-arxiv2-semantic-chunks", split="train")

import google.generativeai as genai
from langchain_google_genai import GoogleGenerativeAIEmbeddings

genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")

# Use us-central1 to match the quota region from the error
encoder = GoogleGenerativeAIEmbeddings(
    model="models/embedding-001",
    api_key=os.getenv("GOOGLE_API_KEY"),
    quota_location="us-central1"  # Changed from asia-east1
)

from pinecone import Pinecone

pc = Pinecone(api_key=PC_API_KEY)

from pinecone import ServerlessSpec

spec = ServerlessSpec(
    cloud="aws", region="us-east-1"  # us-east-1, us-west-2
)

import time

index_name="gemini-1-5-flash-research-agent"

# Get embedding dimension
dims = len(encoder.embed_query("test"))

if index_name not in pc.list_indexes().names():
    pc.create_index(
        name=index_name,
        dimension=dims,
        metric="dotproduct",
        spec=spec,
    )
    while not pc.describe_index(index_name).status['ready']:
      time.sleep(1)

index = pc.Index(index_name)
time.sleep(1)

# Check if index already has data
stats = index.describe_index_stats()
print(f"Index stats: {stats}")

# Skip embedding if index already has data (for testing)
if stats['total_vector_count'] > 0:
    print(f"Index already contains {stats['total_vector_count']} vectors. Skipping embedding process.")
    skip_embedding = True
else:
    skip_embedding = False

# Only run embedding process if index is empty
if not skip_embedding:
    from tqdm.auto import tqdm
    import time

    # For faster testing, you can use a smaller dataset
    data=dataset.to_pandas().iloc[:100]

    # Reduce batch size to avoid rate limits
    batch_size=32  # Reduced from 128

    print(f"Processing {len(data)} documents in batches of {batch_size}")
    print("This will take some time due to rate limiting...")

    for i in tqdm(range(0, len(data), batch_size)):
        i_end = min(len(data), i+batch_size)
        batch = data[i:i_end].to_dict(orient="records")
        metadata=[{"title":r["title"],
                   "content":r["content"],
                   "arxiv_id":r["arxiv_id"],
                   "references":r["references"].tolist()
                   } for r in batch]
        ids=[r["id"] for r in batch]

        # Process embeddings with rate limiting
        embeds = []
        for j, record in enumerate(batch):
            try:
                embed = encoder.embed_query(record["content"])
                embeds.append(embed)

                # Add small delay between individual embedding calls
                if j < len(batch) - 1:  # Don't delay after the last item
                    time.sleep(0.1)  # 100ms delay between calls

            except Exception as e:
                print(f"Error embedding content for ID {record['id']}: {e}")
                # Use a zero vector as fallback
                embeds.append([0.0] * dims)

        # Upsert to Pinecone
        try:
            index.upsert(vectors=zip(ids, embeds, metadata))
        except Exception as e:
            print(f"Error upserting batch {i//batch_size + 1}: {e}")

        # Add delay between batches to respect rate limits
        # With 32 items per batch and 150 requests/minute limit, we need ~13 second delays
        if i + batch_size < len(data):  # Don't delay after the last batch
            print(f"Completed batch {i//batch_size + 1}, waiting 15 seconds...")
            time.sleep(15)  # 15 second delay between batches
else:
    print("Using existing embeddings in the index.")

from typing import TypedDict, Annotated, List, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
import operator

class AgentState(TypedDict):
    input: str
    chat_history: list[BaseMessage]
    intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]

from serpapi import GoogleSearch

serpapi_params = {
    "engine": "google",
    "api_key": serpapi_key
}

import requests
import re

abstract_pattern = re.compile(
    r"<blockquote class=\"abstract mathjax\">\n[ ]*<span class=\"abstract-title\">Abstract:</span>(.*?)<\/blockquote>",
    re.DOTALL,
)

from langchain_core.tools import tool

@tool("fetch_arxiv")
def fetch_arxiv(arxiv_id: str) -> str:
    """Gets the abstract from an ArXiv paper given the arxiv ID. Useful for
    finding high-level context about a specific paper."""

    # get paper page in html
    res = requests.get(
        f"https://export.arxiv.org/abs/{arxiv_id}"
    )
    # search html for abstract
    re_match = abstract_pattern.search(res.text)
    # return abstract text
    return re_match.group(1).strip()

@tool("web_search")
def web_search(query: str):
    """Finds general knowledge information using Google search. Can also be used
    to augment more 'general' knowledge to a previous specialist query."""
    search = GoogleSearch({
        **serpapi_params,
        "q": query,
        "num": 5
    })
    results = search.get_dict()["organic_results"]
    contexts = "\n---\n".join(
        ["\n".join([x["title"], x["snippet"], x["link"]]) for x in results]
    )
    return contexts

from langchain_core.tools import tool
from langchain_community.tools.openweathermap import OpenWeatherMapQueryRun

@tool("weather_agent")
def weather_agent(location: str) -> str:
    """Fetches weather information for a given location."""
    context = OpenWeatherMapQueryRun(api_key=os.environ["OPENWEATHERMAP_API_KEY"])
    return f"The current weather in {location} is {context.run(location)}"

def format_rag_contexts(matches: list):
    contexts = []
    for x in matches:
        text = (
            f"Title: {x['metadata']['title']}\n"
            f"Content: {x['metadata']['content']}\n"
            f"ArXiv ID: {x['metadata']['arxiv_id']}\n"
            f"Related Papers: {x['metadata']['references']}\n"
        )
        contexts.append(text)
    context_str = "\n---\n".join(contexts)
    return context_str

@tool("rag_search_filter")
def rag_search_filter(query: str, arxiv_id: str) -> str:
    """Finds information from our ArXiv database using a natural language query
    and a specific ArXiv ID. Allows us to learn more details about a specific paper."""
    xq = encoder.embed_query(query)
    xc = index.query(vector=xq, top_k=6, include_metadata=True, filter={"arxiv_id": arxiv_id})
    context_str = format_rag_contexts(xc["matches"])
    return context_str

@tool("rag_search")
def rag_search(query: str) -> str:
    """Finds specialist information on AI using a natural language query."""
    xq = encoder.embed_query(query)
    xc = index.query(vector=xq, top_k=2, include_metadata=True)
    context_str = format_rag_contexts(xc["matches"])
    return context_str

@tool("final_answer")
def final_answer(
    introduction: str,
    research_steps: str,
    main_body: str,
    conclusion: str,
    sources: str
):
    """Returns a natural language response to the user in the form of a research
    report. There are several sections to this report, those are:
    - `introduction`: a short paragraph introducing the user's question and the
    topic we are researching.
    - `research_steps`: a few bullet points explaining the steps that were taken
    to research your report.
    - `main_body`: this is where the bulk of high quality and concise
    information that answers the user's question belongs. It is 3-4 paragraphs
    long in length.
    - `conclusion`: this is a short single paragraph conclusion providing a
    concise but sophisticated view on what was found.
    - `sources`: a bulletpoint list provided detailed sources for all information
    referenced during the research process
    """
    if type(research_steps) is list:
        research_steps = "\n".join([f"- {r}" for r in research_steps])
    if type(sources) is list:
        sources = "\n".join([f"- {s}" for s in sources])
    return ""

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

# Updated system prompt for better control
system_prompt = """You are the oracle, the great AI decision maker.
Given the user's query you must decide what to do with it based on the
list of tools provided to you.

CRITICAL RULES:
1. NEVER use the same tool more than ONCE with the same or similar input
2. For queries about "news and weather":
   - Use web_search ONCE for news information
   - Use weather_agent ONCE for weather information
   - Then use final_answer to provide the response
3. If you see a tool has been used in the scratchpad, DO NOT use it again
4. Always progress toward final_answer - do not get stuck in loops
5. If you have gathered enough information, use final_answer immediately

Current available tools will be filtered to prevent overuse. Choose the most appropriate tool that hasn't been used yet."""

prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    MessagesPlaceholder(variable_name="chat_history"),
    ("user", "{input}"),
    ("assistant", "scratchpad: {scratchpad}"),
])

from langchain_core.messages import ToolCall, ToolMessage
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    api_key=GEMINI_API_KEY,
    temperature=0 # Changed from asia-east1 for consistency
)

# Define tools based on query type
def get_tools_for_query(query: str):
    """Return appropriate tools based on the query"""
    query_lower = query.lower()

    # For news and weather queries, only return those tools
    if any(word in query_lower for word in ["news", "weather"]) and \
       ("news" in query_lower and "weather" in query_lower):
        return [web_search, weather_agent, final_answer]

    # For other queries, return all tools
    return [rag_search_filter, rag_search, fetch_arxiv, web_search, weather_agent, final_answer]

# define a function to transform intermediate_steps from list
# of AgentAction to scratchpad string
def create_scratchpad(intermediate_steps: list[AgentAction]):
    research_steps = []
    for i, action in enumerate(intermediate_steps):
        if action.log != "TBD":
            # this was the ToolExecution
            research_steps.append(
                f"Tool: {action.tool}, input: {action.tool_input}\n"
                f"Output: {action.log}"
            )
    return "\n---\n".join(research_steps)

def run_oracle(state: dict):
    print("run_oracle")
    print(f"intermediate_steps: {len(state['intermediate_steps'])}")

    # Hard limit to prevent infinite recursion
    if len(state['intermediate_steps']) >= 5:
        print("Maximum steps reached, forcing final_answer")
        action_out = AgentAction(
            tool="final_answer",
            tool_input={
                "introduction": "Here's the available information based on the research conducted.",
                "research_steps": "Multiple research steps were taken to gather information.",
                "main_body": "\n\n".join([step.log for step in state["intermediate_steps"] if step.log != "TBD"]),
                "conclusion": "This concludes the research based on available data.",
                "sources": "Various tools and APIs were used to gather this information."
            },
            log="TBD"
        )
        return {"intermediate_steps": [action_out]}

    # Track used tools
    used_tools = [step.tool for step in state["intermediate_steps"]]
    tool_counts = {}
    for tool in used_tools:
        tool_counts[tool] = tool_counts.get(tool, 0) + 1

    print(f"Used tools: {used_tools}")
    print(f"Tool counts: {tool_counts}")

    # Prevent any tool from being used more than twice
    for tool, count in tool_counts.items():
        if count >= 2:
            print(f"Tool {tool} used {count} times, limiting further use")

    # Check if this is a news and weather query
    query_lower = state["input"].lower()
    is_news_weather_query = ("news" in query_lower and "weather" in query_lower)

    if is_news_weather_query:
        # Force progression: if weather_agent used but no web_search, force web_search
        if "weather_agent" in used_tools and "web_search" not in used_tools:
            print("Weather tool used, forcing web_search for news")
            action_out = AgentAction(
                tool="web_search",
                tool_input={"query": "latest news Chennai today"},
                log="TBD"
            )
            return {"intermediate_steps": [action_out]}

        # If both tools used, go to final answer
        if "weather_agent" in used_tools and "web_search" in used_tools:
            print("Both news and weather tools used, forcing final_answer")
            action_out = AgentAction(
                tool="final_answer",
                tool_input={
                    "introduction": "Here's the requested news and weather information for Chennai.",
                    "research_steps": "Used web search for news and weather agent for weather data.",
                    "main_body": "\n\n".join([step.log for step in state["intermediate_steps"] if step.log != "TBD"]),
                    "conclusion": "This concludes the news and weather update for Chennai.",
                    "sources": "Web search results and OpenWeatherMap API"
                },
                log="TBD"
            )
            return {"intermediate_steps": [action_out]}

    # Get appropriate tools for this query
    available_tools = get_tools_for_query(state["input"])

    # Filter out tools that have been used too many times
    filtered_tools = []
    for tool in available_tools:
        if tool_counts.get(tool.name, 0) < 2:
            filtered_tools.append(tool)

    # If no tools available, force final_answer
    if not filtered_tools or len(filtered_tools) == 1 and filtered_tools[0].name == "final_answer":
        print("No more tools available, forcing final_answer")
        action_out = AgentAction(
            tool="final_answer",
            tool_input={
                "introduction": "Here's the information gathered so far.",
                "research_steps": "Research was conducted using available tools.",
                "main_body": "\n\n".join([step.log for step in state["intermediate_steps"] if step.log != "TBD"]),
                "conclusion": "This concludes the available information.",
                "sources": "Multiple sources were consulted."
            },
            log="TBD"
        )
        return {"intermediate_steps": [action_out]}

    print(f"Available tools after filtering: {[t.name for t in filtered_tools]}")

    # Create oracle with filtered tools
    oracle = (
        {
            "input": lambda x: x["input"],
            "chat_history": lambda x: x["chat_history"],
            "scratchpad": lambda x: create_scratchpad(
                intermediate_steps=x["intermediate_steps"]
            ),
        }
        | prompt
        | llm.bind_tools(filtered_tools, tool_choice="any")
    )

    out = oracle.invoke(state)

    if not out.tool_calls:
        print("No tool calls returned, forcing final_answer")
        action_out = AgentAction(
            tool="final_answer",
            tool_input={
                "introduction": "Unable to process the query properly.",
                "research_steps": "No specific research steps were taken.",
                "main_body": "The system encountered an issue processing your request.",
                "conclusion": "Please try rephrasing your query.",
                "sources": "None"
            },
            log="TBD"
        )
        return {"intermediate_steps": [action_out]}

    tool_name = out.tool_calls[0]["name"]
    tool_args = out.tool_calls[0]["args"]

    # Additional check: if tool was used too many times, force final_answer
    if tool_counts.get(tool_name, 0) >= 2:
        print(f"Tool {tool_name} already used too many times, forcing final_answer")
        action_out = AgentAction(
            tool="final_answer",
            tool_input={
                "introduction": "Research complete based on available information.",
                "research_steps": "Multiple tools were used to gather information.",
                "main_body": "\n\n".join([step.log for step in state["intermediate_steps"] if step.log != "TBD"]),
                "conclusion": "This concludes the research.",
                "sources": "Various sources were consulted."
            },
            log="TBD"
        )
        return {"intermediate_steps": [action_out]}

    action_out = AgentAction(
        tool=tool_name,
        tool_input=tool_args,
        log="TBD"
    )

    return {"intermediate_steps": [action_out]}

def router(state: dict):
    # return the tool name to use
    if isinstance(state["intermediate_steps"], list) and state["intermediate_steps"]:
        tool_name = state["intermediate_steps"][-1].tool
        print(f"Router directing to: {tool_name}")
        return tool_name
    else:
        # if we output bad format go to final answer
        print("Router invalid format, going to final_answer")
        return "final_answer"

tool_str_to_func = {
    "rag_search_filter": rag_search_filter,
    "rag_search": rag_search,
    "fetch_arxiv": fetch_arxiv,
    "web_search": web_search,
    "weather_agent": weather_agent,
    "final_answer": final_answer
}

def run_tool(state: dict):
    # use this as helper function so we repeat less code
    tool_name = state["intermediate_steps"][-1].tool
    tool_args = state["intermediate_steps"][-1].tool_input
    print(f"{tool_name}.invoke(input={tool_args})")

    # run tool
    try:
        out = tool_str_to_func[tool_name].invoke(input=tool_args)
        action_out = AgentAction(
            tool=tool_name,
            tool_input=tool_args,
            log=str(out)
        )
    except Exception as e:
        print(f"Error running tool {tool_name}: {e}")
        action_out = AgentAction(
            tool=tool_name,
            tool_input=tool_args,
            log=f"Error: {str(e)}"
        )

    return {"intermediate_steps": [action_out]}

from langgraph.graph import StateGraph, END

graph = StateGraph(AgentState)

graph.add_node("oracle", run_oracle)
graph.add_node("rag_search_filter", run_tool)
graph.add_node("rag_search", run_tool)
graph.add_node("fetch_arxiv", run_tool)
graph.add_node("web_search", run_tool)
graph.add_node("final_answer", run_tool)
graph.add_node("weather_agent", run_tool)

graph.set_entry_point("oracle")

graph.add_conditional_edges(
    source="oracle",  # where in graph to start
    path=router,  # function to determine which node is called
)

# create edges from each tool back to the oracle
for tool_name in tool_str_to_func.keys():
    if tool_name != "final_answer":
        graph.add_edge(tool_name, "oracle")

graph.add_edge("final_answer", END)

runnable = graph.compile()

def build_report(output: dict):
    research_steps = output["research_steps"]
    if type(research_steps) is list:
        research_steps = "\n".join([f"- {r}" for r in research_steps])
    sources = output["sources"]
    if type(sources) is list:
        sources = "\n".join([f"- {s}" for s in sources])
    return f"""
INTRODUCTION
------------
{output["introduction"]}

RESEARCH STEPS
--------------
{research_steps}

REPORT
------
{output["main_body"]}

CONCLUSION
----------
{output["conclusion"]}

SOURCES
-------
{sources}
"""

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading data:   0%|          | 0.00/253M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/209760 [00:00<?, ? examples/s]

Index stats: {'dimension': 768,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 656}},
 'total_vector_count': 656}
Index already contains 656 vectors. Skipping embedding process.
Using existing embeddings in the index.



For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from langgraph.pregel import Channel, Pregel


In [None]:
out = runnable.invoke({
    "input": "send me news and weather in Chennai",
    "chat_history": [],
})

print(build_report(
    output=out["intermediate_steps"][-1].tool_input
))

run_oracle
intermediate_steps: 0
Used tools: []
Tool counts: {}
Available tools after filtering: ['web_search', 'weather_agent', 'final_answer']
Router directing to: weather_agent
weather_agent.invoke(input={'location': 'Chennai'})
run_oracle
intermediate_steps: 2
Used tools: ['weather_agent', 'weather_agent']
Tool counts: {'weather_agent': 2}
Tool weather_agent used 2 times, limiting further use
Weather tool used, forcing web_search for news
Router directing to: web_search
web_search.invoke(input={'query': 'latest news Chennai today'})
run_oracle
intermediate_steps: 4
Used tools: ['weather_agent', 'weather_agent', 'web_search', 'web_search']
Tool counts: {'weather_agent': 2, 'web_search': 2}
Tool weather_agent used 2 times, limiting further use
Tool web_search used 2 times, limiting further use
Both news and weather tools used, forcing final_answer
Router directing to: final_answer
final_answer.invoke(input={'introduction': "Here's the requested news and weather information for Chenna

In [None]:
out = runnable.invoke({
    "input": "tell me about RAG",
    "chat_history": [],
})

print(build_report(
    output=out["intermediate_steps"][-1].tool_input
))

run_oracle
intermediate_steps: 0
Used tools: []
Tool counts: {}
Available tools after filtering: ['rag_search_filter', 'rag_search', 'fetch_arxiv', 'web_search', 'weather_agent', 'final_answer']
Router directing to: rag_search
rag_search.invoke(input={'query': 'What is RAG?'})
run_oracle
intermediate_steps: 2
Used tools: ['rag_search', 'rag_search']
Tool counts: {'rag_search': 2}
Tool rag_search used 2 times, limiting further use
Available tools after filtering: ['rag_search_filter', 'fetch_arxiv', 'web_search', 'weather_agent', 'final_answer']
Router directing to: rag_search_filter
rag_search_filter.invoke(input={'arxiv_id': '2311.04072', 'query': 'What is RAG?'})
run_oracle
intermediate_steps: 4
Used tools: ['rag_search', 'rag_search', 'rag_search_filter', 'rag_search_filter']
Tool counts: {'rag_search': 2, 'rag_search_filter': 2}
Tool rag_search used 2 times, limiting further use
Tool rag_search_filter used 2 times, limiting further use
Available tools after filtering: ['fetch_arxiv