## Semantic Search using [Sentence BERT/Sentence Transformers](https://www.sbert.net/index.html)


In [92]:
!pip install -U sentence-transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [93]:
import os
import pandas as pd
import numpy as np
import torch
#Libraries for sentence transformer
from sentence_transformers import SentenceTransformer, util, InputExample, models
from tqdm.notebook import tqdm_notebook # for progress bar
from torch.utils.data import DataLoader
from sentence_transformers import losses
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
#Libraries for visualization
import seaborn as sns
import matplotlib.pyplot as plt

### 1.Load Dataset Files

In [94]:
def load_data(path):
    import os
    
    #_____________ Read data from CISI.ALL file and store in dictinary ________________
    
    with open(os.path.join(path, 'CISI.ALL')) as f:
        lines = ""
        for l in f.readlines():
            lines += "\n" + l.strip() if l.startswith(".") else " " + l.strip()
        lines = lines.lstrip("\n").split("\n")
 
    doc_set = {}
    doc_id = ""
    doc_text = ""

    for l in lines:
        if l.startswith(".I"):
            doc_id = l.split(" ")[1].strip() 
        elif l.startswith(".X"):
            doc_set[doc_id] = doc_text.lstrip(" ")
            doc_id = ""
            doc_text = ""
        else:
            doc_text += l.strip()[3:] + " " 

    print(f"Number of documents = {len(doc_set)}")
    print(doc_set["1"]) 
    
    
    #_____________ Read data from CISI.QRY file and store in dictinary ________________
    
    with open(os.path.join(path, 'CISI.QRY')) as f:
        lines = ""
        for l in f.readlines():
            lines += "\n" + l.strip() if l.startswith(".") else " " + l.strip()
        lines = lines.lstrip("\n").split("\n")
          
    qry_set = {}
    qry_id = ""
    for l in lines:
        if l.startswith(".I"):
            qry_id = l.split(" ")[1].strip() 
        elif l.startswith(".W"):
            qry_set[qry_id] = l.strip()[3:]
            qry_id = ""

    print(f"\n\nNumber of queries = {len(qry_set)}")    
    print(qry_set["1"]) 
    
    
    #_____________ Read data from CISI.REL file and store in dictinary ________________
    
    rel_set = {}
    with open(os.path.join(path, 'CISI.REL')) as f:
        for l in f.readlines():
            qry_id = l.lstrip(" ").strip("\n").split("\t")[0].split(" ")[0] 
            doc_id = l.lstrip(" ").strip("\n").split("\t")[0].split(" ")[-1]

            if qry_id in rel_set:
                rel_set[qry_id].append(doc_id)
            else:
                rel_set[qry_id] = []
                rel_set[qry_id].append(doc_id)

    print(f"\n\nNumber of mappings = {len(rel_set)}")
    print(rel_set["1"]) 
    
    doc_set = {int(id):doc for (id,doc) in doc_set.items()}
    qry_set = {int(id):qry for (id,qry) in qry_set.items()}
    rel_set = {int(qid):list(map(int, did_lst)) for (qid,did_lst) in rel_set.items()}
    
    return doc_set, qry_set, rel_set

In [95]:
doc_set, qry_set, rel_set = load_data('')

Number of documents = 1460
18 Editions of the Dewey Decimal Classifications Comaromi, J.P. The present study is a history of the DEWEY Decimal Classification.  The first edition of the DDC was published in 1876, the eighteenth edition in 1971, and future editions will continue to appear as needed.  In spite of the DDC's long and healthy life, however, its full story has never been told.  There have been biographies of Dewey that briefly describe his system, but this is the first attempt to provide a detailed history of the work that more than any other has spurred the growth of librarianship in this country and abroad. 


Number of queries = 112
What problems and concerns are there in making up descriptive titles? What difficulties are involved in automatically retrieving articles from approximate titles? What is the usual relevance of the content of articles to their titles?


Number of mappings = 76
['28', '35', '38', '42', '43', '52', '65', '76', '86', '150', '189', '192', '193', '1

In [96]:
import fileinput
nltk.download('stopwords')
nltk.download('punkt')
import nltk
import os
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import string
from nltk import word_tokenize,sent_tokenize
import nltk

def lower_str(para):
    return para.lower()

def get_tokenize(para):
    return word_tokenize(para)

def get_stop_word(word_token):
    stop_words = set(stopwords.words('english'))
    without_stop_word=[]
    for w in word_token:
        if w not in stop_words:
            without_stop_word.append(w)
    return without_stop_word

def remove_punctuation(test_str):
    test_str = test_str.translate(str.maketrans('', '', string.punctuation))
    return test_str
def remove_blank_spces(tokens):
     new_token=[]
     for  i in tokens:
         if i==" ":
             continue
         else:
             new_token.append(i)
     return new_token



def edit_file(content):


        content= content

        # print("intital content",content)
        content= lower_str(content)
        content=remove_punctuation(content)
        tokens=get_tokenize(content)
        final_tokens=get_stop_word(tokens)
        # print("final_content",final_tokens)

        final_items=""
        for item in final_tokens:
            final_items=final_items + item + " "
        return final_items



[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [97]:
for i in doc_set.keys():
  doc_set[i]=edit_file(doc_set[i])


In [98]:
print(qry_set)

{1: 'What problems and concerns are there in making up descriptive titles? What difficulties are involved in automatically retrieving articles from approximate titles? What is the usual relevance of the content of articles to their titles?', 2: 'How can actually pertinent data, as opposed to references or entire articles themselves, be retrieved automatically in response to information requests?', 3: 'What is information science?  Give definitions where possible.', 4: 'Image recognition and any other methods of automatically transforming printed text into computer-ready form.', 5: 'What special training will ordinary researchers and businessmen need for proper information management and unobstructed use of information retrieval systems? What problems are they likely to encounter?', 6: 'What possibilities are there for verbal communication between computers and humans, that is, communication via the spoken word?', 7: 'Describe presently working and planned systems for publishing and pri

In [99]:
for i in qry_set.keys():
  qry_set[i]=edit_file(qry_set[i])

In [100]:
print(qry_set)

{1: 'problems concerns making descriptive titles difficulties involved automatically retrieving articles approximate titles usual relevance content articles titles ', 2: 'actually pertinent data opposed references entire articles retrieved automatically response information requests ', 3: 'information science give definitions possible ', 4: 'image recognition methods automatically transforming printed text computerready form ', 5: 'special training ordinary researchers businessmen need proper information management unobstructed use information retrieval systems problems likely encounter ', 6: 'possibilities verbal communication computers humans communication via spoken word ', 7: 'describe presently working planned systems publishing printing original papers computer saving byproduct articles coded dataprocessing form use retrieval ', 8: 'describe information retrieval indexing languages bearing science general ', 9: 'possibilities automatic grammatical contextual analysis articles inc

### 2.Train-Validation Data Split

In [101]:
#Split data into 70-30 
qry_set_train_split = int(len(qry_set)*70/100)
doc_set_train_split = int(len(doc_set)*70/100 )
print(f"Train split\n qry length: {qry_set_train_split}, doc length: {doc_set_train_split}\n")

qry_set_valid_split = int(len(qry_set)-qry_set_train_split)
doc_set_valid_split = int(len(doc_set)-doc_set_train_split)
print(f"Train split\n qry length: {qry_set_valid_split}, doc length: {doc_set_valid_split}")

Train split
 qry length: 78, doc length: 1022

Train split
 qry length: 34, doc length: 438


In [102]:
def is_similar(query_id, doc_id):
  return 0.9 if query_id in rel_set.keys() and doc_id in rel_set[query_id] else 0.1

In [103]:
#Reformat the input data for Finetuning Sentence Transformer using MNR Loss
train_samples = []

for query_id in tqdm_notebook(range(qry_set_train_split)):
  for doc_id in range(doc_set_train_split):
    train_samples.append(
        InputExample(
            texts=[qry_set[query_id+1], doc_set[doc_id+1]],
            label= is_similar(query_id+1, doc_id+1)
            ))

  0%|          | 0/78 [00:00<?, ?it/s]

In [104]:
#Reformat the input data for Finetuning Sentence Transformer using MNR Loss
validation_samples = []

for query_id in tqdm_notebook(range(qry_set_train_split, len(qry_set))):
  for doc_id in range(doc_set_train_split, len(doc_set)):
    validation_samples.append(
        InputExample(
            texts=[qry_set[query_id+1], doc_set[doc_id+1]],
            label= is_similar(query_id+1, doc_id+1)
            ))

  0%|          | 0/34 [00:00<?, ?it/s]

### 3. Finetune SBERT

In [121]:
#Define Model, Dataloader, loss, evaluator
model = SentenceTransformer('all-MiniLM-L6-v2')

loader = DataLoader(train_samples, shuffle=True, batch_size=16)

loss = losses.CosineSimilarityLoss(model)

In [122]:
#train for a single epoch and warmup for the first 10% of our training steps.
model.fit(
    train_objectives=[(loader, loss)],
    epochs=3,
    warmup_steps = 100
)

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4983 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4983 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4983 [00:00<?, ?it/s]

### 4.Compute embeddings and similarity

In [123]:
#compute the embeddings for the corpus
docs_embeddings = model.encode(list(doc_set.values()), convert_to_tensor=True)

In [124]:
docs_embeddings.shape

torch.Size([1460, 384])

In [125]:
#compute the embeddings for the queries
query_embeddings = model.encode(list(qry_set.values()), convert_to_tensor=True)

In [126]:
'''By default, 
query_chunk_size: process 100 queries simultaneously. 
corpus_chunk_size: scans 100k entries in the corpus at a time.
score_function: cosine similarity for computing scores.
'''
#Returns list of Top 10 docs_id and scores for each query
top_k_scores = util.semantic_search(query_embeddings, docs_embeddings, top_k=10)

In [127]:
#Get predicted TOP-10 doc_id for each query
predicted_doc_scores = {}

for i, query in enumerate(top_k_scores):
  predicted_doc_scores[i+1] = []
  for doc_scores in query:
    predicted_doc_scores[i+1].append(doc_scores['corpus_id']+1)

### 6. Performance Metrics 

#### 1] Recall@K

In [128]:
# Recall@K = TP/(TP+FN)
def recall_k(ground_truth, predictions, k):
  avg_recall = 0
  for query_id in ground_truth:    
    truth_set = set(ground_truth[query_id])
    pred_set = set(predictions[query_id][:k])
    result = round(len(truth_set & pred_set) / float(len(truth_set)), 2) 
    avg_recall += result
    avg_recall /= len(ground_truth)

  return round(avg_recall, 3)

In [129]:
print(f"Recall@10 = {recall_k(rel_set, predicted_doc_scores, 10)}") 

Recall@10 = 0.009


#### 2] Precision@K [order-unaware]

In [130]:
# Precision@K = TP/(TP+FP)
def precision_k(ground_truth, predictions, k):
  avg_precision = 0
  for query_id in ground_truth:    
    truth_set = set(ground_truth[query_id]) 
    pred_set = set(predictions[query_id][:k])
    result = round(len(truth_set & pred_set) / float(len(pred_set)), 2) 
    avg_precision += result
  avg_precision /= len(ground_truth)

  return round(avg_precision, 3)

In [131]:
print(f"Precision@10 = {precision_k(rel_set, predicted_doc_scores, 10)}")

Precision@10 = 0.721


#### 3] Mean Reciprocal Rank (MRR) [order-aware]

In [132]:
def get_first_relevent_docid(predictions, truth):
    for doc_id in predictions:
        is_exist = doc_id in truth 
        if is_exist:
            return predictions.index(doc_id)+1 
    else:
        return -1

In [133]:
def mrr(doc_scores, rel_set):
    Q = len(rel_set) 
    cumulative_reciprocal = 0  
    
    for query_id in rel_set:
        first_result = get_first_relevent_docid(doc_scores[query_id], rel_set[query_id])
        first_result_rank = len(doc_scores[1])+1 if first_result<1 else first_result 
        reciprocal = 1 / first_result_rank
        cumulative_reciprocal += reciprocal
        
    mrr = 1/Q * cumulative_reciprocal 
    return round(mrr,3)

In [134]:
mrr = mrr(predicted_doc_scores, rel_set)
print(f"Mean Reciprocal Rank (MRR): {mrr}")

Mean Reciprocal Rank (MRR): 0.883


#### 4.4 Mean Average Precision (MAP) [order-aware]

In [135]:
def map_k(rel_set, doc_scores, K):
    Q = len(rel_set) 
    avg_precision = [] 

    for query_id in rel_set:
        precision_relevance_summation = 0

        for k in range(0,K):
            # calculate precision@k
            truth_set = set(rel_set[query_id])
            pred_set = set(doc_scores[query_id][:k+1])
            precision_at_k = round(len(truth_set & pred_set) / float(len(pred_set)), 2) 
            
            rel_k = 1 if doc_scores[query_id][k] in rel_set[query_id] else 0 
            precision_relevance_summation += precision_at_k * rel_k 
            
        avg_precision_q = precision_relevance_summation / len(rel_set[query_id])
        avg_precision.append(avg_precision_q)

    map_k = sum(avg_precision) / Q 
    return round(map_k, 3)

In [136]:
map_10 = map_k(rel_set, predicted_doc_scores, K=10)
print(f"MAP@10 = {map_10}")

MAP@10 = 0.282
