## LLM之RAG实战（二十八）| 探索RAG query重写

### 一、假设文档嵌入（HyDE）

In [None]:
import os

os.environ['OPENAI_API_KEY'] = 'your key'

from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine

In [None]:
dir_path = 'your dir path'

documents = SimpleDirectoryReader(dir_path).load_data()
index = VectorStoreIndex.from_documents(documents)

In [None]:
query_str = "what did paul graham do after going to RISD"

query_engine = index.as_query_engine()
response = query_engine.query(query_str)

print('--' * 50)
print('Base query:')
print(response)

In [None]:
hyde = HyDEQueryTransform(include_original=True)
hyde_query_engine = TransformQueryEngine(query_engine, transform=hyde)
response = hyde_query_engine.query(query_str)

print('--' * 50)
print('After HyDEQueryTransform:')
print(response)

In [None]:
class HyDEQueryTransform(BaseQueryTransform):
    """Hypothetical Document Embeddings (HyDE) query transform.

    It uses an LLM to generate hypothetical answer(s) to a given query,
    and use the resulting documents as embedding strings.

    As described in `[Precise Zero-Shot Dense Retrieval without Relevance Labels]
    (https://arxiv.org/abs/2212.10496)`
    """

    def __init__(
        self,
        llm: Optional[LLMPredictorType] = None,
        hyde_prompt: Optional[BasePromptTemplate] = None,
        include_original: bool = True,
    ) -> None:
        """Initialize HyDEQueryTransform.

        Args:
            llm_predictor (Optional[LLM]): LLM for generating
                hypothetical documents
            hyde_prompt (Optional[BasePromptTemplate]): Custom prompt for HyDE
            include_original (bool): Whether to include original query
                string as one of the embedding strings
        """
        super().__init__()

        self._llm = llm or Settings.llm
        self._hyde_prompt = hyde_prompt or DEFAULT_HYDE_PROMPT
        self._include_original = include_original
    def _get_prompts(self) -> PromptDictType:
        """Get prompts."""
        return {"hyde_prompt": self._hyde_prompt}

    def _update_prompts(self, prompts: PromptDictType) -> None:
        """Update prompts."""
        if "hyde_prompt" in prompts:
            self._hyde_prompt = prompts["hyde_prompt"]
            

    def _run(self, query_bundle: QueryBundle, metadata: Dict) -> QueryBundle:
        """Run query transform."""
        # TODO: support generating multiple hypothetical docs
        query_str = query_bundle.query_str
        hypothetical_doc = self._llm.predict(self._hyde_prompt, context_str=query_str)
        embedding_strs = [hypothetical_doc]
        if self._include_original:
            embedding_strs.extend(query_bundle.embedding_strs)

        # The following three lines contain the added debug statements.
        print('-' * 100)
        print("Hypothetical doc:")
        print(embedding_strs)

        return QueryBundle(
            query_str=query_str,
            custom_embedding_strs=embedding_strs,
        )

In [None]:
HYDE_TMPL = (
    "please write a passage to answer the question\n"
    "try to include as many key details as possible\n"
    "\n"
    "\n"
    "{context_str}\n"
    'passage:\n'
)

DEFAULT_HYDE_PROMPT = PromptTemplate(template=HYDE_TMPL, 
                                     prompt_type=PromptType.SUMMARY,
                                     input_variables=["context_str"])

### 二、重写-检索-读取

In [None]:
import os
os.environ["OPENAI_API_KEY"] = "your_openai_api_key"

from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

In [None]:
base_template = """
    Answer the users question based only on the following context:
    <context>
    {context}
    </context>
    Question: {question}
"""
base_prompt = ChatPromptTemplate(template=base_template, input_variables=["context", "question"])

model = ChatOpenAI(temperature=0)

search = DuckDuckGoSearchAPIWrapper()

In [None]:
def retriever(query):
    return search.run(query)

def june_print(msg, res):
    print('--' * 50)
    print(msg)
    print(res)

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | base_prompt
    | model
    | StrOutputParser()
)

query = "What did the president say about Ketanji Brown Jackson"

june_print(
    "the result of query",
    chain.invoke(query)
)

june_print(
    "the result of the searched contexts",
    retriever(query)
)

### 现在就开始构建重写器来重写搜索查询。

In [None]:
rewrite_template = """
    Provide a better search query for \
    web search engine to answer the given question, end \
    the queries with '**'. Question: \
    {x} Answer:
"""
rewrite_prompt = ChatPromptTemplate.from_template(rewrite_template)

def _parse(text):
    return text.strip("**")

rewriter = rewrite_prompt | ChatOpenAI(temperature=0) | StrOutputParser() | _parse
june_print(
    'rewritten query:',
    rewriter.invoke({'x': query})
)

In [None]:
# 构造rewrite_retrieve_read_chain并利用重写后的查询。

rewrite_retrieve_read_chain = (
    {
        "context": { "x": RunnablePassthrough() } | rewriter | retriever,
        "question": RunnablePassthrough(),
    }
    | base_prompt
    | model
    | StrOutputParser()
)

june_print(
    'the result of the rewrite_retrieve_read_chain:',
    rewrite_retrieve_read_chain.invoke(query)
)


### 三、Step-Back提示

In [None]:
import os 
os.environ['OPENAI_API_KEY'] = 'YOUR_API_KEY'

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_core.runnables import RunnableLambda
from langchain_openai import ChatOpenAI
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper

In [None]:
###     构建一个链并执行原始查询：
def june_print(msg, res):
    print('--' * 50)
    print(msg)
    print(res)

question = 'was chatgpt around while trump was president?'

base_prompt_template = """
    you are an expert of world knowledge. I am going to ask you a question.
    {normal_context}
    Original Qestion: {question}
    Answer:
    """

base_prompt = ChatPromptTemplate(template=base_prompt_template, input_variables=['normal_context', 'question'])

search = DuckDuckGoSearchAPIWrapper(max_results=4)

def retriever(query):
    return search.run(query)

base_chain = (
    {
        "normal_context": RunnableLambda(lambda x: x['question']) | retriever,
        "question": lambda x: x['question']
    }
    | base_prompt
    | ChatOpenAI(temperature=0)
    | StrOutputParser()
)

june_print('the searched contexts of the original question:', retriever(question))
june_print('the result of base_chain:', base_chain.invoke({'question': question}))


In [None]:
# Few shot examples

examples = [
    {
        "input": "could the members of the police perform lawful arrests?",
        "output": "what can the members of the police do?"
    },
    {
        "input": "what is the purpose of the police?",
        "output": "what is the purpose of the police?"
    }
]
example_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful assistant that translates from English to Pig Latin."),
    ("human", "{input}"),
    ("ai", "{output}")
])

few_shot_prompt = FewShotChatMessagePromptTemplate(
    examples=examples,
    example_prompt=example_prompt,
    prefix="",
    suffix="",
)
step_back_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful assistant that translates from English to Pig Latin."),
    few_shot_prompt,
    ("user", "{question}"),
])

step_back_question_chain = step_back_prompt | ChatOpenAI(temperature=0) | StrOutputParser()
june_print(
    'the step back question:',
    step_back_question_chain.invoke({"question": question})
)
june_print(
    'the searched contexts of the step back question:',
    retriever(step_back_question_chain.invoke({"question": question}))
)

In [None]:
response_prompt_template = """
You are a helpful assistant. Given the following extracted parts of a long document and a question, create a final answer. 
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
If the answer is not contained within the text below, say \"I don't know\"

{normal_context}
{step_back_context}

Original question: {question}
Answer:
"""
response_prompt = ChatPromptTemplate.from_template(
    template=response_prompt_template,
    input_variables=["normal_context", "step_back_context", "question"],
)

step_back_chain = (
    {
        "normal_context": RunnableLambda(lambda x: x["question"]) | retriever,
        "step_back_context": step_back_question_chain | retriever,
        "question": lambda x: x["question"],
    }
    | response_prompt
    | ChatOpenAI(temperature=0)
    | StrOutputParser()
)

june_print(
    "Step back chain created. Now let's try it out!",
    step_back_chain.invoke({"question": question})
)