<a href="https://colab.research.google.com/github/donghuna/AI-Expert/blob/main/%EC%9D%B4%EB%8F%99%ED%95%98/ir_bi_cross_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Document Retrieval with Bi-Encoder and Cross-Encoder

In this course, we will cover the process of utilizing bi-encoder and cross encoder to retrieve documents relevant to our query

<div style="text-align : center;">
    <img width="800" alt="bi_encoder" src="https://github.com/augustinLib/All-of-NLP/assets/74291999/4c196702-aee7-48b8-8ea9-d655a18bae71">
</div>

## Ready to start

We will use the `sentence-transformers` library to implement the bi-encoder and cross-encoder models. The library provides encoder models that can be used to encode text into embeddings.  

`faiss` is a library that is used to perform similarity search on the embeddings. We will use it to retrieve the most relevant documents to our query.  

For Dataset, we will use the `MS-MARCO` dataset. `MS-MARCO` is a collection of queries and web pages from Bing search, and contains queries and documents that are relevant to the queries.

In [None]:
!pip install -U sentence-transformers

In [None]:
!pip install faiss-gpu

In [None]:
!pip install pandas

In [None]:
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1i2Sv9ddy3eWZGNN5_oARmPaCEJJfyCE1' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1i2Sv9ddy3eWZGNN5_oARmPaCEJJfyCE1" -O valid_document.tsv && rm -rf ~/cookies.txt

In [None]:
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1FCEaL3ajZiUWHBbxpR76GVtn7Leladye' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1FCEaL3ajZiUWHBbxpR76GVtn7Leladye" -O document_embedding.pkl && rm -rf ~/cookies.txt

### About Sentence-transformers
- https://sbert.net/
- https://huggingface.co/sentence-transformers


In [None]:
# import the necessary libraries
import pandas as pd
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
import pickle
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize_embeddings(query_embeddings, document_embeddings, rank_df, document_df):
    retrieved_docid = rank_df.loc[:, 'retrieved_doc']
    # concat
    retrieved_doc = []
    for doc in retrieved_docid:
        retrieved_doc.extend(doc)

    retrieved_doc_index = []
    for doc in retrieved_doc:
        retrieved_doc_index.append(np.where(document_df['docid'].values == doc)[0][0])

    document_embeddings = document_embeddings[retrieved_doc_index]

    # reduce the dimensionality of the embeddings
    tsne = TSNE(n_components=2, metric='cosine', random_state=42, perplexity=10)  # Adjust the perplexity value here

    embeddings = np.concatenate([query_embeddings, document_embeddings])
    # fit the tsne on the embeddings
    Y = tsne.fit_transform(embeddings)

    # separate the query and document embeddings
    query_embedding = Y[:len(query_embeddings)]
    document_embedding = Y[len(query_embeddings):]

    # build dataframe for visualization
    query_df = pd.DataFrame(query_embedding, columns=['x', 'y'])
    query_df['index'] = [i for i in range(len(query_embedding))]

    document_df = pd.DataFrame(document_embedding, columns=['x', 'y'])
    document_df['index'] = np.repeat(np.arange(len(query_embeddings)), 5)

    # plot the embeddings
    plt.figure(figsize=(5, 5))
    # scatter plot for query embeddings and document embeddings
    # colors are based on the index
    # each color represents a query and its relevant documents
    plt.scatter(query_df['x'], query_df['y'], c=query_df['index'], cmap='tab20', label='Query')
    plt.scatter(document_df['x'], document_df['y'], c=document_df['index'], cmap='tab20', marker='x', label='Document')
    plt.legend()

In [None]:
# get document data
document = pd.read_csv('valid_document.tsv', sep='\t', dtype=str)
document["docid"] = [int(document.loc[i, "docid"][1:]) for i in range(len(document))]
document

In [None]:
# docid -> content function
def get_content(docid):
    return document[document['docid'] == docid]['doc_tac'].values[0]

def print_document(content):
    content = content.split('.')
    for line in content:
        print(line)

def print_query_document(query, docid):
    print("Query: ", query)
    print()
    print("Document Content: ")
    print(get_content(docid))

# Retrieve document with Bi-encoder

### We will first utilize the bi-encoder structure to retrieve 5 documents.  
The 10 documents are the top-5 most relevant documents to the query,  
and the criterion for relevant is the higher the `inner product` between the query embedding and the document embedding.

### The bi-encoder structure is composed of two parts:
- `Document encoder`: encodes the document into an embedding
- `Query encoder`: encodes the query into an embedding

But in this case, we will use the same encoder for both the document and the query.(`parameter sharing`)

In [None]:
# load the pre-trained model
bi_encoder = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L6-cos-v5')

In [None]:

# generate embeddings for the document
def generate_embeddings(text):
    return bi_encoder.encode(text)

def get_document_embeddings(document):
    embeddings = []
    for text in tqdm(document):
        embeddings.append(generate_embeddings(text))

    embeddings = np.array(embeddings)
    return embeddings

# document embeddings
# document_embeddings = get_document_embeddings(document['doc_tac'])

with open('document_embedding.pkl', 'rb') as f:
    document_embeddings = pickle.load(f)

In [None]:
# document embeddings shape will be (number of documents, embedding size)
document_embeddings.shape

In [None]:
# use faiss to index the embeddings
# when we initialize the index, we need to specify the dimension of the embeddings -> document_embeddings.shape[1]
faiss_index = faiss.IndexFlatIP(document_embeddings.shape[1])

# add the document embeddings to the index
faiss_index = faiss.IndexIDMap2(faiss_index)
faiss_index.add_with_ids(document_embeddings, document['docid'].values)

# check the total number of documents in the index -> it will be equal to the number of documents in the document_embeddings
print(faiss_index.ntotal)

## Queries which we will use to retrieve documents
In this example, we will use the following queries to retrieve documents:
- why do people buy cars

- how many square kilometers is scotland's

- what does terrorism mean and example

- what does the term alien mean

- what part of the eye allows light to enter

Ground truth of these queries are:
- 2921145  

- 8041

- 634663

- 1354086

- 24337


Before we perform the document retrieval process for a query, let's look at each query and its corresponding ground truth document

In [None]:
query = [
"why do people buy cars",
"how many square kilometers is scotland's",
"what does terrorism mean and example",
"what does the term alien mean",
"what part of the eye allows light to enter"
]

gt_docid = [
    2921145,
    8041,
    634663,
    1354086,
    24337
]

In [None]:
print_query_document(query[0], gt_docid[0])

In [None]:
print_query_document(query[1], gt_docid[1])

In [None]:
print_query_document(query[2], gt_docid[2])

In [None]:
print_query_document(query[3], gt_docid[3])

In [None]:
print_query_document(query[4], gt_docid[4])

In [None]:
# generate embeddings for queries
query_embedding = generate_embeddings(query)
query_embedding = np.array(query_embedding).reshape(len(query), -1)

# query embeddings shape will be (number of query, embedding size)
print(query_embedding.shape)

In [None]:
# we use faiss index for document search
# Enter the query embedding as a parameter to the "search" method and the top k counts you want to retrieve
distances, indices = faiss_index.search(query_embedding, 5)
print("distance")
print(distances)
print()

print("top-5 index")
print(indices)

In [None]:
def get_rank(query, indices, gt_docid):
    rank_list = []
    for i in range(len(query)):
        rank = np.where(indices[i] == gt_docid[i])[0][0] + 1
        rank_list.append(rank)

    rank_df = pd.DataFrame({
        "query" : query,
        "retrieved_doc" : list(indices),
        "rank" : rank_list,
    })

    return rank_df

rank_df = get_rank(query=query, indices=indices, gt_docid=gt_docid)
rank_df


In [None]:
visualize_embeddings(query_embedding, document_embeddings, rank_df, document)

You can see that most of the queries retrieved the ground truth document well with rank 1.  

However, for the last query, the rank is low.  

How can we increase the rank for the last query as well?


# Rerank with Cross-encoder


### Let's rerank the query that didn't achieve rank 1.  

We retrieved the top-5 candidate documents with a bi-encoder structure.  

This time, we will use a slightly more powerful structure, the cross-encoder structure, to rerank the candidate documents.

In [None]:
# cross encoder model
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [None]:
# select row whose rank is not 1
not_rank1 = rank_df[rank_df["rank"] != 1].reset_index(drop=True)
not_rank1

In [None]:
# get top-5 document content
content_list = []
top5_docid = not_rank1["retrieved_doc"].values[0]
for docid in top5_docid:
    content_list.append(get_content(docid))

In [None]:
# for cross encoder convert input to list of tuples (query, document)
cross_input = []
for query in not_rank1["query"]:
    for content in content_list:
        cross_input.append((query, content))

cross_input[0]

In [None]:
cross_scores = cross_encoder.predict(cross_input)

In [None]:
docid_score = pd.DataFrame({
    "docid" : top5_docid,
    "score" : cross_scores
})
docid_score = docid_score.sort_values(by="score", ascending=False).reset_index(drop=True)
docid_score

You can see that the ground truth docid has successfully moved up to rank 1