### Rewrite-Retrive-Read

Strategy to prompt the LLM to rewrite query before performing retrieval, as the user query might be worded poorly

In [1]:
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_postgres.vectorstores import PGVector
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import chain
from langchain_core.runnables import Runnable

In [2]:
connection = 'postgresql+psycopg://langchain:langchain@localhost:6024/langchain'
collection_name = "Harry_Potter_Complete"
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")

db = PGVector(
    embeddings=embedding_model,
    connection=connection,
    collection_name=collection_name
)

retriever = db.as_retriever()

In [3]:
prompt = ChatPromptTemplate.from_template("""Answer the question based only on the provided context: {context}
question: {question}
""")
llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=0)

In [4]:
@chain
def qa(question):
    docs = retriever.invoke(question)
    context = '\n\n'.join(d.page_content for d in docs)
    formatted = prompt.invoke({"context" : context, "question" : question})
    answer = llm.invoke(formatted)
    answer_text = answer.content if hasattr(answer, 'content') else answer
    return answer_text

In [5]:
qa.invoke("""Today I woke up and brushed my teeth, then I sat down to read the news. But then I forgot the food on the cooker. What are the names of the houses in Hogwarts?""")

'The names of the houses in Hogwarts are Gryffindor, Hufflepuff, Ravenclaw, and Slytherin.'

#### I was hoping it would not answer the question but it did anyway ðŸ˜…

In [6]:
rewrite_prompt = ChatPromptTemplate.from_template("""Provide a better search query for web search engine to answer the given question, end the queries with â€™**â€™. Question: {question} Answer:""")

rewriter_runnable = rewrite_prompt | llm

def parse_rewriter_output(message):
    # message may be a ChatMessage-like object or string
    text = message.content if hasattr(message, "content") else str(message)
    # split at "**" and return first piece, trimming quotes/spaces
    return text.strip().strip('"').split("**")[0].strip()

@chain
def qa_rrr(question: str):
    # get rewritten query (invoke the runnable, pass mapping)
    rewritten_msg = rewriter_runnable.invoke({"question": question})
    # parse the rewriter output into a query string
    query = parse_rewriter_output(rewritten_msg)

    # retrieve docs using the rewritten query
    docs = retriever.invoke(query)
    context = "\n\n".join(d.page_content for d in docs)

    # prepare the final prompt (use the ChatPromptTemplate)
    formatted_prompt = prompt.invoke({"context": context, "question": question})

    # call the LLM
    answer_msg = llm.invoke(formatted_prompt)
    return answer_msg.content if hasattr(answer_msg, "content") else str(answer_msg)

In [7]:
qa_rrr.invoke("""Today I woke up and brushed my teeth, then I sat down to read the news. But then I forgot the food on the cooker. 
    What are the names of the houses in Hogwarts?""")

'The names of the houses in Hogwarts are Gryffindor, Hufflepuff, Ravenclaw, and Slytherin.'

### Multi-Query Retriever