<a href="https://colab.research.google.com/github/nrimsky/qa/blob/main/paper_qa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install -U InstructorEmbedding sentence-transformers pylatexenc faiss-cpu langchain openai

In [None]:
import os
import requests
import shutil
import tarfile
import re
from InstructorEmbedding import INSTRUCTOR
from pylatexenc.latex2text import LatexNodes2Text
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
import torch
import os
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
os.environ['OPENAI_API_KEY'] = input("Paste OpenAI API Key: ")

In [None]:
model = INSTRUCTOR('hkunlp/instructor-xl')

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Using device", device)

INDEX_TEXT = "Represent this section of a Machine Learning paper for retrieval given a question about the paper:"
RETRIEVAL_TEXT = "Represent this question about a Machine Learning paper for retrieving relevant sections of the paper:"

def encode_instructor(instruction, sentences):
    return model.encode([[instruction,sentence] for sentence in sentences])

def latex_to_text(latex_str):
    try:
      l2t = LatexNodes2Text()
      return l2t.latex_to_text(latex_str)
    except:
      return latex_str

def get_source(arxiv_id):
    source_url = f'https://arxiv.org/e-print/{arxiv_id}'
    response = requests.get(source_url, stream=True)
    drive_path = '/content/drive/My Drive/'
    if response.status_code == 200:
        with open(f'{drive_path}{arxiv_id}.tar.gz', 'wb') as f:
            response.raw.decode_content = True
            shutil.copyfileobj(response.raw, f)
    else:
        print(f'Error: received status code {response.status_code} from arXiv.')
    with tarfile.open(f'{drive_path}{arxiv_id}.tar.gz', 'r:gz') as f:
        f.extractall(path=f'{drive_path}{arxiv_id}_source_files')
    source_dir = f'{drive_path}{arxiv_id}_source_files'
    tex_files = [f for f in os.listdir(source_dir) if f.endswith('.tex')]
    file_contents = []
    for tex_file in tex_files:
        with open(os.path.join(source_dir, tex_file), 'r') as f:
            file_content = f.read()
            file_contents.append(file_content)
    if len(file_contents) <= 1:
      print(f"Failed to extract enough source data - file content size = {len(file_contents[0])} chars")
    return file_contents

def clean_text(text):
    clean = re.sub("\n{3,}", "\n\n", text)
    return re.sub("={4,}", "\n", clean)

def extract_all_text_chunks(arxiv_id, n_character_chunks=1000):
    file_contents = get_source(arxiv_id)
    text_chunks = []
    for file_content in file_contents:
        text = latex_to_text(file_content)
        text = clean_text(text)
        text_chunks += [text[i:i + n_character_chunks] for i in range(0, len(text), n_character_chunks)]
    return text_chunks

class InstructorEmbeddings(Embeddings):

    def embed_documents(self, texts):
        return encode_instructor(INDEX_TEXT, texts)

    def embed_query(self, text):
        return encode_instructor(RETRIEVAL_TEXT, [text])[0]

def cli_ask_questions(arxiv_id):
    chunks = extract_all_text_chunks(arxiv_id)
    embeddings = InstructorEmbeddings()
    vectorstore = FAISS.from_texts(chunks, embeddings, [{"index": i} for i in range(len(chunks))])

    chain_type_kwargs = {
        "prompt": ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(
                "You are a helpful assistant that answers questions about research papers given some snippets from the paper. Whenever possible, you quote directly from the snippets, putting the quote in quotation marks."
            ),
            HumanMessagePromptTemplate.from_template("""
                Some relevant snippets:

                {context}

                Question: {question}
                Answer:
            """)
        ])
    }

    qa = RetrievalQA.from_chain_type(
        llm=ChatOpenAI(model_name='gpt-3.5-turbo'),
        chain_type="stuff",
        retriever=vectorstore.as_retriever(),
        chain_type_kwargs=chain_type_kwargs
    )

    while True:
        question = input("Enter your question about the paper (or 'quit' to stop): ")
        if question.lower() == 'quit':
            break
        else:
            try:
                print(qa.run(question))
            except Exception as e:
                print("An error occurred while processing your question.")
                print(str(e))


Using device cuda


In [None]:
arxiv_id = input("Enter the id of the Arxiv paper you want to ask questions about: ")
cli_ask_questions(arxiv_id)

Enter the id of the Arxiv paper you want to ask questions about: 2303.10798
['\n\n', '\n\n§ INTRODUCTION\n\n\t\n    < g r a p h i c s >\n\n The gray “hat” polykite tile\nis an “einstein", an aperiodic monotile. In other words, copies of this tile may be assembled into tilings of the plane (the tile “admits" tilings), yet copies of the tile cannot form periodic tilings, tilings that have translational symmetry.  In fact, the tile admits uncountably many tilings. In Sections\xa0<ref>,\xa0<ref>, and\xa0<ref> we describe how these tilings all arise from substitution rules, showing that they all have the same local structure.\n\nGiven a set of two-dimensional tiles,  the nature of the planar tilings that they admit arises from a deep interaction between the local and the \nglobal.  Constraints on the ways that pairs of tiles can \nbe neighbours determine the structure of an infinite tiling, at all large scales. Constraints encoded in a set of tiles determine the structure of the space of th