# `chatXiv` - Chat with arXiv papers

In [1]:
import os
import json
import arxiv
import requests
import fitz
from pymupdf_rag import to_markdown
from typing import List, Tuple

import chromadb
from llama_index.core import Document, VectorStoreIndex, PromptTemplate
from llama_index.core.schema import NodeWithScore
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core.ingestion import IngestionPipeline
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core.node_parser import MarkdownNodeParser
import chromadb.utils.embedding_functions as embedding_functions
from chat_utils import ChatLLM, ChatSession

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ['OPENAI_API_KEY'] = "YOUR API KEY"

# Naive Q&A over a PDF Document

## Ingestion

### Load Data

In [3]:
arxiv_client = arxiv.Client()

def download_pdf(paper_id: str, dirpath: str = './papers'):
    '''Downloads a paper from arXiv given its ID'''
    paper = next(arxiv.Client().results(arxiv.Search(id_list=[paper_id])))
    return paper.download_pdf(dirpath=dirpath), paper

In [4]:
def load_arxiv_document(paper_id):
    '''Loads an arxiv paper as a Document object'''
    pdf_path, paper = download_pdf(paper_id, dirpath='./papers')
    paper_pdf = fitz.open(pdf_path)
    md_text = to_markdown(paper_pdf)
    
    return Document(
        doc_id=paper_id,
        text=md_text,
        metadata={
            'title': paper.title,
            'authors': ", ".join([auth.name for auth in paper.authors[:10]]),
            'published': paper.published.strftime('%Y-%m-%d'),
            'filepath': pdf_path,
        }
    )

In [5]:
paper_id = "2312.00752"

In [6]:
# Load the documents
documents = [
    load_arxiv_document(paper_id)
]

In [7]:
# Breakdown the document into nodes
parser = MarkdownNodeParser()

nodes = parser.get_nodes_from_documents(documents)

### Create the VectorDB

In [8]:
# Creating a Chroma VectorDB instance
db = chromadb.PersistentClient(path="./chroma_db")
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
                api_key=os.getenv("OPENAI_API_KEY"),
                model_name="text-embedding-3-small"
            )
chroma_collection = db.get_or_create_collection("chatxiv", embedding_function=openai_ef)

# Create a Chroma VectorStore
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)

# Create a Simple Document Store as well
docstore = SimpleDocumentStore()

# Load the embedding model
embed_model = OpenAIEmbedding(model="text-embedding-3-small", embed_batch_size=100)

In [9]:
# Build the Embedding Pipeline
pipeline = IngestionPipeline(
    transformations=[
        parser,
        embed_model,
    ],
    vector_store=vector_store,
    docstore=docstore,
)
# load local cache (if it exists)
# pipeline.load("./pipeline_chatxiv")

pipeline.run(documents=documents, show_progress=True)

# save cache locally
pipeline.persist("./pipeline_chatxiv")

Parsing nodes: 100%|██████████| 1/1 [00:00<00:00, 18.62it/s]
Generating embeddings: 100%|██████████| 43/43 [00:02<00:00, 16.46it/s]


In [10]:
# Create vector index for retrieval
index = VectorStoreIndex.from_vector_store(
    vector_store=vector_store,
    embed_model=embed_model
)

## Retrieval

In [11]:
retriever = index.as_retriever(similarity_top_k=3)

In [12]:
query_str = "Explain mamba"

In [13]:
retrieved_nodes = retriever.retrieve(query_str)
retrieved_nodes

[NodeWithScore(node=TextNode(id_='2129918f-3efe-4f77-bd90-2b7111fe6d6b', embedding=None, metadata={'Header_1': 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces', 'Header_2': 'E ## Experimental Details and Additional Results', 'Header_3': 'E.5 ### Efciency Benchmark', 'title': 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces', 'authors': 'Albert Gu, Tri Dao', 'published': '2023-12-01', 'filepath': './papers/2312.00752v1.Mamba__Linear_Time_Sequence_Modeling_with_Selective_State_Spaces.pdf'}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='2312.00752', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'title': 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces', 'authors': 'Albert Gu, Tri Dao', 'published': '2023-12-01', 'filepath': './papers/2312.00752v1.Mamba__Linear_Time_Sequence_Modeling_with_Selective_State_Spaces.pdf'}, hash='2d42815c6988e92465aaa9c4

In [14]:
retrieved_nodes[0].metadata

{'Header_1': 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces',
 'Header_2': 'E ## Experimental Details and Additional Results',
 'Header_3': 'E.5 ### Efciency Benchmark',
 'title': 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces',
 'authors': 'Albert Gu, Tri Dao',
 'published': '2023-12-01',
 'filepath': './papers/2312.00752v1.Mamba__Linear_Time_Sequence_Modeling_with_Selective_State_Spaces.pdf'}

## Response Synthesis

In [15]:
system_prompt = {
    'role': 'system',
    'content': """You are a Q&A bot. You are here to answer questions based on the context given.
You are prohibited from using prior knowledge and you can only use the context given. If you need 
more information, please ask the user."""
}

In [16]:
context_prompt = PromptTemplate(
"""Context information to answer the query is below.
---------------------
{context_str}
---------------------""")

def build_context_prompt(retrieved_nodes):
    context_str = "\n\n".join([r.node.get_content(metadata_mode='all') for r in retrieved_nodes])
    return context_prompt.format(
        context_str=context_str
    )

In [17]:
llm = ChatLLM()
session = ChatSession(llm, system_prompt=system_prompt)

In [18]:
# Add the context as a user message to the message history (thread)
session.thread.append({
    'role': 'user',
    'content': build_context_prompt(retrieved_nodes)
})

In [19]:
response = session.chat("How does mamba work? Explain like I am 5")
print(response.content)

Mamba is like a super smart computer friend that can help with big puzzles really fast! It can look at lots of pieces of information at once and figure out the best way to put them together. It's very good at organizing things quickly and efficiently to solve problems.


# Retrieval as a Tool

In [20]:
from llama_index.core.tools import FunctionTool
from openai.types.chat.chat_completion_tool_message_param import ChatCompletionToolMessageParam

In [21]:
def context_retrieval(search_query: str, top_k: int=3) -> Tuple[str, List[NodeWithScore]]:
    '''
    This function let's you semantically retrieve relevant context chunks from a given document based on a query.

    Arguments:
        query (str): The query to search for in the document. Based on the original user query, write a good search query
                     which is more logically sound to retrieve the relevant information from the document.
        top_k (int): The number of top chunks to retrieve from the document. Default is 3. You can increase this number if
                     you feel like you need more information. But ideally you should make multiple calls to retrieve different
                     topics of information.

    Returns:
        str: The top retrieved
        List[NodeWithScore]: A list of nodes with their scores. Use this to cite the information in the response.
    '''
    retriever = index.as_retriever(similarity_top_k=top_k)
    retrieved_nodes = retriever.retrieve(search_query)
    return build_context_prompt(retrieved_nodes), retrieved_nodes

In [22]:
llm = ChatLLM()
session = ChatSession(llm, system_prompt=system_prompt)

In [23]:
context_retrieval_tool = FunctionTool.from_defaults(fn=context_retrieval)

In [24]:
available_functions = {
    "context_retrieval": context_retrieval,
}

In [25]:
def conversation_turn(user_query):
    # Send the messages and available functions to the model
    # We will call our chat method from our session object which is already keeping track of the history
    response = session.chat(user_query, 
                            model='gpt-3.5-turbo', 
                            temperature=1,
                            max_tokens=512,
                            tools=[context_retrieval_tool.metadata.to_openai_tool()], tool_choice='auto')
    
    tool_calls = response.tool_calls

    # Check if the model wanted to call a function
    if tool_calls:
        # Call the function
        for tool_call in tool_calls:
            function_name = tool_call.function.name
            function_to_call = available_functions[function_name]
            function_args = json.loads(tool_call.function.arguments)
            function_response, retrieved_nodes = function_to_call(**function_args)

            print(f'Retrieving with {tool_call}')

            # Add the function response to the message thread
            session.thread.append(
                ChatCompletionToolMessageParam(role='tool', tool_call_id=tool_call.id, name=function_name, content=function_response)
            )
        
        # Get a new response from the model where it can see the function response
        post_function_response = session.chat()

        return post_function_response
    
    return response

In [28]:
response = conversation_turn("HOw does mamba work?")
response.content

Retrieving with ChatCompletionMessageToolCall(id='call_GgdsYoPcpBhFtZWios8xkRKI', function=Function(arguments='{"search_query":"Mamba: Linear-Time Sequence Modeling with Selective State Spaces","top_k":1}', name='context_retrieval'), type='function')


"Mamba works by utilizing a Selective State Space Model with hardware-aware state expansion. This approach involves incorporating a selection mechanism into the state space model to dynamically choose states or components based on specific criteria. Additionally, the model is designed to expand its state space efficiently by leveraging hardware-aware techniques, allowing for increased computational efficiency and performance.\n\nBy combining the Selective State Space Model with hardware-aware state expansion, Mamba aims to enable linear-time sequence modeling with improved adaptability and efficiency in processing sequential data. This approach integrates selective attention and hardware optimization to enhance the model's capabilities in capturing temporal dependencies and patterns in sequences."