# Reference Paper Chunk Retrieval

The reference papers are chunked and indexed to then retrieve the three most relevant (similar) chunks to the citation statement. LlamaIndex is used to easily implement this full process.

LlamaIndex:
- Github: https://github.com/run-llama/llama_index
- Documentation: https://docs.llamaindex.ai/en/latest/

In [None]:
import pandas as pd
import glob
import xml.etree.ElementTree as ET
import json
import os

from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import Settings
from llama_index.core import VectorStoreIndex, Document
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core import StorageContext, load_index_from_storage

In [None]:
# read the xlsx data into a pandas dataframe
df = pd.read_excel(f"../data/ReferenceErrorDetection_data_extended_annotation.xlsx")

## Getting Text from Reference Papers

In [None]:
# 'txt' if PBTE was used on TEI documents, otherwise ''
extension = "txt"

In [None]:
def get_file_path(reference_article_id):
    # Construct the file path pattern using the Reference Article ID of the first entry
    file_pattern = f"../data/extractions/{'only_text/' if extension == 'txt' else ''}{reference_article_id}*.{extension}"

    # Find the file that matches the pattern
    file_list = glob.glob(file_pattern)
    if file_list:
        file_path = file_list[0]
        return file_path
    else: 
        print("No matching file found.")
        return None

In [None]:
def get_reference_text(reference_article_id):
    # Get the file path
    file_path = get_file_path(reference_article_id)
    
    if file_path:
        if extension == "txt":
            # Read the text file
            with open(file_path, 'r') as file:
                reference_text = file.read()
            return reference_text

        elif extension == "xml":
            # Parse the XML file
            tree = ET.parse(file_path)
            root = tree.getroot()

            # Extract the text content from the XML file
            reference_text = ''.join(root.itertext())
            return reference_text

## Setting OpenAI key
An OpenAI API key needs to be generated and put into a file called "open_ai_key.txt" for the following code to work. 

In [None]:
# Read the content of open_ai_key.txt into a variable
with open('../open_ai_key.txt', 'r') as file:
    open_ai_key = file.read().strip()

## Setting up Vector Index

### Reloading or Generating Index

In [None]:
model_embeddings = "text-embedding-3-large"

In [None]:
Settings.embed_model = OpenAIEmbedding(model=model_embeddings, api_key=open_ai_key)

In [None]:
def create_index(reference_text, chunk_size, chunk_overlap):
    # create the pipeline with transformations
    pipeline = IngestionPipeline(
        transformations=[
            SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap),
            OpenAIEmbedding(model=model_embeddings, api_key=open_ai_key)
        ]
    )

    # run the pipeline
    nodes = pipeline.run(documents=[Document(text=reference_text)])
    index = VectorStoreIndex(nodes)
    return index

In [None]:
def load_or_create_index(article_id, reference_text, chunk_size, chunk_overlap, only_checking=False):
    index_path = f"../data/vector_indices/{'only_text_' if extension == 'txt' else ''}{chunk_size}_{chunk_overlap}/{article_id}/"
    index = None
    if only_checking:
        if os.path.exists(index_path) and os.listdir(index_path):
            print(article_id + ": Index exists.")
            return True
        
    assert reference_text is not None and reference_text != '', "Reference text cannot be None or empty."

    try:
        storage_context = StorageContext.from_defaults(persist_dir=index_path)
        index = load_index_from_storage(storage_context)
        print(article_id + ": Loaded existing index.")
    except Exception as e:
        print(e)
        print(article_id + ": Creating a new index.")
        try: 
            index = create_index(reference_text, chunk_size, chunk_overlap)
            index.storage_context.persist(persist_dir=index_path)
        except Exception as e:
            print(e)
            print(article_id + ": Failed to create index.")
            print(reference_text)
    return index

### Creating Indices for all Reference Papers

In [None]:
chunk_size = 1024
chunk_overlap = 20

In [None]:
for _, row in df.iterrows():
    if row['Reference Article Downloaded'] == 'Yes':
        reference_article_id = row['Reference Article ID']
        if reference_article_id:
            reference_text = get_reference_text(reference_article_id)
            index = load_or_create_index(reference_article_id, reference_text, chunk_size, chunk_overlap, only_checking=True)

## Retrieving Top 3 Chunks

In [None]:
from llama_index.core.retrievers import VectorIndexRetriever

def get_top_k_similar_chunks(statement, index, k=3):
    retriever = VectorIndexRetriever(
        index=index,
        similarity_top_k=k,
    )
    retrieved_nodes = retriever.retrieve(statement)
    return retrieved_nodes

In [None]:
def save_similar_chunks(doc_ids, reference_id, chunk_size, chunk_overlap):
    file_path = f"../data/similar_chunks/{'only_text_' if extension == 'txt' else ''}{chunk_size}_{chunk_overlap}/{reference_id}.json"
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w') as file:
        json.dump(doc_ids, file)

def load_similar_chunks(reference_id, chunk_size, chunk_overlap):
    file_path = f"../data/similar_chunks/{'only_text_' if extension == 'txt' else ''}{chunk_size}_{chunk_overlap}/{reference_id}.json"
    with open(file_path, 'r') as file:
        doc_ids = json.load(file)
    return doc_ids

### Saving Document IDs and Contents of Retrieved Chunks to the DF

In [None]:
def get_doc_ids(response):
    doc_ids = []
    for node in response:
        doc_ids.append(node.dict()['node']['id_'])
    return doc_ids

In [None]:
def save_top_k_chunk_ids(df, chunk_size, chunk_overlap, k=3):
    for _, row in df.iterrows():
        if row['Reference Article Downloaded'] == 'Yes':
            reference_article_id = row['Reference Article ID']
            print(f"------ Starting {reference_article_id} ------")
            
            # Try to load similar chunks first
            try:
                doc_ids = load_similar_chunks(reference_article_id, chunk_size, chunk_overlap)
                print("Loaded similar chunks successfully.")
            except FileNotFoundError:
                # Load reference text and create chunks
                reference_text = get_reference_text(reference_article_id)
                
                # Load or create index
                index = load_or_create_index(reference_article_id, reference_text, chunk_size, chunk_overlap)
                
                # Get the statement and retrieve top chunks
                statement = row["Corrected Statement"]

                print("Receiving top chunks")

                try:
                    response = get_top_k_similar_chunks(statement, index, k)
                    doc_ids = get_doc_ids(response)
                    
                    # Save the top chunks
                    print("Saving top chunks")
                    save_similar_chunks(doc_ids, reference_article_id, chunk_size, chunk_overlap)
                except Exception as e:
                    print(e)
                    print("Failed to get top chunks.")
            print("")

In [None]:
save_top_k_chunk_ids(df, chunk_size, chunk_overlap, k=3)

In [None]:
output_dir = f"../data/dfs/{'only_text_' if extension == 'txt' else ''}{chunk_size}_{chunk_overlap}/"
os.makedirs(output_dir, exist_ok=True)
# df2 = pd.read_pickle(os.path.join(output_dir, f"ReferenceErrorDetection_data_with_chunk_info.pkl"))

In [None]:
def add_top_k_chunk_ids_and_texts_to_df(df, chunk_size, chunk_overlap, k=3):
    doc_ids_list = []
    doc_texts_list = []
    for _, row in df.iterrows():
        if row['Reference Article Downloaded'] == 'Yes':
            reference_article_id = row['Reference Article ID']
            
            print(f"------ Starting {reference_article_id} ------")

            # load index
            index_path = f"../data/vector_indices/{'only_text_' if extension == 'txt' else ''}{chunk_size}_{chunk_overlap}/{reference_article_id}/"
            storage_context = StorageContext.from_defaults(persist_dir=index_path)
            index = load_index_from_storage(storage_context)

            # load similar chunks
            doc_ids = load_similar_chunks(reference_article_id, chunk_size, chunk_overlap)
            doc_texts = [index.docstore.docs[doc_id].text for doc_id in doc_ids]

            # add to lists
            doc_ids_list.append(doc_ids)
            doc_texts_list.append(doc_texts)
        else:
            doc_ids_list.append(None)
            doc_texts_list.append(None)
    
    df[f'Top_{k}_Chunk_IDs'] = doc_ids_list
    df[f'Top_{k}_Chunk_Texts'] = doc_texts_list
    return df

In [None]:
df2 = add_top_k_chunk_ids_and_texts_to_df(df, chunk_size, chunk_overlap, k=3)

In [None]:
# Ensure the directory exists
output_dir = f"../data/dfs/{'only_text_' if extension == 'txt' else ''}{chunk_size}_{chunk_overlap}/"
os.makedirs(output_dir, exist_ok=True)

# Save the DataFrame to a pickle file
df2.to_pickle(os.path.join(output_dir, f"ReferenceErrorDetection_data_with_chunk_info.pkl"))

# Save the DataFrame to a excel file
df2.to_excel(os.path.join(output_dir, f"ReferenceErrorDetection_data_with_chunk_info.xlsx"), index=False)