In [None]:
from langchain_upstage import UpstageLayoutAnalysisLoader, ChatUpstage, UpstageEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from kobert_transformers import get_kobert_model, get_tokenizer
from langdetect import detect
import pandas as pd
import numpy as np
import wikipediaapi
import spacy
import torch
import time
import os
import re

# set parameters
api_key = "up_Yvl9vhT0SfWFZRwucrXCW4O4th9IH"
data_path = "." # folder path containing ewah.pdf and samples.csv

def parse_pdf():
    layzer = UpstageLayoutAnalysisLoader(api_key=api_key,file_path=os.path.join(data_path, 'ewha.pdf'), output_type="text")
    pdf_text = layzer.load()  # or layzer.lazy_load()

    # 2. Split
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000, chunk_overlap=250
    )
    articles = text_splitter.split_documents(pdf_text)
    return articles

articles = parse_pdf()

In [None]:
def main(file_name, context_extraction_method, uses_wiki):
    prompts, answers = read_file(file_name)
    if(context_extraction_method == "kobert"):
        contexts = find_context_kobert(prompts, articles, uses_wiki)
    elif (context_extraction_method == "upstage"):
        contexts = find_context_upstage(prompts, articles, uses_wiki)
    else:
        print("Invalid context extraction method")
    responses = get_llm_responses(prompts, contexts)
    print_accuracy(responses, answers)

def read_file(file_name):
    data = pd.read_csv(os.path.join(data_path, file_name))
    prompts = data['prompts']
    answers = data['answers']
    # returns three lists: prompts, answers and evidences
    return prompts, answers

# Cleans question by removing options
def clean_question(question):
    result = True
    # Find the position of the first ')'
    pos1 = question.find(')')
    # Find the position of the first '(A)'
    pos2 = question.find('(A)')
    # Check if both positions were found
    if pos1 != -1 and pos2 != -1 and pos1 < pos2:
        # Extract content between the positions
        question = question[pos1 + 1 : pos2].strip()
    else:
        print("Could not find the positions or invalid positions.")
        result = False
    return question, result


# Extract answer from a response
def extract_answer(response):
    """
    extracts the answer from the response using a regular expression.
    expected format: "(A)"

    if there are no answers formatted like the format, it returns None.
    """
    pattern = r"\(([A-Z])\)"  # Regular expression to capture the answer letter
    match = re.search(pattern, response)

    if match:
        return match.group() # Extract the letter inside parentheses (e.g., A)
    else:
        return extract_again(response)
def extract_again(response):
    pattern = r"([A-Z])"
    match = re.search(pattern, response)
    if match:
        return f"({match.group(0)})"
    else:
        return None

# prints responses
def print_responses(responses, answers):
    count = 0
    for response in responses:
        print(f"Question {count+1} : {response} \t Correct answer: {answers[count]}")
        count += 1

# prints accuracy
def print_accuracy(responses, answers):
        
    # Count the number of matching values at the same index
    count = 0
    index = 0
    mistakes_index = []
    for a, b in zip(responses, answers):
        if extract_answer(a) == b:
            count += 1
        else:
            mistakes_index.append(index)
        index += 1
    print(f"Correct answers: {count}/{len(responses)} {count/len(responses)*100:.2f}% accuracy")

# Finds best contexts using upstage embedding
def find_context_upstage(prompts, articles, uses_wiki):
    embeddings = UpstageEmbeddings(
        api_key=api_key,
        model="embedding-query"
    )
    # Compute document embeddings
    doc_result = embeddings.embed_documents(
        [article.page_content for article in articles]
    )
    contexts = []
    count = 1
    for prompt in prompts:
        print(f"Processing question {count}")
        count += 1
        question, result = clean_question(prompt)
        if detect(question) == "en":
            if uses_wiki == False:
                contexts.append([])
                continue
            else:
                contexts.append(get_contexts_wiki(question))
                continue
        query_result = embeddings.embed_query(question)
        similarity_list = []
        for passage_embedding in doc_result:
            similarity = np.dot(passage_embedding, query_result)
            similarity_list.append(similarity)
            
        values = similarity_list
        # Get the indexes sorted by values in descending order
        sorted_indexes = sorted(range(len(values)), key=lambda i: values[i], reverse=True)

        # Take top 3 chunks as context
        context = []
        for i in range(10):
            context.append(articles[sorted_indexes[i]].page_content)
        contexts.append(context)
    return contexts

# Finds best context using KoBERT
def find_context_kobert(prompts, articles, uses_wiki):
    # Load KoBERT tokenizer and model
    tokenizer = get_tokenizer()
    model = get_kobert_model()
    def embed_text(text, tokenizer, model):
        # Embeds text using KoBERT.
        # Tokenize input
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        # Get model outputs
        with torch.no_grad():
            outputs = model(**inputs)
        # Mean-pool embeddings
        embeddings = outputs.last_hidden_state.mean(dim=1)
        return embeddings
    # Embed articles
    article_embeddings = torch.vstack([embed_text(article.page_content, tokenizer, model) for article in articles])
    contexts = []
    count = 1
    for prompt in prompts:
        print(f"Processing question {count}")
        count += 1
        # Embed question
        question, result = clean_question(prompt)
        if detect(question) == "en":
            if uses_wiki == False:
                contexts.append([])
                continue
            else:
                contexts.append(get_contexts_wiki(question))
                continue
        question_embedding = embed_text(question, tokenizer, model)
        # Compute cosine similarity between question and articles
        cosine_sim = torch.nn.functional.cosine_similarity(article_embeddings, question_embedding)
        # Find the top 3 most relevant articles
        values = cosine_sim
        # Get the indexes sorted by values in descending order
        sorted_indexes = sorted(range(len(values)), key=lambda i: values[i], reverse=True)
        # Take top 3 chunks as context
        context = [articles[sorted_indexes[0]].page_content, articles[sorted_indexes[1]].page_content, articles[sorted_indexes[2]].page_content]
        contexts.append(context)
    return contexts

def get_contexts_wiki(question):
    # Create a Wikipedia API instance
    wiki = wikipediaapi.Wikipedia("NLP Project (yanrenyu00@gmail.com)", "en")
    # Extract keywords from question
    keywords = extract_keywords(question)
    context = []
    for keyword in keywords:
        page = wiki.page(keyword)
        # Check if the page exists
        if page.exists():
            context.append(page.summary)
        else:
            # print(f"Page does not exist for {keyword}")
            pass
    return context


# Extract keywords of an english question
def extract_keywords(question):
    # Load spaCy's English model
    nlp = spacy.load("en_core_web_sm")
    # Process the question
    doc = nlp(question)
    # Extract nouns, proper nouns, and compound nouns
    keywords = [chunk.text for chunk in doc.noun_chunks]
    return keywords

def get_llm_responses(prompts, contexts):
    llm = ChatUpstage(api_key = api_key, model="solar-1-mini-chat")
    prompt_template = PromptTemplate.from_template(
        """
        You are an expert in multiple domains, including Law, Psychology, Business, Philosophy, and History. 
        Your task is to provide the most accurate answer to the given question based on the provided context. 
        Follow these instructions carefully:

        1. Read the question and context carefully.
        2. Your task is to **select the letter of the correct option**. 
        3. If the correct answer is a combination of choices (e.g., "x and y"), **explicitly select the combined option** (e.g., (D): "x and y") rather than any single choice (e.g., (A): "x" or (B): "y").
        4. **Avoid assuming single choices are correct if a combination option matches all criteria**.
        5. If the correct answer is not explicitly available in the context, use your judgment to choose the most plausible answer.
        6. Your final answer should be in the form of the letter corresponding to the answer (e.g., (A), (B), (C), etc.).

        I have provided 2 examples:

        Example:
        Question: Who likes to eat chips?    
        (A): Ha  
        (B): Lam  
        (C): Bun  
        (D): All of the above  
        Context: Ha, Lam, Bun like to eat chips.
        Answer: (D) 

        Example :
        Question : Which Team won the game?
        (A)B
        (B)C
        (C)A
        Context : Team B won the game.
        Answer: (A)


        The answer must look like this always!  
        Answer: (<answer>) where <answer> is a single capital letter

        ---

        ### For Business-related questions that involve calculations (e.g., averages, standard deviations):

        1. If the question involves calculating an average or standard deviation, first follow these steps:

        - **To calculate the average (mean):**
        1. Add up all the values in the dataset.
        2. Divide the total sum by the number of values.
     
        - **To calculate the standard deviation:**
        1. Find the average (mean) of the dataset.
        2. Subtract the average from each data point and square the result.
        3. Calculate the average of these squared differences.
        4. Take the square root of this average to get the standard deviation.

        2. Once you have the correct result, select the corresponding answer from the options.

        ---

        ### Prompt Format:

        Question: {question}  
        Context: {context}  
        Answer: (<answer>)

        """
    )
    chain = prompt_template | llm
    responses = []
    count = 0
    for prompt in prompts:
        # print(f"Prompt: {prompt}")
        response = chain.invoke({"question": prompt, "context": "\n".join(contexts[count])})
        # print(f"Response: {response.content}")
        responses.append(response.content)
        count += 1
        time.sleep(0) # Increase delay if API rate limits are hit
    return responses



In [None]:
# Replace with file name to run
# Available context extraction methods: upstage/kobert (only applicable for ewha questions)
main(file_name = "testsets/test_sample_MMLU_hard.csv", context_extraction_method = "upstage", uses_wiki = True)