### Step 1: Preprocess Data and Build the Vector Database using FAISS


In [1]:
%cd "/Users/rebeccaglick/Desktop/pubmedqa/data"

import json
from langchain.schema import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

# Load data
with open('ori_pqal.json', 'r') as f:
    data = json.load(f)

# Get first 50 examples from JSON file
items = list(data.items())[:1000]

# Convert desired number of items in LangChain document - to start, I am using the first 50
# Each entry has contexts, long_answer, and ground truth final_decision
documents = []
for pmid, entry in list(data.items())[:1000]:
    context = " ".join(entry["CONTEXTS"])
    long_answer = entry["LONG_ANSWER"]
    full_text = f"Context: {context}\n\nConclusion: {long_answer}"
    documents.append(Document(page_content=full_text, metadata={"pmid": pmid}))

# Embed using local model 
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Store in FAISS vector database
vectorstore = FAISS.from_documents(documents, embedding_model)

# Save for later
vectorstore.save_local("faiss_index_1000_entries")

/Users/rebeccaglick/Desktop/pubmedqa/data


  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
  from .autonotebook import tqdm as notebook_tqdm


### 2. Load the Vector Database and Create Retrieval QA Chain

In [2]:
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOllama
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

# 1. Reload vector DB later without re-embedding
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.load_local("faiss_index_1000_entries", embedding_model, allow_dangerous_deserialization=True)

# 2. Create retriever to fetch the most relevant documents given question
# Documents retrieved are then passed as context to LLM to answer the question (here k=3 -> returns 3 most similar passages)
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})

# 3. Connect to your local LLaMA 3.2 model via Ollama
llm = ChatOllama(model="llama3.2")

# 4. Define custom prompt
custom_prompt = PromptTemplate(
    input_variables=["context", "question"],
    template="""Use the following context to answer the question.
If the answer is not explicitly clear from the context, respond with "Maybe".

Context:
{context}

Question: {question}
Answer with only one word: Yes, No, or Maybe.
Answer:"""
)

# 5. Build RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type="stuff",
    chain_type_kwargs={"prompt": custom_prompt}
)

  llm = ChatOllama(model="llama3.2")


### 3. Evaluation

In [3]:
import string 

# Keep track of how many answers the model gets correct
correct = 0 
total = len(items)
predictions = []

# Loop through all of the questions (each entry contains the question and ground truth answer)
for pmid, entry in items:
    question = entry["QUESTION"] # what is passed to the LLM
    truth = entry["final_decision"].lower() # the actual answer (Y/N/M)

    retrieved_docs = retriever.get_relevant_documents(question)

    if not retrieved_docs:
        print(f"[WARN] No documents retrieved for PMID {pmid}: '{question}'")
    else:
        print(f"[INFO] Retrieved {len(retrieved_docs)} docs for PMID {pmid}")
        for i, doc in enumerate(retrieved_docs):
            print(f"\n-- Doc {i+1} (snippet) --\n{doc.page_content[:300]}...\n")

    try:
        response = qa_chain.invoke({"query": question}) # query the LLM using the QA chain defined above
        raw_output = response["result"]
        print(f"\n[LLM RAW OUTPUT for PMID {pmid}]:\n{raw_output}\n")
        answer = raw_output.strip().lower().strip(string.punctuation) # LLM response stored here
        
        if not answer:
            print(f"[WARN] Empty response from LLM for PMID {pmid}")
        if answer not in {"yes", "no", "maybe"}:
            print(f"[WARN] Unexpected answer from LLM: '{answer}' — defaulting to 'maybe'")

    except Exception as e:
        print(f"Error with PMID {pmid}: {e}")
        answer = "maybe"

    # Fallback check
    if answer not in {"yes", "no", "maybe"}:
        answer = "maybe"

    
    print(f"[QUESTION]: {question}")
    print(f"[GROUND TRUTH]: {truth}")
    print(f"[FINAL PREDICTION]: {answer}")

    if answer == truth: # compare prediction of model to ground truth 
        correct += 1

    predictions.append(answer)

# Print/calculate final accuracies 
#print(f"Correct answers: {correct} out of {total}")
#print(f"Accuracy: {correct / total:.2f}")

valid_answers = [a for a in predictions if a in {"yes", "no", "maybe"}]
print(f"\nCorrect answers: {correct} out of {total}")
print(f"Accuracy: {correct / total:.2f}")
print(f"Valid predictions: {len(valid_answers)}")
print(f"Skipped or invalid predictions: {total - len(valid_answers)}")


  retrieved_docs = retriever.get_relevant_documents(question)


[INFO] Retrieved 2 docs for PMID 21645374

-- Doc 1 (snippet) --
Context: Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs i...


-- Doc 2 (snippet) --
Context: The hypothesis was tested that pectin content and methylation degree participate in regulation of cell wall mechanical properties and in this way may affect tissue growth and freezing resistance over the course of plant cold acclimation and de-acclimation. Experiments were carried on the le...


[LLM RAW OUTPUT for PMID 21645374]:
Yes.

[QUESTION]: Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?
[GROUND TRUTH]: yes
[FINAL PREDICTION]: yes
[INFO] Retrieved 2 docs for PMID 16418930

-- Doc 1 (snippet) --
Context: Assessment of visual acuity depends on

In [9]:
# Ground truth
ground_truth_answers = []

for pmid, entry in items:
    truth = entry["final_decision"].lower()
    ground_truth_answers.append(truth)

print(ground_truth_answers)


In [10]:
# Model predictions
print(predictions)