<a href="https://colab.research.google.com/github/hiwei93/rag-practice/blob/main/Rerank.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Rerank

Cohere Rerank 实现参考：
- [LangChain document: Cohere Reranker](https://python.langchain.com/docs/integrations/retrievers/cohere-reranker)

Cross Encoder 实现参考：
- [Issue: how to use reranker model with langchain in retrievalQA case?](https://github.com/langchain-ai/langchain/issues/13076#issuecomment-1839814250)
- [sentence-transformers example: Retrieve & Re-Rank Demo over Simple Wikipedia](https://github.com/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb)

## 依赖安装

In [1]:
!pip install langchain cohere chromadb --quiet

## 加载数据

In [2]:
from tqdm import tqdm
import requests


def http_get(url, path) -> None:
    """
    Downloads a URL to a given path on disc
    """
    if os.path.dirname(path) != "":
        os.makedirs(os.path.dirname(path), exist_ok=True)

    req = requests.get(url, stream=True)
    if req.status_code != 200:
        print("Exception when trying to download {}. Response {}".format(url, req.status_code), file=sys.stderr)
        req.raise_for_status()
        return

    download_filepath = path + "_part"
    with open(download_filepath, "wb") as file_binary:
        content_length = req.headers.get("Content-Length")
        total = int(content_length) if content_length is not None else None
        progress = tqdm(unit="B", total=total, unit_scale=True)
        for chunk in req.iter_content(chunk_size=1024):
            if chunk:  # filter out keep-alive new chunks
                progress.update(len(chunk))
                file_binary.write(chunk)

    os.rename(download_filepath, path)
    progress.close()

In [3]:
import json
import gzip
import os

wikipedia_filepath = 'simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
    http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)

passages = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())

        #Add all paragraphs
        #passages.extend(data['paragraphs'])

        #Only add the first paragraph
        passages.append(data['paragraphs'][0])

print("Passages:", len(passages))


Passages: 169597


In [4]:
average = sum((len(p) for p in passages)) / len(passages)
max_length = max((len(p) for p in passages))

print(average, max_length)

235.19272746569808 3644


In [5]:
passages = [p for p in passages if len(p) <= 512]

In [6]:
len(passages)

160443

In [None]:
passages[:100]

## 构建向量查询器

In [8]:
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Chroma


from google.colab import userdata

# 获取 Huggingface token
inference_api_key = userdata.get('hf_token')

embedding = HuggingFaceInferenceAPIEmbeddings(
    api_key=inference_api_key, model_name="sentence-transformers/all-MiniLM-L6-v2"
)

In [9]:
vectordb = Chroma.from_texts(
    texts=passages[:500],
    embedding=embedding
)

In [12]:
vectordb._collection.count()

500

In [18]:
vectordb.similarity_search_with_score("What's Chinese New Year?")

[(Document(page_content='Chinese New Year, known in China as the SpringFestival and in Singapore as the LunarNewYear, is a holiday on and around the new moon on the first day of the year in the traditional Chinese calendar. This calendar is based on the changes in the moon and is only sometimes changed to fit the seasons of the year based on how the Earth moves around the sun. Because of this, Chinese New Year is never on January1. It moves around between January21 and February20.'),
  0.44440481066703796),
 (Document(page_content='In Hinduism and Buddhism, a dakini is a female being like a goddess. They are mostly found in Tibetan Buddhism.Chinese: 空行母, Pinyin: Kōngxíng Mǔ and 狐仙,Pinyin:Hú xian ;明妃,Pinyin:Míng fēi｝ The dakini inspires spiritual practice. A dakini is often depicted as beautiful and naked. The nakedness represents the freedom of the mind.'),
  1.4540748596191406),
 (Document(page_content='Lu Sheng-Yen (盧勝彥, Lú Shèngyàn) (27 June 1945), is the founder and spiritual leade

## Cohere Rerank

In [20]:
from google.colab import userdata
cohere_token = userdata.get('cohere')

In [21]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank

compressor = CohereRerank(cohere_api_key=cohere_token)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=vectordb.as_retriever()
)

compression_retriever.get_relevant_documents(
    "What's Chinese New Year"
)

[Document(page_content='Chinese New Year, known in China as the SpringFestival and in Singapore as the LunarNewYear, is a holiday on and around the new moon on the first day of the year in the traditional Chinese calendar. This calendar is based on the changes in the moon and is only sometimes changed to fit the seasons of the year based on how the Earth moves around the sun. Because of this, Chinese New Year is never on January1. It moves around between January21 and February20.', metadata={'relevance_score': 0.9995955}),
 Document(page_content='Lu Sheng-Yen (盧勝彥, Lú Shèngyàn) (27 June 1945), is the founder and spiritual leader of the True Buddha School, which is a religious group with teachings taken from Taoism and Buddhism. He is called Master Lu by his followers. Within his sect, he is also known as "Living Buddha Lian Sheng" (蓮生活佛, "Liansheng Huófó"). He is worshipped by his followers as a "Living Buddha".', metadata={'relevance_score': 0.08299415}),
 Document(page_content='In Hi

## Cross Encoder Rerank

模拟 CohereRerank 实现 CrossEncoderRerank

### Cross-encoder 模型访问封装

使用 Huggingface inference api 来调用 Cross-encoder 模型

In [22]:
import requests
class CrossEncoder(object):
    url = "https://api-inference.huggingface.co/models"

    def __init__(self, model_name: str, token: str):
        self.model_name = model_name
        self.token = token

    def get_similary_score(self, query, answer) -> float:
        url = f"{self.url}/{self.model_name}"
        body = {"text": query, "text_pair": answer}
        headers = {"token": self.token}
        resp = requests.post(url, json=body, headers=headers)
        print(resp.text)
        result = resp.json()
        return result[0]['score']

In [23]:
from google.colab import userdata

model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
token = userdata.get('hf_token')

In [24]:
encoder = CrossEncoder(model_name, token)
print(encoder.get_similary_score("Who like apples?", "I told all my friends that I like apples"))

[{"label":"LABEL_0","score":0.9944524168968201}]
0.9944524168968201


### CrossEncoderRerank 实现

In [25]:
from typing import Optional, Sequence, Dict

from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra, root_validator


class CrossEncoderRerank(BaseDocumentCompressor):
    """Document compressor that uses `Cross Encoder to Rerank`."""

    model: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"
    """Model to use for reranking."""
    encoder: CrossEncoder = None
    token: str = None

    class Config:
        """Configuration for this pydantic object."""
        extra = Extra.forbid
        arbitrary_types_allowed = True

    @root_validator(pre=True)
    def validate(cls, values: Dict) -> Dict:
        model_name = values.get("model")
        token = values.get('hf_token')
        values['encoder'] = CrossEncoder(model_name, token)
        return values

    def _compute_score(self, query, content) -> float:
        return self.encoder.get_similary_score(query, content)

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        """
        Compress documents using Cross Encoder Inference API.

        Args:
            documents: A sequence of documents to compress.
            query: The query to use for compressing the documents.
            callbacks: Callbacks to run during the compression process.

        Returns:
            A sequence of compressed documents.
        """
        if len(documents) == 0:  # to avoid empty api call
            return []
        doc_list = list(documents)
        _docs = [d.page_content for d in doc_list]
        for doc in doc_list:
            score = self._compute_score(query, doc.page_content)
            doc.metadata["relevance_score"] = score
        doc_list.sort(key=lambda x: x.metadata["relevance_score"], reverse=True)
        return doc_list


In [60]:
reranker = CrossEncoderRerank(model=model_name, token=token)
reranker._compute_score("Who like apples?", "I told all my friends that I like apples")

[{"label":"LABEL_0","score":0.9944524168968201}]


0.9944524168968201

### CrossEncodRerank 使用

In [26]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank

compressor = CrossEncoderRerank(model=model_name, token=token)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=vectordb.as_retriever()
)

compression_retriever.get_relevant_documents(
    "What's Chinese New Year?"
)

[{"label":"LABEL_0","score":0.9999721050262451}]
[{"label":"LABEL_0","score":2.4164222850231454e-05}]
[{"label":"LABEL_0","score":4.29442516178824e-05}]
[{"label":"LABEL_0","score":1.2482765669119544e-05}]


[Document(page_content='Chinese New Year, known in China as the SpringFestival and in Singapore as the LunarNewYear, is a holiday on and around the new moon on the first day of the year in the traditional Chinese calendar. This calendar is based on the changes in the moon and is only sometimes changed to fit the seasons of the year based on how the Earth moves around the sun. Because of this, Chinese New Year is never on January1. It moves around between January21 and February20.', metadata={'relevance_score': 0.9999721050262451}),
 Document(page_content='Lu Sheng-Yen (盧勝彥, Lú Shèngyàn) (27 June 1945), is the founder and spiritual leader of the True Buddha School, which is a religious group with teachings taken from Taoism and Buddhism. He is called Master Lu by his followers. Within his sect, he is also known as "Living Buddha Lian Sheng" (蓮生活佛, "Liansheng Huófó"). He is worshipped by his followers as a "Living Buddha".', metadata={'relevance_score': 4.29442516178824e-05}),
 Document(

## 对比 see sharp 检索词的效果

数据集前 500 个例子中，有一个 see sharp 的例子无法通过搜索找到，定位问题中。

> C# (pronounced "see sharp") is a computer programming language. It is developed by Microsoft. It was created to use all capacities of .NET platform. The first version was released in 2001. The most recent version is C# 8.0, which was released in September 2019. C# is a modern language. C#\'s development team is led by Anders Hejlsberg, the creator of Delphi.',
 'Wilmington is a city in New Hanover County, North Carolina, United States.

In [27]:
!pip install rank_bm25 --quiet

In [28]:
# 测试关键词搜索

from langchain.retrievers import BM25Retriever

# initialize the bm25 retriever and faiss retriever
bm25_retriever = BM25Retriever.from_texts(
    passages[:500]
)
bm25_retriever.k = 2

In [32]:
vectordb.similarity_search_with_score("see sharp")

[(Document(page_content='A cross section is what one gets if one cuts an object into slices.'),
  1.4852806329727173),
 (Document(page_content='In geometrical optics, a focus (also called an image point) is the point where light rays that come from a point on the object converge (come together).'),
  1.522754430770874),
 (Document(page_content='"For the band, see U2 (band)."'), 1.6397788524627686),
 (Document(page_content='A high five is a hand gesture done with two people who want to express joy over an common achievement. The five refers to the five fingers on each hand.'),
  1.6554793119430542)]

In [33]:
bm25_retriever.get_relevant_documents("see sharp")

[Document(page_content='"For the band, see U2 (band)."'),
 Document(page_content='Bifocals are eyeglasses with lenses that are split between two different strengths. Usually the lower half of each lens is made to help the wearer read, while the upper one is to help the wearer see at a distance.')]

In [None]:
def keyword_match(keyword):
    return [p for p in passages[:500] if keyword in p]

In [None]:
keyword_match("see sharp")

['C# (pronounced "see sharp") is a computer programming language. It is developed by Microsoft. It was created to use all capacities of .NET platform. The first version was released in 2001. The most recent version is C# 8.0, which was released in September 2019. C# is a modern language. C#\'s development team is led by Anders Hejlsberg, the creator of Delphi.']