In [None]:
import numpy as np
import os
import warnings
from pdfminer.converter import TextConverter
from pdfminer.layout import LAParams
from pdfminer.pdfdocument import PDFDocument
from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
from pdfminer.pdfpage import PDFPage
from pdfminer.pdfparser import PDFParser
from typing import Dict, List, Tuple

import openai
import tiktoken

import re
import io
from tqdm import tqdm
from nltk.tokenize import sent_tokenize
warnings.filterwarnings("ignore")

In [None]:
def load_pdf_document(
        file_path: str, overlap: int = 0) -> Tuple[List[Dict[str, str]], Dict[int, str]]:
    """Load a document from a file."""
    doc = []
    para_dict = {}
    resource_manager = PDFResourceManager()
    fake_file_handle = io.StringIO()
    converter = TextConverter(
        resource_manager,
        fake_file_handle,
        laparams=LAParams())
    page_interpreter = PDFPageInterpreter(resource_manager, converter)

    para_id = 0
    prev_para = None

    with open(file_path, 'rb') as file:
        pdf_document = PDFDocument(PDFParser(file))
        for page_number, page in enumerate(
                PDFPage.create_pages(pdf_document), start=1):
            page_interpreter.process_page(page)
            text = fake_file_handle.getvalue()

            paragraphs = re.split(r'(?<!\d)\n\n|(?<=[a-zA-Z])\n\n', text)
            for paragraph in paragraphs:

                paragraph = re.sub(r"-\n|\n", "", paragraph)

                if re.search(
                    r'\d$|\d \w+$|\d (\w+ ){1,2}\w+$|^\w+$',
                        paragraph):
                    prev_para = paragraph
                    continue

                if prev_para:
                    doc.append({"page_content": prev_para, "metadata": {
                                "source": file_path, "page": page_number, "para": para_id}})
                    prev_para = None

                if paragraph.strip() != '' and bool(
                        re.search('[a-zA-Z0-9]', paragraph)):

                    para_dict[para_id] = paragraph

                    sentences = sent_tokenize(paragraph)
                    for i in range(len(sentences)):

                        window_sentences = sentences[i:i + overlap + 1]

                        sentence = ' '.join(window_sentences)

                        sentence = re.sub(r"-\n|\n", "", sentence)
                        if sentence.strip() != '' and bool(
                                re.search('[a-zA-Z]', sentence)):
                            doc.append({"page_content": sentence, "metadata": {
                                       "source": file_path, "page": page_number, "para": para_id}})
                    para_id += 1

            fake_file_handle.truncate(0)
            fake_file_handle.seek(0)

    if prev_para:
        doc.append({"page_content": prev_para, "metadata": {
                    "source": file_path, "page": page_number, "para": para_id - 1}})

    converter.close()
    fake_file_handle.close()

    return doc, para_dict

In [None]:
documents, para_dict = load_pdf_document(
    '/path', overlap=3)
print(len(documents))

In [None]:
from InstructorEmbedding import INSTRUCTOR

model = INSTRUCTOR('hkunlp/instructor-xl')

sentence_pairs = []
for chunk in documents:
    sentence = chunk["page_content"]
    instruction = "Represent the Research paper document chunk for retrieval:"
    sentence_pairs.append([instruction, sentence])

embeddings = model.encode(
    sentence_pairs,
    batch_size=20,
    normalize_embeddings=True,
    show_progress_bar=True)

for chunk, embedding in zip(documents, embeddings):
    chunk["embeddings"] = np.array(embedding)

In [None]:
import umap.umap_ as umap
import hdbscan


reducer = umap.UMAP(n_neighbors=50, n_components=10, random_state=4)
embeddings_2d = reducer.fit_transform(embeddings)


clusterer = hdbscan.HDBSCAN(min_cluster_size=3, gen_min_span_tree=True)


clusterer.fit(embeddings_2d)


cluster_labels = clusterer.labels_


for chunk, cluster_id in zip(documents, cluster_labels):
    chunk["metadata"]["cluster"] = cluster_id

In [None]:
def cosine(x, y):
    return (np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))).item()


def check_no_of_tokens(prompt):
    encoding = tiktoken.get_encoding("cl100k_base")
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
    return len(encoding.encode(prompt))


def create_prompt(query, top_chunks_with_context):
    prompt = "Your task is to answer questions based on a research paper." + "\n\n"
    prompt += "Answer the following question: " + query + "\n\n" + "Below is the only information you have about the paper, wade through the irrelevant text and use the useful text. Do not expect extremely precise information." + \
        "\n\n" + "\n\n".join(chunk for chunk in top_chunks_with_context)
    prompt += "\n\n" + "Answer:"
    return prompt


openai.api_key = "KEY"
TOPK = 20
PROMPT_LENGTH = 800


instruction = "Represent the Research paper question for retrieving supporting document:"
query = " How to explain the phenomenon observed in this paper?"
query_embedding = model.encode([[instruction, query]])

similarities = [
    (chunk, cosine(query_embedding, chunk["embeddings"].T))
    for _, chunk in enumerate(documents)]

top_chunks = sorted(similarities, key=lambda x: x[1], reverse=True)[:TOPK]
top2_cluster_ids = [chunk[0]["metadata"]["cluster"] for chunk in top_chunks[:2]]
filtered_top_chunks = [
    chunk for chunk in top_chunks
    if chunk[0]["metadata"]["cluster"] in top2_cluster_ids]


top_para_ids = list(set([chunk[0]["metadata"]["para"]
                    for chunk in filtered_top_chunks]))
top_chunks_with_context = [para_dict[para_id] for para_id in top_para_ids]

prompt = create_prompt(query, top_chunks_with_context)
while check_no_of_tokens(prompt) > PROMPT_LENGTH:
    top_chunks_with_context = top_chunks_with_context[:-1]
    prompt = create_prompt(query, top_chunks_with_context)

print("Number of tokens used in prompt: ", check_no_of_tokens(prompt))

response = openai.ChatCompletion.create(
    model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}])
print(response.choices[0].message.content)