In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pyprojroot import here

pdf_path = here() / "data/JMLR-23-0380-1.pdf"
assert pdf_path.exists()

In [None]:
from langchain.chat_models import ChatOpenAI
from llamabot.config import default_language_model
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from llama_index import (
    GPTVectorStoreIndex,
    LLMPredictor,
    ServiceContext,
    load_index_from_storage,
)

# Configuration (copied from QueryBot __init__)
model_name = default_language_model()
temperature = 0.0
stream = True

chat = ChatOpenAI(
    model_name=model_name,
    temperature=temperature,
    streaming=True,
    verbose=True,
    callback_manager=BaseCallbackManager(
        handlers=[StreamingStdOutCallbackHandler()] if stream else []
    ),
)
llm_predictor = LLMPredictor(llm=chat)
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)

In [None]:
from llamabot.bot.querybot import make_or_load_index

doc_paths = [pdf_path]
large_chunk_size = int(2000)
small_chunk_size = int(500)
chunk_overlap = 0
use_cache = True

large_index = make_or_load_index(doc_paths, large_chunk_size, chunk_overlap, use_cache)
small_index = make_or_load_index(doc_paths, small_chunk_size, chunk_overlap, use_cache)

In [None]:
large_similarity_top_k = 5
small_similarity_top_k = 20
large_retriever = large_index.as_retriever(similarity_top_k=large_similarity_top_k)
small_retriever = small_index.as_retriever(similarity_top_k=small_similarity_top_k)

query = "What is Post-training of Feature extractors' algorithm written out explicitly? Translate the symbols into plain English, but retain their original symbols when referring to them."

large_source_nodes = large_retriever.retrieve(query)
large_source_texts = [n.node.text for n in large_source_nodes]

small_source_nodes = small_retriever.retrieve(query)
small_source_texts = [n.node.text for n in small_source_nodes]

In [None]:
# Now build the full query that gets stuffed into `chat`:
from langchain.schema import AIMessage, HumanMessage, SystemMessage

faux_chat_history = []
faux_chat_history.append(SystemMessage(content="You are a Q&A bot about papers!"))
faux_chat_history.append(
    SystemMessage(content="Here is the context you will be working with:")
)
# for text in small_source_texts:
#     faux_chat_history.append(SystemMessage(content=text))

for text in large_source_texts:
    faux_chat_history.append(SystemMessage(content=text))
faux_chat_history.append(HumanMessage(content=query))
response = chat(faux_chat_history)

In [None]:
# Now build the full query that gets stuffed into `chat`:
from langchain.schema import AIMessage, HumanMessage, SystemMessage

faux_chat_history = []
faux_chat_history.append(SystemMessage(content="You are a Q&A bot about papers!"))
faux_chat_history.append(
    SystemMessage(content="Here is the context you will be working with:")
)
for text in small_source_texts:
    faux_chat_history.append(SystemMessage(content=text))

# for text in large_source_texts:
#     faux_chat_history.append(SystemMessage(content=text))
faux_chat_history.append(HumanMessage(content=query))
response = chat(faux_chat_history)

In [None]:
[n.score for n in small_source_nodes]

In [None]:
[n.score for n in large_source_nodes]