# <font color="#003660">Applied Machine Learning for Text Analysis (M.184.5331)</font>


# <font color="#003660">Session 8: Retrieval-Augmented Generation</font>

# <font color="#003660">Reranking</font>

<center><br><img width=256 src="https://raw.githubusercontent.com/olivermueller/aml4ta-2021/main/resources/dag.png"/><br></center>

<p>

<div>
    <font color="#085986"><b>By the end of this lesson, you ...</b><br><br>
        ... will know how to implement reranking methods using langchain. <br>
        ... will know hot to implement langchain chains on your own.
    </font>
</div>
</p>

The following content is heavily inspired by the following excellent sources:

* [HuggingFace (2024): NLP Course](https://huggingface.co/learn/nlp-course/)
* [Huggingface (2024): Open-Source AI Cookbook](https://huggingface.co/learn/cookbook/index)
* [LangChain API Reference (2024)](https://python.langchain.com/api_reference/reference.html)
* [LangChain Docs (2024)](https://python.langchain.com/docs/introduction/)
* [LangChain AI (2024) Cookbook](https://github.com/langchain-ai/langchain/blob/master/cookbook/rewrite.ipynb?ref=blog.langchain.dev)

# Reranking

![](https://github.com/olivermueller/amlta-2024/blob/main/Session_08/imgs/rag_extensions.png?raw=true)

(Source: ([Wang et al., 2024](https://doi.org/10.18653/v1/2024.emnlp-main.981)))

There are multiple ways to improve RAG architectures as summarized by [Wang et al. (2024)](https://doi.org/10.18653/v1/2024.emnlp-main.981).

This lecture focuses on reranking.

In [None]:
!pip install -U pymupdf4llm datasets transformers faiss-cpu sentence-transformers accelerate langchain langchain-community langchain-huggingface

In [1]:
import os
import re
from tqdm.notebook import tqdm
import pymupdf4llm
import urllib

from IPython.display import display, Markdown

from transformers import AutoTokenizer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain import hub
from langchain_huggingface import HuggingFacePipeline

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank

DEVICE = "cuda"

In [None]:
os.mkdir("documents")
os.mkdir("imgs")
os.mkdir("markdown_documents")
urllib.request.urlretrieve("https://raw.githubusercontent.com/olivermueller/amlta-2024/refs/heads/main/Session_08/documents/Game_of_Thrones.pdf", "documents/Game_of_Thrones.pdf")
urllib.request.urlretrieve("https://raw.githubusercontent.com/olivermueller/amlta-2024/refs/heads/main/Session_08/documents/How_I_Met_Your_Mother.pdf", "documents/How_I_Met_Your_Mother.pdf")
urllib.request.urlretrieve("https://raw.githubusercontent.com/olivermueller/amlta-2024/refs/heads/main/Session_08/markdown_documents/Game_of_Thrones.md", "markdown_documents/Game_of_Thrones.md")
urllib.request.urlretrieve("https://raw.githubusercontent.com/olivermueller/amlta-2024/refs/heads/main/Session_08/markdown_documents/How_I_Met_Your_Mother.md", "markdown_documents/How_I_Met_Your_Mother.md")

In [2]:
RETRIEVER_NAME = "jinaai/jina-embeddings-v2-base-en"
GENERATOR_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

# Loading Documents

In [3]:
markdown_documents_path = "markdown_documents"

In [4]:
def remove_markdown_links(text):
    """
    Removes Markdown links from the given text while keeping the link text.
    
    Args:
        text (str): The input Markdown text.
        
    Returns:
        str: The text with Markdown links removed. 
    
    Yeah this was ChatGPT ;)
    """
    # Regex to match Markdown links [text](link)
    pattern = r'\[([^\]]+)\]\([^\)]+\)'
    # Replace the matched pattern with just the text inside the brackets
    cleaned_text = re.sub(pattern, r'\1', text)
    return cleaned_text

In [5]:
markdown_documents = os.listdir(markdown_documents_path)

md_files = []

for markdown_document in markdown_documents:
    markdown_document_path = os.path.join(markdown_documents_path, markdown_document)
    with open(markdown_document_path) as file:
        md_files.append([markdown_document, remove_markdown_links(file.read())])

# Original Chain

In [6]:
embedding_tokenizer = AutoTokenizer.from_pretrained(RETRIEVER_NAME, use_fast=False)
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
    embedding_tokenizer, 
    chunk_size=512, 
    chunk_overlap=32,
    separators=[
        "# ",
        "## ",
        "### ",
        "#### ",
        "##### ",
        "###### ",
        "\n\n",
        "\n",
        " ",
    ],
    keep_separator=True
)
all_splits = text_splitter.create_documents(
    texts=[x[1] for x in md_files], 
    metadatas=[{"source": x[0]} for x in md_files],   
)

In [7]:
retriever_model = HuggingFaceEmbeddings(
    model_name=RETRIEVER_NAME, 
    model_kwargs={'device': DEVICE, "trust_remote_code": True},
    encode_kwargs={'normalize_embeddings': True}
)
db = FAISS.from_documents(
    all_splits, 
    embedding=retriever_model
)
retriever = db.as_retriever()

In [None]:
llm = HuggingFacePipeline.from_model_id(
    model_id=GENERATOR_NAME, 
    task="text-generation", 
    pipeline_kwargs={"return_full_text": False}
)

In [None]:
prompt = hub.pull("rlm/rag-prompt")
print(prompt.messages[0].prompt.template)

## New CrossEncoder Reranker

![](https://raw.githubusercontent.com/UKPLab/sentence-transformers/master/docs/img/Bi_vs_Cross-Encoder.png)

Source: ([SBERT.net (2024)](https://www.sbert.net/examples/applications/cross-encoder/README.html))

Cross-Encoder Rerankers output a numerical value between 0 and 1 and can be used to determine the similarity between two sentences (chunks). While they are not compute efficient on large-scale tasks such as clustering they can improve RAG systems by ranking the reranking the chunks by their query-similarity ([Reimers and Gurevych](https://doi.org/10.48550/arXiv.1908.10084)).

In [None]:
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=model, top_n=3)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

In [None]:
print([x.id for x in retriever.invoke("Who plays Daenerys Targaryen?")])
print([x.id for x in compression_retriever.invoke("Who plays Daenerys Targaryen?")])

In [1]:
def invoke_reranking_rag_chain(question, retriever_type="basic"):
    if retriever_type == "basic":
        retrieved_docs = retriever.invoke(question)
    elif retriever_type == "compression":
        # TODO: implement compression retriever in here
        pass
    elif retriever_type == "flash":
        # TODO: Implement Flash retriever

        # IMPORTANT: wont work because of versioning problems in FlashrankRank        

        # your code here
        pass

    else:
        raise NotImplementedError("No other retrievers implemented yet.")
    for doc in retrieved_docs:
        display(Markdown(f"### {doc.metadata['source']}"))
        display(Markdown(doc.page_content))
    print("#" * 50)
    input_prompt = prompt.invoke({"question": question, "context": "\n\n".join(doc.page_content for doc in retrieved_docs)})
    answer = llm.invoke(input_prompt)
    return answer

In [None]:
question = "Who plays Daenerys Targaryen?"
answer = invoke_reranking_rag_chain(question, "basic")
print("Answer:", answer)

In [None]:
question = "Who plays Daenerys Targaryen?"
answer = invoke_reranking_rag_chain(question, "compression")
print("Answer:", answer)

## Flash Retriever
Now try the FlashRank by yourself:
* [Damodaran, P. (2023). FlashRank GitHub](https://github.com/PrithivirajDamodaran/FlashRank)
* [LangChain Docs](https://python.langchain.com/docs/integrations/retrievers/flashrank-reranker/)

In [19]:
# TODO: Implement Flash retriever

# your code here




In [None]:
def invoke_reranking_rag_chain(question, retriever_type="basic"):
    if retriever_type == "basic":
        retrieved_docs = retriever.invoke(question)
    elif retriever_type == "compression":
        retrieved_docs = compression_retriever.invoke(question, top_k=3)
    elif retriever_type == "flash":
        # TODO: Implement Flash retriever
        

        # your code here
        pass

    else:
        raise NotImplementedError("No other retrievers implemented yet.")
    for doc in retrieved_docs:
        display(Markdown(f"### {doc.metadata['source']}"))
        display(Markdown(doc.page_content))
    print("#" * 50)
    input_prompt = prompt.invoke({"question": question, "context": "\n\n".join(doc.page_content for doc in retrieved_docs)})
    answer = llm.invoke(input_prompt)
    return answer