<a href="https://colab.research.google.com/github/duonghiepit/Review-Text-Retrival/blob/main/project_sentence_transformers_text_retrieval_solution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets sentence_transformers

In [None]:
from datasets import load_dataset

data_set = load_dataset('ms_marco', 'v1.1')

In [6]:
subset = data_set['test']

In [24]:
queries_infos = []
queries = []
corpus = []

for sample in subset:
    query_type = sample['query_type']
    if query_type != 'entity':
        continue
    query_id = sample['query_id']
    query_str = sample['query']
    passages_dict = sample['passages']
    is_selected_lst = passages_dict['is_selected']
    passage_text_lst = passages_dict['passage_text']
    query_info = {
        'query_id': query_id,
        'query': query_str,
        'relevant_docs': []
    }
    current_len_corpus = len(corpus)
    for idx in range(len(is_selected_lst)):
        if is_selected_lst[idx] == 1:
            doc_idx = current_len_corpus + idx
            query_info['relevant_docs'].append(doc_idx)

    if query_info['relevant_docs'] == []:
        continue

    queries.append(query_str)
    queries_infos.append(query_info)
    corpus += passage_text_lst

In [25]:
import string
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer

nltk.download('stopwords')
english_stopwords = stopwords.words('english')
remove_chars = string.punctuation
stemmer = PorterStemmer()

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [26]:
def text_normalize(text):
    text = text.lower()
    for char in remove_chars:
        text.replace(char, '')
    text = ' '.join([word for word in text.split() if word not in english_stopwords])
    text = ' '.join([stemmer.stem(word) for word in text.split()])

    return text

# Custom search function

In [None]:
import torch
from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer('all-MiniLM-L6-v2')
corpus_embeddings = model.encode(corpus, convert_to_tensor=True)

In [None]:
custom_queries = ['what is facebook']

top_k = min(5, len(corpus))
for query in custom_queries:
    query_embeddings = model.encode(query, convert_to_tensor=True)

    cos_scores = util.cos_sim(query_embeddings, corpus_embeddings)[0]
    top_results = torch.topk(cos_scores, k=top_k)

    print("\n\n======================")
    print("Query:", query)
    print("Top 5 most similar sentences in corpus:\n")

    for idx, (score, doc_idx) in enumerate(zip(top_results[0], top_results[1])):
        print(f'Document rank {idx + 1}:')
        print(corpus[idx], f'\n(Score: {score:.4f})', '\n')