## Building a Arxiv Research Agent with Pinecone and Llamaindex Workflows

Follow along in this notebook to use Pinecone and Llamaindex workflows to learn agentic RAG.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/generation/llama-index/llama-index-research-agent-condensed.ipynb)

In [37]:
!pip install -qU \
    datasets==2.19.1 \
    llama-index-core==0.11.9 \
    llama-index-llms-openai \
    llama-index-utils-workflow==0.2.1 \
    llama-index-embeddings-openai \
    pinecone-notebooks \
    pinecone

# Knowledge Base Setup

We'll be running our agent against a knowledge base — which requires a Pinecone index to be built.

You can, if needed, skip this step and replace the `search` tool with a placeholder value if wanting to quickly test the structure of the Llama Index Workflow.

If you want full functionality here, you do need to run this section — but we'll make it quick.

## Before you begin:

You'll need an OpenAI API key and a Pinecone API key to complete this tutorial!

## Download a Dataset

The first thing we need for an agent using RAG is somewhere we want to pull knowledge from. We will use v2 of the AI ArXiv dataset, available on Hugging Face Datasets at `jamescalam/ai-arxiv2-semantic-chunks`.

Note: we're using the prechunked dataset. For the raw version see `jamescalam/ai-arxiv2`.

In [38]:
from datasets import load_dataset

dataset = load_dataset("jamescalam/ai-arxiv2-semantic-chunks", split="train")
dataset

Dataset({
    features: ['id', 'title', 'content', 'prechunk_id', 'postchunk_id', 'arxiv_id', 'references'],
    num_rows: 209760
})

In [39]:
dataset[0]

{'id': '2401.04088#0',
 'title': 'Mixtral of Experts',
 'content': '4 2 0 2 n a J 8 ] G L . s c [ 1 v 8 8 0 4 0 . 1 0 4 2 : v i X r a # Mixtral of Experts Albert Q. Jiang, Alexandre Sablayrolles, Antoine Roux, Arthur Mensch, Blanche Savary, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Emma Bou Hanna, Florian Bressand, Gianna Lengyel, Guillaume Bour, Guillaume Lample, LÃ©lio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Sandeep Subramanian, Sophia Yang, Szymon Antoniak, Teven Le Scao, ThÃ©ophile Gervet, Thibaut Lavril, Thomas Wang, TimothÃ©e Lacroix, William El Sayed Abstract We introduce Mixtral 8x7B, a Sparse Mixture of Experts (SMoE) language model. Mixtral has the same architecture as Mistral 7B, with the difference that each layer is composed of 8 feedforward blocks (i.e. experts). For every token, at each layer, a router network selects two experts to process the current state and combine their outputs. Even though each token only sees two experts

## Construct Knowledge Base

First initialize encoder model, for this we will need an [OpenAI API key](https://platform.openai.com/api-keys).

In [40]:
import os
from getpass import getpass
from llama_index.embeddings.openai import OpenAIEmbedding

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") or getpass("OpenAI API key: ")

# batch size parameter aligns with Pinecone upsertion
embed_model = OpenAIEmbedding(model="text-embedding-3-small", embed_batch_size=128)


Initialize our connection to Pinecone:

In [41]:
from pinecone import Pinecone
from pinecone_notebooks.colab import Authenticate

Authenticate()


In [42]:
# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = os.environ.get('PINECONE_API_KEY')

# configure client
pc = Pinecone(api_key=api_key)

In [43]:
from pinecone import ServerlessSpec

spec = ServerlessSpec(
    cloud="aws", region="us-east-1"  # us-east-1
)

In [44]:
dims = len(embed_model.get_text_embedding("some random text"))
dims


1536

In [45]:
import time

index_name = "gpt-4o-research-agent"

# check if index already exists (it shouldn't if this is first time)
if index_name not in pc.list_indexes().names():
    # if does not exist, create index
    pc.create_index(
        index_name,
        dimension=dims,  # dimensionality of embed 3
        metric='dotproduct',
        spec=spec
    )
    # wait for index to be initialized
    while not pc.describe_index(index_name).status['ready']:
        time.sleep(1)

# connect to index
index = pc.Index(index_name)
time.sleep(1)
# view index stats
index.describe_index_stats()

{'dimension': 1536,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 10000}},
 'total_vector_count': 10000}

Populate the knowledge base:

In [46]:
from tqdm.auto import tqdm

# easier to work with dataset as pandas dataframe
# take first 10k instances as example. Embed more at your convenience
data = dataset.to_pandas().iloc[:10000]

batch_size = 128

for i in tqdm(range(0, len(data), batch_size)):
    i_end = min(len(data), i+batch_size)
    batch = data[i:i_end].to_dict(orient="records")
    # get batch of data
    metadata = [{
        "title": r["title"],
        "content": r["content"],
        "arxiv_id": r["arxiv_id"],
        "references": r["references"].tolist()
    } for r in batch]
    # generate unique ids for each chunk
    ids = [r["id"] for r in batch]
    # get text content to embed
    content = [r["content"] for r in batch]
    # embed text
    embeds = embed_model.get_text_embedding_batch(content)
    # add to Pinecone
    index.upsert(vectors=zip(ids, embeds, metadata))

  0%|          | 0/79 [00:00<?, ?it/s]

# Agent Components

## Tools

We define the separate tool functions. When integrating with our graph all of these will be executed using the same `run_tools` class - which we will define later.

For now, let's define the functions that our agent will have access to.

In [47]:
import requests
import re

# our regex
abstract_pattern = re.compile(
    r'\s*Abstract:\s*(.*?)\s*',
    re.DOTALL
)

async def fetch_arxiv(arxiv_id: str):
    """Gets the abstract from an ArXiv paper given the arxiv ID. Useful for
    finding high-level context about a specific paper."""
    print(">>> fetch_arxiv")
    # get paper page in html
    res = requests.get(
        f"https://export.arxiv.org/abs/{arxiv_id}"
    )
    # search html for abstract
    re_match = abstract_pattern.search(res.text)
    # return abstract text
    return re_match.group(1)

## Rag Search

We provide two RAG-focused tools for our agent. The `rag_search` allows the agent to perform a simple RAG search for some information across all indexed research papers. The `rag_search_filter` also searches, but within a specific paper which is filtered for via the `arxiv_id` parameter.

We also define the `format_rag_contexts` function to handle the transformation of our Pinecone results from a JSON object to a readble plaintext format.

In [48]:
def format_rag_contexts(matches: list):
    contexts = []
    for x in matches:
        text = (
            f"Title: {x['metadata']['title']}\n"
            f"Content: {x['metadata']['content']}\n"
            f"ArXiv ID: {x['metadata']['arxiv_id']}\n"
            f"Related Papers: {x['metadata']['references']}\n"
        )
        contexts.append(text)
    context_str = "\n---\n".join(contexts)
    return context_str

async def rag_search_filter(query: str, arxiv_id: str):
    """Finds information from our ArXiv database using a natural language query
    and a specific ArXiv ID. Allows us to learn more details about a specific paper."""
    print(">>> rag_search_filter")
    xq = await embed_model.aget_text_embedding(query)
    xc = index.query(vector=xq, top_k=6, include_metadata=True, filter={"arxiv_id": arxiv_id})
    context_str = format_rag_contexts(xc["matches"])
    return context_str

async def rag_search(query: str):
    """Finds specialist information on AI using a natural language query."""
    print(">>> rag_search")
    xq = await embed_model.aget_text_embedding(query)
    xc = index.query(vector=xq, top_k=2, include_metadata=True)
    context_str = format_rag_contexts(xc["matches"])
    return context_str

In [49]:
async def final_answer(
    introduction: str,
    research_steps: str,
    main_body: str,
    conclusion: str,
    sources: str
):
    """Returns a natural language response to the user in the form of a research
    report. There are several sections to this report, those are:
    - `introduction`: a short paragraph introducing the user's question and the
    topic we are researching.
    - `research_steps`: a few bullet points explaining the steps that were taken
    to research your report.
    - `main_body`: this is where the bulk of high quality and concise
    information that answers the user's question belongs. It is 3-4 paragraphs
    long in length.
    - `conclusion`: this is a short single paragraph conclusion providing a
    concise but sophisticated view on what was found.
    - `sources`: a bulletpoint list provided detailed sources for all information
    referenced during the research process
    """
    print(">>> final_answer")
    if type(research_steps) is list:
        research_steps = "\n".join([f"- {r}" for r in research_steps])
    if type(sources) is list:
        sources = "\n".join([f"- {s}" for s in sources])
    return ""

## Oracle LLM

Our prompt for the Oracle will emphasize it's decision making ability within the `system_prompt`, leave a placeholder for us to later insert `chat_history`, and provide a place for us to insert the user `input`.

In [50]:
system_prompt = """You are the oracle, the great AI decision maker.
Given the user's query you must decide what to do with it based on the
list of tools provided to you.

If you see that a tool has been used (in the scratchpad) with a particular
query, do NOT use that same tool with the same query again. Also, do NOT use
any tool more than twice (ie, if the tool appears in the scratchpad twice, do
not use it again).

You should aim to collect information from a diverse range of sources before
providing the answer to the user. Once you have collected plenty of information
to answer the user's question (stored in the scratchpad) use the final_answer
tool."""

The oracle agent will be provided the tools we previously built.

In [51]:
from llama_index.core.tools import FunctionTool

tools = [
    FunctionTool.from_defaults(async_fn=fetch_arxiv),
    FunctionTool.from_defaults(async_fn=rag_search_filter),
    FunctionTool.from_defaults(async_fn=rag_search),
    FunctionTool.from_defaults(async_fn=final_answer),
]

In [52]:
import os
from llama_index.llms.openai import OpenAI

llm = OpenAI(
    model="gpt-4o",
    additional_kwargs={"tool_choice": "required"}
)

## Events

We need to create a few events for our workflow. Llama-index comes with a few predefined event types, two of which we will use (`StartEvent` and `StopEvent`). However, we need to define a few additional custom event types - these are:

* `InputEvent` to handle new messages and prepare chat history.

* `ToolCallEvent` to trigger tool calls.

In [53]:
from llama_index.core.llms import ChatMessage
from llama_index.core.tools import ToolSelection, ToolOutput
from llama_index.core.workflow import Event


class InputEvent(Event):
    input: list[ChatMessage]


class ToolCallEvent(Event):
    id: str
    name: str
    params: dict

Now we build the workflow. Workflows consist of a single `Workflow` class with multiple `steps`. Each step is like a compute/execution step in our agentic flow.

We control which step is triggered by using different `Event` types. Each step consumes a different type of event, like `InputEvent` or `ToolCallEvent`. Additionally, our workflow begins and ends with the `StartEvent` and `StopEvent` events respectively.

In [54]:
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.llms import MessageRole

class ResearchAgent(Workflow):
    def __init__(
        self,
        *args: any,
        oracle: OpenAI,
        tools: list[FunctionTool],
        timeout: int = 20,
    ):
        super().__init__(*args)
        self._timeout = timeout
        self.oracle = oracle
        self.tools = tools
        self.get_tool = {tool.metadata.get_name(): tool for tool in self.tools}
        # initialize chat history/memory with system prompt
        self.sys_msg = ChatMessage(
            role=MessageRole.SYSTEM,
            content=system_prompt
        )
        self.memory = ChatMemoryBuffer.from_defaults(llm=OpenAI())

    @step
    async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:
        # clear memory
        self.memory = ChatMemoryBuffer.from_defaults(llm=OpenAI())
        self.memory.put(message=self.sys_msg)
        # get user input
        user_input = ev.input
        user_msg = ChatMessage(role="user", content=user_input)
        self.memory.put(message=user_msg)
        # get chat history
        chat_history = self.memory.get()
        # return input event
        return InputEvent(input=chat_history)

    @step
    async def handle_llm_input(self, ev: InputEvent) -> ToolCallEvent | StopEvent:
        chat_history = ev.input
        # get oracle response
        response = await self.oracle.achat_with_tools(
            tools=self.tools,
            chat_history=chat_history,
            tool_choice="required",
        )
        # add response to chat history / memory
        self.memory.put(message=response.message)
        # get tool calls
        tool_calls = self.oracle.get_tool_calls_from_response(
            response
        )
        # if final_answer tool used we return to the user with the StopEvent
        if tool_calls[-1].tool_name == "final_answer":
            return StopEvent(result={"response": tool_calls[-1].tool_kwargs})
        else:
            # return tool call event
            return ToolCallEvent(
                id=tool_calls[-1].tool_id,
                name=tool_calls[-1].tool_name,
                params=tool_calls[-1].tool_kwargs,
            )

    @step
    async def run_tool(self, ev: ToolCallEvent) -> InputEvent:
        tool_name = ev.name
        additional_kwargs = {
            "tool_call_id": ev.id,
            "name": tool_name
        }
        # get chosen tool
        tool = self.get_tool.get(tool_name)
        if not tool:
            tool_msg = ChatMessage(
                role="tool",
                content=f"Tool {tool_name} not found",
                additional_kwargs=additional_kwargs
            )
        else:
            # now call tool
            tool_output = await tool.acall(**ev.params)
            tool_msg = ChatMessage(
                role="tool",
                content=tool_output.content,
                additional_kwargs=additional_kwargs
            )
        self.memory.put(message=tool_msg)
        chat_history = self.memory.get()
        return InputEvent(input=chat_history)



In [55]:
from llama_index.utils.workflow import draw_all_possible_flows

draw_all_possible_flows(ResearchAgent, filename="research_agent.html")

research_agent.html


Initialize the workflow:

In [61]:
agent = ResearchAgent(
    oracle=llm,
    tools=tools,
    timeout=60,
)

In [62]:
res = await agent.run(input="tell me about AI")
res["response"]

>>> rag_search
>>> fetch_arxiv
>>> fetch_arxiv
>>> rag_search_filter


{'introduction': 'Artificial Intelligence (AI) is a broad and dynamic field that encompasses the development of computer systems capable of performing tasks that typically require human intelligence. These tasks include learning, reasoning, problem-solving, perception, language understanding, and interaction. AI has evolved significantly over the years, leading to the creation of sophisticated models and applications that impact various domains.',
 'research_steps': '1. Conducted a general search on AI to gather initial information.\n2. Retrieved abstracts and detailed content from specific ArXiv papers related to AI.\n3. Analyzed the content to extract relevant information about AI and its applications.',
 'main_body': "Artificial Intelligence (AI) has its roots in the mid-20th century, with the term 'artificial intelligence' being coined by John McCarthy in 1956. Since then, AI has grown to encompass a wide range of subfields, including machine learning, natural language processing, 

---

Let's test with async!

In [63]:
agent1 = ResearchAgent(oracle=llm, tools=tools)
input = "Tell me about MAMBA models and compare them to the Transformer architecture"


res = await agent.run(input=input)
res["response"]


>>> rag_search
>>> fetch_arxiv
>>> rag_search
>>> fetch_arxiv
>>> rag_search_filter


{'introduction': 'The MAMBA model and the Transformer architecture are two significant advancements in the field of artificial intelligence, particularly in sequence modeling and natural language processing. Both have their unique characteristics and applications, making them suitable for different tasks.',
 'research_steps': '1. Conducted a search for information on MAMBA models in AI.\n2. Retrieved and reviewed the abstract and details of the MAMBA model from ArXiv.\n3. Conducted a search for information on Transformer architecture in AI.\n4. Retrieved and reviewed the abstract and details of the Transformer architecture from ArXiv.\n5. Compiled and compared the information from both models.',
 'main_body': "The MAMBA model, short for 'Linear-Time Sequence Modeling with Selective State Spaces,' is a recent development in AI that aims to achieve Transformer-quality performance in a more efficient manner. According to the research, MAMBA models have shown to exceed the performance of v

## Clean up


Run the below line if you'd like to delete your index.

In [None]:
#pc.delete_index(index_name)
