# 0. Library and Api code

In [None]:
from pprint import pprint
import os
import re
import warnings
from IPython import get_ipython
import numpy as np
import pandas as pd
import json
import faiss

from langchain_upstage import ChatUpstage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_upstage import UpstageLayoutAnalysisLoader
from langchain_upstage import UpstageEmbeddings
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from openai import OpenAI

import wikipediaapi
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from langchain.schema import Document  

In [None]:
warnings.filterwarnings("ignore")
upstage_api_key_env_name = "YOUR_API_KEY"  
UPSTAGE_API_KEY = "YOUR_API_KEY"  
file_path = "YOUR_FILE_PATH"

# 1. BaseLine

In [None]:
# 1. load llm ('solar-1-mini-chat')
def load_llm(model='solar-1-mini-chat'):
    """
    Loads Upstage's llm model.
    [params]
        - model(str): name of model (default: 'solar-1-mini-chat')
    [returns]
        - llm: model object
    """
    #llm = ChatUpstage(api_key = upstage_api_key_env_name, model = model)
    llm = ChatUpstage(
    api_key=upstage_api_key_env_name, 
    model=model,
    temperature=0,      # Deterministic output
    top_p=1.0,          # Consider all tokens (keep probability distribution uniform)
)
    return llm

In [206]:
# 2. converts document to text
def pdf2splitted_txt(document_path):
    """
    Converts pdf document into text format
    [params]
        - document_path: path to pdf file (e.g., 'ewha.pdf')
    [returns]
        - splitted_text(list): splitted text into chunks of size = chunk_size
    """
        
    #1. document load
    print(f"Loading document: {document_path}..")
    layzer = UpstageLayoutAnalysisLoader(
        api_key = UPSTAGE_API_KEY,
        file_path = document_path, 
        output_type = "text"
    )
    document = layzer.load()
    
    #2. split text
    print(f"Splitting document..")
    recur_text_splitter = RecursiveCharacterTextSplitter(
        chunk_size = 1000,
        chunk_overlap = 100,
    )
    splits = recur_text_splitter.split_documents(document)

    #3. put each split's page content into a list
    splitted_text = []
    for split in splits:
        splitted_text.append(split.page_content)
    
    if splitted_text:
        print(f"Extracted {len(splitted_text)} content chunks.")
    return splitted_text

In [207]:
# 3. embed document text
def embed_document(document_splits):
    """
    Embeds document's splitted text
    [params]
        - document_splits(list): list containing splitted text of the document 
    [returns]
        - embedded_vector(list): embedded vector of each splits
    """
    passage_embeddings = UpstageEmbeddings(
        api_key = UPSTAGE_API_KEY, 
        model = "embedding-passage" # context length = 4000
    )
    print(f"Embedding each document splits.")
    embedded_vector = passage_embeddings.embed_documents(document_splits)
    if embedded_vector:
        print(f"Done embedding.")
    return embedded_vector

In [None]:
# 4. store data in vector store(KB)
def vector_store(KB_name, splitted_texts, embedded_texts):
    """
    Store document text and embedding results in a vector store (using FAISS).
    [params]
        - KB_name (str): Name for the KB file. Saved as '{KB_name}.json'.
        - splitted_texts (list): List of text content to embed.
        - embedded_texts (list): List of embedding vectors.
    [returns]
        - vector_KB (faiss.Index): FAISS index object storing the vectors.
        - text_KB (str): JSON file storing metadata.
    """
    # 1. Validate parameters
    KB_name = str(KB_name)
    if not isinstance(KB_name, str):
        raise TypeError("KB_name should be a string.")
    if not splitted_texts:
        raise ValueError("Error: splitted_texts does not exist.")
    if not embedded_texts:
        raise ValueError("Error: embedded_texts does not exist.")

    # 2. Prepare text_KB (Dump embedded vectors and corresponding text in JSON)
    print(f"Making {KB_name}_text_KB")
    text_KB =  []
    for text, embedding in zip(splitted_texts, embedded_texts):
        text_KB.append({
            "vector": embedding,
            "metadata": { # customize metadata (e.g., category, keywords..)
                "content": text
            }
        })

    # 3. Save text_KB as JSON
    text_KB_path = os.path.join(file_path, f"data/{KB_name}.json")
    with open(text_KB_path, "w", encoding='utf-8') as f:
        json.dump(text_KB, f, ensure_ascii=False, indent=4)
    print(f"{KB_name}_text_KB saved at: {text_KB_path}")

    # 4. Store vectors in FAISS
    print(f"Storing vectors: {KB_name}_vector_KB..")
    dimension = len(embedded_texts[0])
    vectors = np.array(embedded_texts, dtype="float32") 
    vector_KB = faiss.IndexFlatL2(dimension)
    vector_KB.add(vectors)

    # 5. Save FAISS index
    vector_KB_path = os.path.join(file_path, f"data/{KB_name}.faiss")
    faiss.write_index(vector_KB, vector_KB_path)
    print(f"FAISS index saved at: {vector_KB_path}")
    print(f"Index size: {vector_KB.ntotal} vectors.")
    return vector_KB, text_KB

In [209]:
# 5. embed query
def embed_query(question):
    """
    Embed a user-provided question into a vector using the Upstage API
    [params]
        - question(str): Single question to be embedded.
    [returns]
        - query_vector (list): Embedding vector of the given question.
    """
    query_embeddings = UpstageEmbeddings(
        api_key = UPSTAGE_API_KEY, 
        model = "embedding-query" # context length = 4000
    )
    query_vector = query_embeddings.embed_query(question)
    return query_vector

In [210]:
# 6. Retrieve top_k=3 relevant context to solve the question

def retrieve_context(vector_KB, text_KB, query_vector, k=3):
    """
    Retrieves the most relevant contexts for a given query using FAISS and a knowledge base.

    [params]
    :vector_KB: KB that contains vectors of the text
    :query_vector (np.array): Query vector for similarity search.
    :text_KB (list): Knowledge base loaded as a list of metadata and texts.
    :k (int): Number of top results to retrieve.

    [returns]
    :retrieved_texts (list): List of top K retrieved context texts.
    """
    # Validate query vector and reshape
    query_vector = np.array(query_vector, dtype="float32").reshape(1, -1)

    # Perform FAISS search
    _, indices = vector_KB.search(query_vector, k=k)

    # Extract relevant texts
    retrieved_texts = []
    for idx_list in indices:
        for idx in idx_list:
            if 0 <= idx < len(text_KB):  # Ensure valid index range
                retrieved_texts.append(text_KB[idx]["metadata"]["content"])
            else:
                print(f"[WARNING] Index {idx} is out of range (KB size: {len(text_KB)}).")

    # When empty results
    if not retrieved_texts:
        print("[WARNING] No valid context retrieved.")
        return ["No relevant context found."]

    # Logging
    print("[INFO] Relevant text indices:", indices)
    
    return retrieved_texts

In [None]:
def read_data(data_path):
    try:
        data = pd.read_csv(data_path, encoding='cp949')
    except UnicodeDecodeError:
        data = pd.read_csv(data_path, encoding='utf-8')
    #data = pd.read_csv(data_path)
    prompts = data['prompts']
    answers = data['answers']
    return prompts, answers # lists


def extract_answer(response):
    
    pattern = r"\[ANSWER\]:\s*\((A|B|C|D|E|F|G|H|I|J)\)" 
    match = re.search(pattern, response)
    if match:
        return match.group(1)
    else:
        return extract_again(response)

def extract_again(response):
    pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)"
    match = re.search(pattern, response)
    if match:
        return match.group(0)
    else:
        return None

# 2. Ewha

In [None]:
# 0. prompt_template to solve questions about ewha.pdf
ewha_prompt_template = PromptTemplate.from_template(
    """
    You are a highly intelligent and helpful assistant.

    You are given a multiple-choice question and a context to solve the question. 
    Your task is to arrive at the correct answer by reasoning through the context step by step.  
    Follow these steps logically and pay special attention to numerical values at each stage:

    Step 1: Carefully read and understand the question and the provided context. Identify the key information in both.
    Step 2: Reason step by step to determine if the question can be answered based on the evidence in the context:
        - If the context provides sufficient evidence, label this as "**IN CONTEXT**".
        - If the context does not provide sufficient evidence, label this as "**OUT OF CONTEXT**".
    Step 3: If "**IN CONTEXT**":
        - Be careful when you bring context, especially it contains numerical values. 
        - Use the evidence in the context to analyze the options ((A): ~, (B): ~, ... etc).
        - Clearly identify the correct answer, paying close attention to any numerical or factual details.
    Step 4: If "**OUT OF CONTEXT**":
        - Use your general knowledge to analyze the options ((A): ~, (B): ~, ... etc).
        - Justify your reasoning and carefully select the best answer while being attentive to numerical accuracy.
    Step 5: Always provide the final answer in the fixed format: "[정답]: (A)"

    Let's work on this problem step by step to ensure accuracy and clarity.

    Question: {query}
    ---
    Context: {context}
    """
)


In [None]:
# 1. predict answer
def predict_answer_ewha(question:str, retrieved_context:list[str], prompt_template_ewha) -> str:
    """
    Predicts an answer to an ewha question.
    [params]
        - question: question to answer
        - retrieved_context: context to use to solve the question
    [returns]
        - prediction: predicted answer to the question
    """
    print(f"Predicting answer for a query.")
    llm = load_llm()
    chain = ewha_prompt_template | llm
    prediction = chain.invoke({"query": question, "context": "\n".join(retrieved_context)})  
    return prediction.content

In [None]:
#2. solve ewha question (prompt -> finding in KB -> retrieve context -> predict answer)

def solve_question_ewha(prompt,ewha_pdf_path):
    print(f"Solving a ewha question.")
    embedded_prompt = embed_query(prompt)

    # Check for required files and create if necessary
    ewha_vector_KB_path = os.path.join(file_path, "data/ewha.faiss")
    ewha_text_KB_path = os.path.join(file_path, "data/ewha.json")
    
    if not os.path.exists(ewha_vector_KB_path) or not os.path.exists(ewha_text_KB_path):
        print("[INFO] Required KB files not found. Generating 'ewha.faiss' and 'ewha.json'.")
        ewha_splitted = pdf2splitted_txt(ewha_pdf_path)
        ewha_embedded = embed_document(ewha_splitted)
        ewha_vector_KB, ewha_text_KB = vector_store("ewha", ewha_splitted, ewha_embedded)
    else:
        print("[INFO] Required KB files found. Loading them.")
        ewha_vector_KB = faiss.read_index(ewha_vector_KB_path)
        with open(ewha_text_KB_path, "r", encoding="utf-8") as f:
            ewha_text_KB = json.load(f)

    # Retrieve context
    retrieved_context = retrieve_context(ewha_vector_KB, ewha_text_KB, embedded_prompt, k=3)
    
    # Predict answer
    prediction = predict_answer_ewha(prompt, retrieved_context, ewha_prompt_template)
    return prediction

# 3.MMLU

In [None]:
# 0. MMLU prompt template
MMLU_prompt_template = PromptTemplate.from_template(
   """
   You are an expert in solving multiple-choice questions in the domain of {domain}. 

   You are given a multiple-choice question and a context to solve the question. 
   Your goal is to choose the correct answer based on the context provided.  
   Pay special attention to any numbers, quantities, or statistics mentioned in the question or context.
   Follow these steps while paying special attention to numerical values for each step:

   Step 1: Carefully read the question and the context.
   Step 2: Determine if the question can be answered based on the evidence in the context.
      - If yes, proceed to Step 3.
      - If no, proceed to Step 4.
   Step 3: Print "**IN CONTEXT**", provide evidence from the context.
      - Select the best one from the given options while being careful with find numerical values 
      - Then, proceed to Step 5.
   Step 4: Print "**OUT OF CONTEXT**", use your knowledge to generate an answer
      - Justify your reasoning and select the best one from the given options while being careful with find numerical values.
      - Then, proceed to Step 5.
   Step 5: Analysis options((A) : ~, (B): ~ ,...etc) based on the evidence and precisely select the best answer.
   Step 6: Always provide the selected best answer in the fixed format: "[ANSWER]: (A)"
   
   Question: {query}
   ---
   Context: {context}
   """

)

In [216]:
# 1. Load Wikipedia API
def load_wiki(lang='en'):
    wiki = wikipediaapi.Wikipedia('LLM Project',language=lang)
    return wiki

In [None]:
# 2. Determine the domain of the question
def define_domain(question, domains=["Law", "Psychology", "Business", "Philosophy", "History"]):
    """"
    Uses LLM to determine the domain of the query. 
    Domain choices are ["Law", "Psychology", "Business", "Philosophy", "History"].
    [param]
        - question: MMLU-pro multiple choice question
        - domains: categories that the query will fall into
    [return]
        - domain: defined domain
    """
    llm = load_llm()
    
    # Initialize a dictionary to store scores for each domain
    scores = {}
    
    # Loop through each domain and ask the LLM to score the question's relevance
    for domain in domains:
        comparison_prompt_template = PromptTemplate.from_template(
            """
            Question: {question}
            Domain: {domain}

            Task: Based on the question above, determine how well the question fits the domain "{domain}".
            Rate the relevance of this question to the domain on a scale of 0 to 100, where:
            - 100 means the question is fully relevant to the domain "{domain}".
            - 0 means the question is not relevant to the domain at all.

            Important:
            - Carefully analyze the question.
            - Avoid giving high scores unless there is clear evidence the question belongs to the domain "{domain}".
            Provide only the score (0-100).
            """
        )
        
        # Generate a response for each domain
        response = (comparison_prompt_template | llm).invoke({"question": question, "domain": domain})
        
        # Parse the response to extract the score
        try:
            score = float(response.content.strip())  # Extract the numeric score
        except ValueError:
            score = 0  # If parsing fails, assign a default score of 0
        
        scores[domain] = score  # Store the score for the current domain
    
    # Find the domain with the highest score
    best_domain = max(scores, key=scores.get)
    highest_score = scores[best_domain]
    
    print(f"Selected domain: {best_domain} with score: {highest_score}")
    
    # Return the domain with the highest score
    return best_domain.strip() if isinstance(best_domain, str) else best_domain.content.strip()
    
        

In [None]:
# 3. compute similarity of query and the title of a cateogry or a page
def compute_similarity(query, title):
    """
    Compute similarity between the query and the category/page's title using cosine similarity.
    [params]
        - query (str): query, question
        - title (str): title of a category or page to compare.
    [returns]
        - similarity (float): Cosine similarity score 
    """
    # need to use same embedding model to compute similarity (for same dimension)
    embedded_query = embed_query(query)
    embedded_title = embed_query(title)

    # convert embeddings to np arrays to compute cosine similarity
    query_vector = np.array(embedded_query).reshape(1, -1)
    title_vector = np.array(embedded_title).reshape(1, -1) 
    
    # compute cosine similarity -> short strings. no need for chunking
    similarity = cosine_similarity(query_vector, title_vector)
    return similarity

In [None]:
# 4. fetch content from the selected page
def fetch_page_content(wiki, page):
    page = wiki.page(page)
    if page.exists():
        return page.text
    return ""

In [None]:
# 5. Search for categories in wiki and get top 10 pages or all existing pages
def find_most_relevant_page_with_subcategories(wiki, query, domain, max_depth=1, current_depth=0, chunk_size=1000, top_n=10, summary_limit=500):
    """
    Optimized version to find the most relevant page with subcategories up to a certain depth.
    [params]
        - wiki: Wikipedia API object.
        - query (str): The query string for comparison.
        - domain (str): The domain of the query, which is the starting category name.
        - max_depth (int): Maximum depth for recursive search (default is 1).
        - current_depth (int): Current recursion depth (default is 0).
        - chunk_size (int): Maximum size of text chunks for comparison.
        - top_n (int): Limit the number of members to check per category.
        - summary_limit (int): Limit the length of the summary for comparison.
    [returns]
        - best_match (str): The title of the most relevant page.
    """
    # Load the category
    category = wiki.page(f"Category:{domain}")
    if not category.exists():
        print(f"Category '{domain}' does not exist.")
        return None

    print(f"{'  ' * current_depth}Searching in category: {domain}")

    best_match = None
    highest_similarity = -1

    def split_text(text, chunk_size):
        """Helper function to split text into chunks."""
        return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

    # Get members of the category (handle cases where fewer members are available)
    members = list(category.categorymembers.values())
    if not members:
        print(f"{'  ' * current_depth}No members found in category: {domain}")
        return None  # Return None if no members exist

    # Limit to top_n or use all available members
    members = members[:top_n] if len(members) > top_n else members

    # Iterate through category members (pages and subcategories)
    for member in members:
        if member.ns == wikipediaapi.Namespace.MAIN:  # Page
            print(f"{'  ' * current_depth}Checking page: {member.title}")
            page = wiki.page(member.title)
            summary = page.summary[:summary_limit]  # Limit summary length
            chunks = split_text(summary, chunk_size)  # Split the limited summary into chunks
            for chunk in chunks:
                similarity = compute_similarity(query, chunk)  # Compare query with each chunk
                if similarity > highest_similarity:
                    highest_similarity = similarity
                    best_match = member.title

        elif member.ns == wikipediaapi.Namespace.CATEGORY and current_depth < max_depth:  # Subcategory
            print(f"{'  ' * current_depth}Checking subcategory: {member.title}")
            subcategory_name = member.title.replace("Category:", "")
            sub_match = find_most_relevant_page_with_subcategories(
                wiki, query, subcategory_name, max_depth, current_depth + 1, chunk_size, top_n, summary_limit
            )
            if sub_match:  # Check if a match is found in the subcategory
                sub_similarity = compute_similarity(query, sub_match)
                if sub_similarity > highest_similarity:
                    highest_similarity = sub_similarity
                    best_match = sub_match

    # Return the most relevant page
    print(f"{'  ' * current_depth}Best match at this level: {best_match} (Similarity: {highest_similarity})")
    return best_match


In [None]:
# 6. split wiki text in chunks 
from langchain.schema import Document   

def split_wiki(page_text, page_title):
    """
    Splits text of a page in wiki
    [params]
        - page_text: text you want to split
        - page_title
    """
    recur_text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=100,
    )
    # convert page text into a Document object
    document = [Document(page_content=page_text, metadata={"title": page_title})]
    splits = recur_text_splitter.split_documents(document)

    # store splitted chunks
    split_contents = []
    for split in splits:
        split_contents.append(split.page_content)
    if split_contents:
        print(f"Extracted {len(split_contents)} content chunks for {page_title}.")
    return split_contents

In [None]:
# 7-1. solving MMLU-pro question with pdf
def solve_question_MMLU_pdf(question):
    print(f"\nSolving an MMLU question with pdf.")
    embedded_prompt = embed_query(question)
    # 1. determine domain
    domain = define_domain(question, domains=["Law", "Psychology", "Business", "Philosophy", "History"])
    # 유지
    domain_pdf_path =  os.path.join(file_path, f"documents\\{domain}.pdf")
     # Check for required files and create if necessary
    domain_vector_KB_path = os.path.join(file_path, f"data\\{domain}.faiss")
    domain_text_KB_path = os.path.join(file_path, f"data\\{domain}.json")
    if not os.path.exists(domain_vector_KB_path) or not os.path.exists(domain_text_KB_path):
        print(f"[INFO] Required KB files not found. Generating '{domain}.faiss' and '{domain}.json'.")
        domain_splitted = pdf2splitted_txt(domain_pdf_path)
        domain_embedded = embed_document(domain_splitted)
        domain_vector_KB, domain_text_KB = vector_store(f"{domain}", domain_splitted, domain_embedded)
    else:
        print("[INFO] Required KB files found. Loading them.")
        domain_vector_KB = faiss.read_index(domain_vector_KB_path)
        with open(domain_text_KB_path, "r", encoding="utf-8") as f:
            domain_text_KB = json.load(f)
    
     # Retrieve context
    retrieved_context = retrieve_context(domain_vector_KB, domain_text_KB, embedded_prompt, k=3)    
    return retrieved_context

In [None]:
# 7-2. solving MMLU-pro question with wiki
def solve_question_MMLU_wiki(question, wiki):
    print(f"\nSolving an MMLU question with wiki.")
    # 1. determine domain
    domain = define_domain(question, domains=["Law", "Psychology", "Business", "Philosophy", "History"])
    
    # 2. fetch most relevant page
    most_relevant_page = find_most_relevant_page_with_subcategories(wiki, question, domain)
    page_title = most_relevant_page
    
    vector_KB_path = os.path.join(file_path, f"data\\{page_title}.faiss")
    text_KB_path = os.path.join(file_path, f"data\\{page_title}.json")

    # check if already exists KB for that page
    if not os.path.exists(vector_KB_path) or not os.path.exists(text_KB_path):
        print(f"[INFO] Required KB files not found. Generating '{page_title}.faiss' and '{page_title}.json'.")
        page_text = fetch_page_content(wiki, most_relevant_page)
        page_splitted = split_wiki(page_text, page_title)
        # embed and store into KB
        page_embedded = embed_document(page_splitted)
        page_vector_KB, page_text_KB = vector_store(page_title, page_splitted, page_embedded)
    else:
        print("[INFO] Required KB files found. Loading them.")
        page_vector_KB = faiss.read_index(vector_KB_path)
        with open(text_KB_path, "r", encoding="utf-8") as f:
            page_text_KB = json.load(f)
    
    embedded_prompt = embed_query(question)
    retrieved_context = retrieve_context(page_vector_KB, page_text_KB, embedded_prompt, k=3)
    return retrieved_context

    

In [None]:
# 8. predict MMLU question answer
def predict_answer_MMLU(domain, prompt, retrieved_context, MMLU_prompt_template):
    """
    Predicts an answer to an MMLU question.
    [params]
        - prompt: question to answer
        - retrieved_context: context to use to solve the question
    [returns]
        - prediction: predicted answer to the question
    """
    print(f"Predicting answer for a query.")
    chain = MMLU_prompt_template | llm
    prediction = chain.invoke({"domain": domain,"query": prompt, "context": "\n".join(retrieved_context)})
    return prediction.content

# 4. Both

In [227]:
# Updated Function: test accuracy
def test_accuracy(questions, predictions, answers):
    print(f"[Testing accuracy]")
    cnt = 0
    for answer, prediction in zip(answers, predictions):
        print("-"*70)
        generated_answer = extract_answer(prediction)
        print(prediction)
        # check
        if generated_answer:
            print(f"generated answer: {generated_answer}, answer: {answer}")
        else:
            print("extraction fail")

        if generated_answer == None:
            continue
        if generated_answer in answer:
            cnt += 1
    print("="*70)
    print(f"acc: {(cnt/len(answers))*100}%")

In [None]:
# determine question type by Language (ewha or MMLU)

def determine_question_type(prompt: str) -> str:
    """
    Determine the question type based on the language of the prompt.
    
    [params]
    :prompt (str): The prompt text to analyze.

    [returns]
    :question_type (str): 'ewha' if the prompt is Korean, 'MMLU' if it's English.
    """
    if any('가' <= char <= '힣' for char in prompt):
        question_type = 'ewha'
    else:
        question_type = 'MMLU'
    print(f"\nQuestion type is {question_type} for prompt: \n>> {prompt}")
    return question_type


In [None]:
# Infer answer prompt template
Answer_prompt_template = PromptTemplate.from_template(
   """
   You are a highly reliable assistant. 

   Your goal is to answer multiple-choice questions based on the given context.
   Your final output MUST ONLY contain the answer, formatted as [ANSWER]: (X).

   Instructions:
   1. Carefully read the question and the provided context.
   2. Decide if the question can be answered based on the context:
      - If the context contains sufficient information, select the best option.
      - If the context lacks sufficient information, select the most reasonable option based on your knowledge.
   3. Analyze all options ((A), (B), (C), etc.) and pick the best one.
   4. **FINAL STEP**: Output ONLY the answer in this format: [ANSWER]: (X). Do not include any extra text.

   Example output: `[ANSWER]: (X)`

   Question: {query}
   ---
   Context: {context}
   """
)


In [None]:
# Infer translation again using the context of LLM 
def predict_answer_MMLU_fixed(prompt, retrieved_context, Answer_prompt_template, default_answer="(B)"):
    """
    Predicts an answer to an MMLU question and ensures fixed output format.
    [params]
        - prompt: The question to answer
        - retrieved_context: Context to use to solve the question
        - Answer_prompt_template: Template to format the input prompt
        - default_answer: Default answer if the LLM output is not in the expected format
    [returns]
        - prediction: Predicted answer in the fixed format "[ANSWER]: (X)"
    """
    import re

    print(f"Predicting answer for a query.")
    # Create the input for the chain
    chain = Answer_prompt_template | llm
    prediction = chain.invoke({"query": prompt, "context": "\n".join(retrieved_context)})

    # Ensure the output is in fixed format
    content = prediction.content.strip()
    return content  # Return as-is if in the correct format
    

# 5.Full Pipeline

In [None]:

# load each model
llm = load_llm()
wiki = load_wiki()

# list to save predictions of each question
predictions = []
formatted_predictions = []
# prepare files
ewha_pdf_path = os.path.join(file_path, f"documents/ewha.pdf")
test_set_path = os.path.join(file_path, f"test_sets/testset.csv")

# prepare test set
prompts, answers = read_data(test_set_path)

for (prompt, answer) in zip(prompts, answers):
    question_type = determine_question_type(prompt)
    if question_type == 'ewha':
        prediction = solve_question_ewha(prompt,ewha_pdf_path)
    elif question_type == 'MMLU':
        domain = define_domain(prompt, domains=['Law', 'Psychology', 'Business', 'Philosophy', 'History'])
        pdf_retrieved_context = solve_question_MMLU_pdf(prompt)
        wiki_retrieved_context = solve_question_MMLU_wiki(prompt,wiki)
        all_retrieved_context = pdf_retrieved_context + wiki_retrieved_context
        print(all_retrieved_context)
        prediction = predict_answer_MMLU(domain, prompt, all_retrieved_context, MMLU_prompt_template)
    else:
        print(f"Cannot determine the question type of {prompt}.")
        prediction = 'No question type'
    
    formatted_prediction = predict_answer_MMLU_fixed(prompt, prediction, Answer_prompt_template)
    
    print("this is prediction***")       
    print(prediction)   
    print("this is answer prediction***")    
    print(formatted_prediction)
    formatted_predictions.append(formatted_prediction)

# after getting all predictions to the questions
test_accuracy(prompts, formatted_predictions, answers)


Question type is ewha for prompt: 
>> QUESTION1) 재학 중인 학생이 휴학을 하려면 학기 개시일로부터 며칠 이내에 휴학을 신청하야하나요?
(A) 30일
(B) 45일 
(C) 60일
(D) 90일
Solving a ewha question.
[INFO] Required KB files not found. Generating 'ewha.faiss' and 'ewha.json'.
Loading document: C:/Users/liy35/NLP_project\documents/ewha.pdf..
Splitting document..
Extracted 53 content chunks.
Embedding each document splits.
Done embedding.
Making ewha_text_KB
ewha_text_KB saved at: C:/Users/liy35/NLP_project\data/ewha.json
Storing vectors: ewha_vector_KB..
FAISS index saved at: C:/Users/liy35/NLP_project\data/ewha.faiss
Index size: 53 vectors.
[INFO] Relevant text indices: [[7 2 8]]
Predicting answer for a query.
Predicting answer for a query.
this is prediction***
정답: (D) 90일
this is answer prediction***
[ANSWER]: (D)
90일

Question type is ewha for prompt: 
>> QUESTION2) '재입학은 a회에 한하여 할 수 있다. 다만 제 28조제4호에 의하여 제적된 자는 제적된 날부터 b년이 경과한 후 재입학 할 수 있다.' a와 b가 상수일 때 a+b의 값을 구하면?
(A) 2
(B) 3
(C) 4
(D) A,B,C 중 답 없음
Solving a ewha question.
