In [1]:
import json
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import gzip
import os
from PyPDF2 import PdfReader

In [2]:
query = "What is Atma?"

In [3]:
#  To encode all passages
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
top_k=32 # no of passages to be retrieved from bi-encoder

In [4]:
# To re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')


In [5]:
FILE_PATH = 'C:/Users/v-ankbhagat/SemanticRanker/DOCS/Gita.pdf'

In [6]:
def get_pdf_data(file_path):
    reader = PdfReader(file_path)
    full_doc_text = ""
    pages = reader.pages
    num_pages = len(pages)

    try:
        for page in range(num_pages):
            current_page = reader.pages[page]
            text = current_page.extract_text()
            full_doc_text += text
    except:
        print("Error reading the file")
    finally:
        return full_doc_text

In [7]:
def get_chunks(fulltext:str, chunk_length=500) -> list:
    text = fulltext

    chunks = []

    while len(text) > chunk_length:
        last_period_index = text[:chunk_length].rfind('.')
        if last_period_index == -1:
            last_period_index = chunk_length
        chunks.append(text[:last_period_index])
        text = fulltext[last_period_index+1:]
    chunks.append(text)

    return chunks

In [8]:
full_doc_text = get_pdf_data(FILE_PATH)
print(f'Full doc text length: {len(full_doc_text)}')

chunks = get_chunks(full_doc_text)
print(f"# of chunks: {chunks}")

Full doc text length: 104274


In [None]:
# generate embeddings for chunks and question

chunk_embeddings = bi_encoder.encode(chunks, show_progress_bar=True)
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)

# Semantic Search

In [None]:
search_results = util.semantic_search(question_embedding, chunk_embeddings, top_k=top_k)
search_results = search_results[0] # get the search score for the first query
search_results

In [None]:
len(search_results)

# Semantic Reranking

In [None]:
cross_input = [[query, chunks[search_result['corpus_id']]] for search_result in search_results]
cross_scores = cross_encoder.predict(cross_input)