# <font color="#003660">Applied Machine Learning for Text Analysis (M.184.5331)</font>


# <font color="#003660">Session 8: Retrieval-Augmented Generation</font>

# <font color="#003660">Rewriting</font>

<center><br><img width=256 src="https://raw.githubusercontent.com/olivermueller/aml4ta-2021/main/resources/dag.png"/><br></center>

<p>

<div>
    <font color="#085986"><b>By the end of this lesson, you ...</b><br><br>
        ... will know how to implement query rewriting, especially rewrite-retrieve-read. <br>
        ... will know how to implement chains the "langchain" style.
    </font>
</div>
</p>

The following content is heavily inspired by the following excellent sources:

* [HuggingFace (2024): NLP Course](https://huggingface.co/learn/nlp-course/)
* [Huggingface (2024): Open-Source AI Cookbook](https://huggingface.co/learn/cookbook/index)
* [LangChain API Reference (2024)](https://python.langchain.com/api_reference/reference.html)
* [LangChain Docs (2024)](https://python.langchain.com/docs/introduction/)
* [LangChain AI (2024) Cookbook](https://github.com/langchain-ai/langchain/blob/master/cookbook/rewrite.ipynb?ref=blog.langchain.dev)

# Rewriting

![](https://github.com/olivermueller/amlta-2024/blob/main/Session_08/imgs/rag_extensions.png?raw=true)

(Source: ([Wang et al., 2024](https://doi.org/10.18653/v1/2024.emnlp-main.981)))

There are multiple ways to improve RAG architectures as summarized by [Wang et al. (2024)](https://doi.org/10.18653/v1/2024.emnlp-main.981).

This lecture focuses on rewriting.

In [None]:
!pip install -U pymupdf4llm datasets transformers faiss-cpu sentence-transformers accelerate langchain langchain-community langchain-huggingface

In [43]:
import os
import re
from tqdm.notebook import tqdm
import pymupdf4llm
import urllib

from IPython.display import display, Markdown

from transformers import AutoTokenizer, set_seed
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain import hub
from langchain_huggingface import HuggingFacePipeline

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

DEVICE = "cuda"

In [None]:
os.mkdir("documents")
os.mkdir("imgs")
os.mkdir("markdown_documents")
urllib.request.urlretrieve("https://github.com/olivermueller/amlta-2024/tree/main/Session_08/documents/Game_of_Thrones.pdf", "documents/Game_of_Thrones.pdf")
urllib.request.urlretrieve("https://github.com/olivermueller/amlta-2024/tree/main/Session_08/documents/How_I_Met_Your_Mother.pdf", "documents/How_I_Met_Your_Mother.pdf")
urllib.request.urlretrieve("https://github.com/olivermueller/amlta-2024/tree/main/Session_08/markdown_documents/Game_of_Thrones.md", "markdown_documents/Game_of_Thrones.md")
urllib.request.urlretrieve("https://github.com/olivermueller/amlta-2024/tree/main/Session_08/markdown_documents/How_I_Met_Your_Mother.md", "markdown_documents/How_I_Met_Your_Mother.md")

In [2]:
RETRIEVER_NAME = "jinaai/jina-embeddings-v2-base-en"
GENERATOR_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

# Loading Documents

In [3]:
markdown_documents_path = "markdown_documents"

In [4]:
def remove_markdown_links(text):
    """
    Removes Markdown links from the given text while keeping the link text.
    
    Args:
        text (str): The input Markdown text.
        
    Returns:
        str: The text with Markdown links removed. 
    
    Yeah this was ChatGPT ;)
    """
    # Regex to match Markdown links [text](link)
    pattern = r'\[([^\]]+)\]\([^\)]+\)'
    # Replace the matched pattern with just the text inside the brackets
    cleaned_text = re.sub(pattern, r'\1', text)
    return cleaned_text

In [5]:
markdown_documents = os.listdir(markdown_documents_path)

md_files = []

for markdown_document in markdown_documents:
    markdown_document_path = os.path.join(markdown_documents_path, markdown_document)
    with open(markdown_document_path) as file:
        md_files.append([markdown_document, remove_markdown_links(file.read())])

# Our original chain (Just another query)

In [6]:
embedding_tokenizer = AutoTokenizer.from_pretrained(RETRIEVER_NAME, use_fast=False)
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
    embedding_tokenizer, 
    chunk_size=512, 
    chunk_overlap=32,
    separators=[
        "# ",
        "## ",
        "### ",
        "#### ",
        "##### ",
        "###### ",
        "\n\n",
        "\n",
        " ",
    ],
    keep_separator=True
)
all_splits = text_splitter.create_documents(
    texts=[x[1] for x in md_files], 
    metadatas=[{"source": x[0]} for x in md_files],
)

In [None]:
retriever_model = HuggingFaceEmbeddings(
    model_name=RETRIEVER_NAME, 
    model_kwargs={'device': DEVICE, "trust_remote_code": True}, 
    encode_kwargs={'normalize_embeddings': True}
)
db = FAISS.from_documents(
    all_splits, 
    embedding=retriever_model
)
retriever = db.as_retriever()

In [None]:
llm = HuggingFacePipeline.from_model_id(
    model_id=GENERATOR_NAME, 
    task="text-generation", 
    pipeline_kwargs={"return_full_text": False}
)

In [None]:
prompt = hub.pull("rlm/rag-prompt")
print(prompt.messages[0].prompt.template)

In [12]:
chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
chain.invoke("I loooove Naomi Watts! Who plays Daenerys Targaryen?")

As illustrated above, distracting LLMs is really easy. But how to prevent this?

# Implementing a rewriter with Rewrite-Retrieve-Read

To improve the answer, we can also use query rewriting. This simply means that we pass the query (or prompt) directly to the LLM with a prompt to rewrite the question before answering it. This can prevent potential biases and can help to improve answer accuracy.

We will now implement Rewrite-Retrieve-Read (RRR) as a strategy, introduced by [Ma et al. (2023)](https://aclanthology.org/2023.emnlp-main.322/). (Besides RRR there are diverse other methods which are easily explained in this [LangChain blog post](https://blog.langchain.dev/query-transformations/)).

In [None]:
set_seed(0)
template = """REWRITE the question! Start your rewritten question with REWRITTEN:. Question: {x} REWRITTEN:""" # Original prompt that wasn't working with the small LLM was this: """Provide a better search query for web search engine to answer the given question, end the queries with ’**’. Question: {x} Answer:"""
rewrite_prompt = ChatPromptTemplate.from_template(template)
def _parse(text):
    return text.strip('"').split("REWRITTEN:")[-1]
rewriter = rewrite_prompt | llm | StrOutputParser() | _parse
rewriter.invoke({"x": "I loooove Naomi Watts! Who plays Daenerys Targaryen?"})

# The updated chain

In [None]:
set_seed(0)
rewrite_retrieve_read_chain = (
    {
        "context": {"x": RunnablePassthrough()} | rewriter | retriever,
        "question": RunnablePassthrough(),
    }
    | prompt
    | llm
    | StrOutputParser()
)
rewrite_retrieve_read_chain.invoke("I loooove Naomi Watts! Who plays Daenerys Targaryen?")

# Want to learn more stuff about RAG with LangChain?

* [LangChain RAG from scratch](https://github.com/langchain-ai/rag-from-scratch)
* [LangChain Cookbook](https://python.langchain.com/docs/introduction/)