**Imports**

In [None]:
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_ollama import OllamaLLM
from langchain.prompts import ChatPromptTemplate
from uuid import uuid4
from langchain_core.documents import Document
from tqdm import tqdm

import pandas as pd
import json
import faiss

from config import prompt

**Create vector database**

In [None]:
embeddings = OllamaEmbeddings(model="all-minilm:33m")

index = faiss.IndexFlatL2(len(embeddings.embed_query("Hello world")))

vector_store = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={}
)

with open(r"../data/statpearls_chunks.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

documents = []

for row in tqdm(data, desc="Processing chunks", total=len(data)):
    temp_doc = Document(page_content=row["chunk_text"], metadata={"chunk_id": row["_id"], "chunk_index": row["chunk_index"] ,"source_filename": row["source_filename"]})
    documents.append(temp_doc)


uuids = [str(uuid4()) for _ in range(len(documents))]

vector_store.add_documents(documents=documents, ids=uuids)

vector_store.save_local(r"../data/med_index")

**Load Vector Database**

In [None]:
embeddings = OllamaEmbeddings(model="all-minilm:33m")

vector_store = FAISS.load_local(
    r"../data/med_index", embeddings, allow_dangerous_deserialization=True
)

MMLU

In [None]:
with open(r"../data/benchmark.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

mmlu_benchmark = data["mmlu"]
ground_truth_answers = []

for row in mmlu_benchmark:
    ground_truth_answers.append(mmlu_benchmark[row]["answer"])


MedMCQA

In [None]:
with open(r"../data/benchmark.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

medmcqa_benchmark = data["medmcqa"]
ground_truth_answers = []

for row in medmcqa_benchmark:
    ground_truth_answers.append(medmcqa_benchmark[row]["answer"])

MedMCQA - Test examples (same as MMLU)

In [None]:
test_examples = []

for q in medmcqa_benchmark:
    question = medmcqa_benchmark[q]["question"]
    options = medmcqa_benchmark[q]["options"]
    
    test_examples.append((question, options))

In [None]:
import spacy
import re

nlp = spacy.load("en_core_web_sm")

def remove_one_word_perturbations(context: str) -> list[tuple[str, str, str]]:
    """
    Generate one-word-removal perturbations from the given context. Used for exhaustive perturbations.

    For each word in the context, the function removes it to create a perturbed version of the context.

    Returns all generated perturbations in a list:
        list[tuple[str, str, int]]: A list of tuples, each containing:
            - str: Perturbed context with one word removed
            - str: The removed word
            - int: Index of the removed word
    """

    words = extract_tokens(context)
    perturbations = []

    for index, _ in enumerate(words):
        temp_words = words[:]
        
        removed_word = temp_words.pop(index)

        perturbed_context = " ".join(temp_words)
        perturbed_context = add_newlines_before_documents(perturbed_context)

        temp_perturbation = (perturbed_context, removed_word, index)
        perturbations.append(temp_perturbation)

    return perturbations

def remove_word_span(context: str, span_size: int) -> list[tuple[str, str, int]]:
    """
    Generate span-based word removal perturbations from the given context.

    For each contiguous span of 'span_size' words, remove the span to create a perturbed version of the context.

    Returns all generated perturbations in a list:
        list[tuple[str, str, int]]: A list of tuples, each containing:
            - str: Perturbed context with one span removed
            - str: The removed span text
            - int: The starting index of the removed span in the original word list

    """

    words = extract_tokens(context)
    perturbations = []

    for i in range(len(words) - span_size + 1):
        temp_words = words[:i] + words[i + span_size:]
        removed_words = words[i:i + span_size]

        perturbed_context = " ".join(temp_words)
        perturbed_context = add_newlines_before_documents(perturbed_context)

        temp_perturbation = (perturbed_context, " ".join(removed_words), i)
        perturbations.append(temp_perturbation)

    return perturbations


def extract_tokens(text):
    doc = nlp(text)
    words = [token.text for token in doc if not token.is_punct]

    return words

def add_newlines_before_documents(text):
    updated_text = re.sub(r'(?<!^) (Chunk \d+)', r'\n\n\1', text)
    return updated_text

**Exhaustive Perturbations**

In [None]:
model = OllamaLLM(model="llama3.2:3b-instruct-fp16", temperature=0)

examples = test_examples[:]
results = {}

for index, example in tqdm(enumerate(examples), desc="Processing Perturbations", total=len(examples)):
    context = vector_store.similarity_search(example[0], k=1)
    context = [c.page_content for c in context]

    context_text = ""
    for index, c in enumerate(context):
        context_text += f"Chunk {index}: {c.capitalize()}\n"

    chain = prompt | model
    original_response = chain.invoke({"paragraph": context_text, "question": question, "options": options})[0]
    print(f"Response: {original_response}")

    perturbations = remove_word_span(context_text, 5)

    for perturbation in perturbations:
        perturbed_text = perturbation[0]
        removed_token = perturbation[1]

        temp_answer = chain.invoke({"paragraph": perturbed_text, "question": question, "options": options})[0]  

        if original_response != temp_answer:
            print(f"Peturbed Text: {perturbed_text}") 
            print(f"Temp answer: {temp_answer} | Removed token: {removed_token}")    


**Count LLM calls and Tokens**

In [None]:
model = OllamaLLM(model="llama3.2:3b-instruct-fp16", temperature=0)

examples = test_examples[:]
responses = {}

for test_index, example in tqdm(enumerate(examples), desc="Processing Perturbations", total=len(examples)):
    context = vector_store.similarity_search(example[0], k=1)
    context = [c.page_content for c in context]

    context_text = ""
    for index, c in enumerate(context):
        context_text += f"Chunk {index}: {c.capitalize()}\n"

    paragraph_length = len(context_text.split())

    perturbations = remove_word_span(context_text, 5)

    responses[f"test_{test_index}"] = {}
    responses[f"test_{test_index}"]["llm_calls"] = len(perturbations)
    responses[f"test_{test_index}"]["total_tokens"] = len(perturbations) * paragraph_length

with open(r"../results/medmcqa_calls_amount_simple.json", "w", encoding="utf-8") as f:
    json.dump(responses, f, indent=2, ensure_ascii=False)
