# The project: Biomedical Question Answering with Fine-tuned PubMedBERT and Agent-based RAG

This project combines two key components to build a biomedical QA pipeline using the PubMedQA dataset:

---

### 1. Fine-tuning PubMedBERT

- The PubMedQA dataset (`pqa_labeled`) is used, containing biomedical yes/no/maybe questions with evidence-based contexts and final decisions.
- The model `microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext` is fine-tuned using Huggingface's `Trainer`.
- The input pairs are: question + context → label (`yes`, `no`, or `maybe`).
- Accuracy and a full classification report (precision, recall, F1) are calculated after 3 epochs of training.
- Fine-tuned weights are saved locally for downstream use.

---

### 2. LangGraph Agent-based RAG Pipeline

- A LangGraph pipeline is constructed with 5 agents:
  - **RetrieverAgent**: Uses SentenceTransformer to retrieve top 3 similar contexts from the dataset using FAISS.
  - **ContextPrintAgent**: (silent) for structure, optionally prints top contexts.
  - **ClassifierAgent**: Loads the fine-tuned PubMedBERT model to predict the answer based on the concatenated context.
  - **ExplainerAgent**: Uses an LLM (e.g. Granite) to explain why the model chose that answer based on the retrieved context.
  - **FinalOutputAgent**: Prints the full reasoning, context chunks, model prediction, and ground truth (if available).

- Text is formatted to break every 10 words for readability.

---

### Example Output

For a question like:
> "Can vitamin D deficiency cause depression?"

The pipeline retrieves relevant studies, predicts `"yes"`, and explains the answer based on associated biomedical contexts. If the question exists in the dataset, the true label is also shown.

---

### Goal

This project demonstrates how combining a **fine-tuned biomedical classifier** with **retrieval-augmented reasoning** and **explainable LLM output** can improve trust and performance in biomedical QA systems.


## The dataset: pubmed_qa

This dataset contains biomedical question-answer pairs derived from PubMed articles, used for training models in biomedical **yes/no/maybe** question answering.

### What the dataset contains

Each row in the `"train"` split includes the following fields:

| Column Name      | Description                                                  |
| ---------------- | ------------------------------------------------------------ |
| `pubid`          | Unique PubMed article identifier                             |
| `question`       | A biomedical yes/no/maybe question derived from the article  |
| `context`        | A dictionary with supporting article text and metadata       |
| `long_answer`    | A full explanatory answer in free text form                  |
| `final_decision` | The ground-truth label: one of `"yes"`, `"no"`, or `"maybe"` |

### Details of the `context` field

The `context` field is a dictionary that includes:

* **`contexts`**: List of supporting paragraph texts (strings)
* **`labels`**: Section labels such as `"METHODS"`, `"RESULTS"`, etc.
* **`meshes`**: MeSH terms (medical subject headings) for the article
* **`reasoning_required_pred`**: Predicted answer based on full reasoning
* **`reasoning_free_pred`**: Prediction assuming minimal reasoning

### Label for classification

The main label for supervised training is:

* **`final_decision`**: The correct answer to the question (used as the classification target)

Typical mapping for training:

```python
label2id = {"no": 0, "yes": 1, "maybe": 2}


In [None]:
!pip install bitsandbytes accelerate transformers langgraph faiss-cpu evaluate

In [2]:
from dotenv import load_dotenv
import os

load_dotenv(dotenv_path="env")
hf_token = os.getenv("HF_TOKEN")

In [17]:
# ✅ PubMedQA Fine-tuning with Huggingface Trainer

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
from sklearn.metrics import classification_report
import numpy as np
import evaluate

# ------------------------------
# Step 1: Load PubMedQA
# ------------------------------
dataset = load_dataset("pubmed_qa", "pqa_labeled")
dataset = dataset.filter(lambda x: x["final_decision"] in ["yes", "no", "maybe"])

# ------------------------------
# Step 2: Encode labels
# ------------------------------
label2id = {"no": 0, "yes": 1, "maybe": 2}
id2label = {v: k for k, v in label2id.items()}
def encode_labels(example):
    example["label"] = label2id[example["final_decision"]]
    return example

dataset = dataset.map(encode_labels)

# ------------------------------
# Step 3: Tokenization
# ------------------------------
model_ckpt = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

def preprocess(batch):
    return tokenizer(
        list(map(str, batch["question"])),
        list(map(str, batch["context"])),
        truncation=True,
        max_length=512,
        padding="max_length"
    )




tokenized = dataset.map(preprocess, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# ------------------------------
# Step 4: Load model
# ------------------------------
model = AutoModelForSequenceClassification.from_pretrained(
    model_ckpt,
    num_labels=3,
    id2label=id2label,
    label2id=label2id
)

# ------------------------------
# Step 5: Metrics
# ------------------------------
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": metric.compute(predictions=preds, references=labels)["accuracy"]
    }

# ------------------------------
# Step 6: Training setup
# ------------------------------
train_test = tokenized["train"].train_test_split(test_size=0.1, seed=42)
args = TrainingArguments(
    output_dir="./pubmedqa-bert",
    save_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_test["train"],
    eval_dataset=train_test["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# ------------------------------
# Step 7: Train
# ------------------------------
trainer.train()

# ------------------------------
# Step 8: Evaluate + F1
# ------------------------------
predictions = trainer.predict(train_test["test"])
preds = np.argmax(predictions.predictions, axis=-1)
true = predictions.label_ids

print("✅ Classification Report (fine-tuned BERT on PubMedQA):")
print(classification_report(true, preds, target_names=["no", "yes", "maybe"]))


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss
500,0.521


✅ Classification Report (fine-tuned BERT on PubMedQA):
              precision    recall  f1-score   support

          no       0.89      0.85      0.87        40
         yes       0.84      0.96      0.90        45
       maybe       0.91      0.67      0.77        15

    accuracy                           0.87       100
   macro avg       0.88      0.82      0.85       100
weighted avg       0.87      0.87      0.87       100



In [18]:
# Save model and tokenizer after training
model.save_pretrained("./pubmedqa-bert")
tokenizer.save_pretrained("./pubmedqa-bert")


('./pubmedqa-bert/tokenizer_config.json',
 './pubmedqa-bert/special_tokens_map.json',
 './pubmedqa-bert/vocab.txt',
 './pubmedqa-bert/added_tokens.json',
 './pubmedqa-bert/tokenizer.json')

## The Evaluation in This Script: RAG-style Classification

This script performs evaluation of a fine-tuned BERT model using a **Retrieval-Augmented Generation (RAG)** style workflow, rather than the standard evaluation performed via `Trainer.evaluate()`.

### How It Works

1. **Retriever with FAISS + SentenceTransformer**:
   - Each question is encoded with a SentenceTransformer.
   - The script uses a FAISS index to retrieve the top-3 most semantically similar contexts from the dataset.
   - These contexts are concatenated to form a single input passage.

2. **Classification using Fine-tuned BERT**:
   - The concatenated context (from retrieval) and the question are passed into a fine-tuned `BERT` model (`pubmedqa-bert`) for classification.
   - The model predicts one of three labels: `yes`, `no`, or `maybe`.

3. **Evaluation Metrics**:
   - Predictions are compared to the original ground-truth labels.
   - Standard classification metrics are reported:
     - Precision
     - Recall
     - F1-score
     - Accuracy (via `sklearn.metrics.classification_report`)

### Difference from Standard Evaluation

| Feature                | Trainer.evaluate()                             | RAG Evaluation (this script)                         |
|------------------------|------------------------------------------------|------------------------------------------------------|
| Input Context          | Uses original `context` from the dataset       | Uses top-3 retrieved contexts from FAISS             |
| Retrieval involved     | No                                             | Yes                                                  |
| Realistic QA scenario  | Limited                                        | More realistic (simulates open-domain QA setup)      |
| Evaluation target      | Pure model performance on static input         | End-to-end performance under retrieval + reasoning   |

### Purpose

This type of evaluation answers the question:

> *"How well does the fine-tuned model perform when it has to rely on dynamically retrieved context instead of ground-truth context?"*

This is particularly useful for applications involving **retrieval-based question answering**, such as biomedical QA systems built over large corpora.



In [19]:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import classification_report
import numpy as np
from tqdm import tqdm

# ------------------------------
# Load model + tokenizer
# ------------------------------
model_path = "./pubmedqa-bert"
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.eval()

label_map = {0: "no", 1: "yes", 2: "maybe"}
inv_label_map = {v: k for k, v in label_map.items()}

# ------------------------------
# Load dataset
# ------------------------------
dataset = load_dataset("pubmed_qa", "pqa_labeled", split="train")
dataset = dataset.filter(lambda x: x["final_decision"] in ["yes", "no", "maybe"])
dataset = dataset.select(range(1000))  # ניתן לשנות
contexts = [str(x) for x in dataset["context"]]
questions = dataset["question"]
true_labels = [x.lower() for x in dataset["final_decision"]]

# ------------------------------
# FAISS Retriever
# ------------------------------
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
context_embeddings = retriever.encode(contexts, convert_to_numpy=True, show_progress_bar=True)
index = faiss.IndexFlatIP(context_embeddings.shape[1])
index.add(context_embeddings)

# ------------------------------
# Evaluate on full dataset
# ------------------------------
preds = []

for q in tqdm(questions, desc="Evaluating RAG"):
    # Retrieve top-3 contexts
    q_emb = retriever.encode([q], convert_to_numpy=True)
    D, I = index.search(q_emb, 3)
    top_contexts = [str(contexts[int(i)]) for i in I[0]]
    full_context = " ".join(top_contexts)

    # Tokenize input
    inputs = tokenizer(
        q,
        full_context,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding="max_length"
    )

    # Predict
    with torch.no_grad():
        logits = model(**inputs).logits
        pred = torch.argmax(logits, dim=1).item()
        preds.append(label_map[pred])


inv_label_map = {"no": 0, "yes": 1, "maybe": 2}

true_numeric = [inv_label_map[x] for x in true_labels]
preds = [inv_label_map[p] for p in preds]

from sklearn.metrics import classification_report

print("✅ RAG Evaluation – PubMedQA (Test Set):")
print(classification_report(true_numeric, preds, target_names=["no", "yes", "maybe"]))


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Evaluating RAG: 100%|██████████| 1000/1000 [03:48<00:00,  4.37it/s]

✅ RAG Evaluation – PubMedQA (Test Set):
              precision    recall  f1-score   support

          no       0.96      0.92      0.94       338
         yes       0.92      0.94      0.93       552
       maybe       0.74      0.74      0.74       110

    accuracy                           0.91      1000
   macro avg       0.87      0.87      0.87      1000
weighted avg       0.91      0.91      0.91      1000






#AI AGENT

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    llm_int8_enable_fp32_cpu_offload=True
)

try:
    GRANITE_MODEL = "ibm-granite/granite-3.3-8b-instruct"

    granite_model = AutoModelForCausalLM.from_pretrained(
        GRANITE_MODEL,
        device_map="auto",
        quantization_config=bnb_config
    )

    granite_tokenizer = AutoTokenizer.from_pretrained(GRANITE_MODEL)

    granite_pipe = pipeline(
        "text-generation",
        model=granite_model,
        tokenizer=granite_tokenizer,
        pad_token_id=granite_tokenizer.eos_token_id,
        return_full_text=False
    )
except Exception as e:
    print("[Warning] Failed to load Granite model:", e)
    granite_pipe = lambda prompt, **kwargs: [{"generated_text": "[granite model unavailable]"}]


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/207 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/801 [00:00<?, ?B/s]

Device set to use cuda:0


In [10]:
#!pip install -q langgraph sentence-transformers transformers datasets faiss-cpu

from langgraph.graph import StateGraph, END
from typing import TypedDict
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, BitsAndBytesConfig
import faiss
import torch
import numpy as np

# ------------------------------
# State definition
# ------------------------------
class QAState(TypedDict):
    question: str
    prediction: str
    context: str
    top_contexts: list[str]
    ground_truth: str
    explanation: str

# ------------------------------
# Load dataset + labels
# ------------------------------
dataset = load_dataset("pubmed_qa", "pqa_labeled", split="train")
dataset = dataset.filter(lambda x: x["final_decision"] in ["yes", "no", "maybe"])
dataset = dataset.select(range(1000))

questions_list = dataset["question"]
contexts_raw = dataset["context"]
labels_raw = dataset["final_decision"]
label_map = {0: "no", 1: "yes", 2: "maybe"}
inv_label_map = {v: k for k, v in label_map.items()}

# ------------------------------
# Encode context embeddings
# ------------------------------
print("Encoding context with SentenceTransformer...")
texts_for_index = [
    " ".join(x["contexts"]) if isinstance(x, dict) and "contexts" in x else str(x)
    for x in contexts_raw
]
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = retriever.encode(texts_for_index, convert_to_numpy=True, show_progress_bar=True)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)

# ------------------------------
# Load fine-tuned classifier
# ------------------------------
print("Loading fine-tuned BERT classifier...")
model = AutoModelForSequenceClassification.from_pretrained("./pubmedqa-bert")
tokenizer = AutoTokenizer.from_pretrained("./pubmedqa-bert")

def break_lines_every_n_words(text: str, n: int = 10) -> str:
    """Break text every n words."""
    words = text.split()
    lines = [" ".join(words[i:i+n]) for i in range(0, len(words), n)]
    return "\n".join(lines)

# ------------------------------
# Agent 1: Retriever
# ------------------------------
def RetrieverAgent(state: QAState) -> QAState:
    q = state["question"]
    q_emb = retriever.encode([q], convert_to_numpy=True)
    D, I = index.search(q_emb, 3)

    top_contexts = []
    for i in I[0]:
        raw = contexts_raw[int(i)]
        if isinstance(raw, dict) and "contexts" in raw:
            top_contexts.extend(raw["contexts"])
        else:
            top_contexts.append(str(raw))

    full_context = "\n\n".join([f"Context {i+1}:\n{ctx}" for i, ctx in enumerate(top_contexts)])
    if q in questions_list:
        gt = labels_raw[questions_list.index(q)]
    else:
        gt = "[unknown]"  # שאלה חדשה שלא קיימת בדאטהסט

    return {**state, "context": full_context, "top_contexts": top_contexts, "ground_truth": gt}

# ------------------------------
# Agent 2: Context printer
# ------------------------------
def ContextPrintAgent(state: QAState) -> QAState:
    for i, ctx in enumerate(state["top_contexts"][:3]):
      return state

# ------------------------------
# Agent 3: Classifier
# ------------------------------
def ClassifierAgent(state: QAState) -> QAState:
    inputs = tokenizer(
        state["question"],
        state["context"],
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding="max_length"
    )
    with torch.no_grad():
        logits = model(**inputs).logits
        pred = torch.argmax(logits, dim=1).item()
        label = label_map[pred]
    return {**state, "prediction": label}

# ------------------------------
# Agent 4: LLM explainer
# ------------------------------
def ExplainerAgent(state: QAState) -> QAState:
    prompt = f"""
You are a biomedical QA assistant.

Given the following question and retrieved context, the classifier predicted the answer: "{state['prediction']}".

Your task is to explain WHY this answer is appropriate, based only on the evidence from the context.

---

Question:
{state['question']}

Context:
{state['context']}

Answer:
{state['prediction']}

---

Explanation:
"""
    response = granite_pipe(prompt, max_new_tokens=150)[0]["generated_text"]
    return {**state, "explanation": response.strip()}

# ------------------------------
# Agent 5: Ground truth compare + print final output
# ------------------------------
def FinalOutputAgent(state: QAState) -> QAState:
    print("\n==============================")
    print(f"Question:\n{state['question']}\n")
    print(f"Predicted Answer: {state['prediction']}")
    print(f"Ground Truth: {state['ground_truth']}\n")

    print(f"Top Contexts:")
    for i, ctx in enumerate(state['top_contexts'][:3]):
        print(f"\nContext {i+1}:\n{break_lines_every_n_words(ctx.strip(), 10)}")

    print(f"\nExplanation:\n{break_lines_every_n_words(state['explanation'].strip(), 10)}")
    print("==============================\n")

    return None




# ------------------------------
# Build LangGraph
# ------------------------------
graph = StateGraph(QAState)
graph.add_node("RetrieverAgent", RetrieverAgent)
graph.add_node("ContextPrintAgent", ContextPrintAgent)
graph.add_node("ClassifierAgent", ClassifierAgent)
graph.add_node("ExplainerAgent", ExplainerAgent)
graph.add_node("FinalOutputAgent", FinalOutputAgent)

graph.set_entry_point("RetrieverAgent")
graph.add_edge("RetrieverAgent", "ContextPrintAgent")
graph.add_edge("ContextPrintAgent", "ClassifierAgent")
graph.add_edge("ClassifierAgent", "ExplainerAgent")
graph.add_edge("ExplainerAgent", "FinalOutputAgent")
graph.add_edge("FinalOutputAgent", END)

app = graph.compile()

# ------------------------------
# Run full pipeline
# ------------------------------
question = "Can vitamin D deficiency cause depression?"
_ = app.invoke({"question": question})  # discard output




🔄 Encoding context with SentenceTransformer...


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

🔄 Loading fine-tuned BERT classifier...

Question:
Can vitamin D deficiency cause depression?

Predicted Answer: yes
Ground Truth: [unknown]

Top Contexts:

Context 1:
The aetiology of osteochondritis dissecans is still unclear. The aim
of this prospective pilot study was to analyse whether vitamin
D insufficiency, or deficiency, might be a contributing etiological factor
in the development of an OCD lesion.

Context 2:
The serum level of vitamin D3 in 23 consecutive patients
(12 male and 11 female) suffering from a stage III,
or stages III and IV, OCD lesion (mostly stage III)
admitted for surgery was measured.

Context 3:
The patients' mean age was 31.3 years and most of
them already exhibited closed epiphyseal plates. In the majority of
patients (18/23), a distinct vitamin D3 deficiency was found, two
patients were vitamin D3-insufficient and, in three patients, the vitamin
D3 level reached the lowest normal value.

Explanation:
The context suggests that vitamin D deficiency is comm

In [11]:
import random
from datasets import load_dataset

# Load the labeled train set from PubMedQA
dataset = load_dataset("pubmed_qa", "pqa_labeled", split="train")
dataset = dataset.filter(lambda x: x["final_decision"] in ["yes", "no", "maybe"])

# Select 5 random indices
random_indices = random.sample(range(len(dataset)), 5)

# Run each selected question through the LangGraph pipeline
for i, idx in enumerate(random_indices, 1):
    question = dataset[idx]["question"]
    print(f"\nQuestion {i}: {question}")
    _ = app.invoke({"question": question})



Question 1: Can dobutamine stress echocardiography induce cardiac troponin elevation?

Question:
Can dobutamine stress echocardiography induce cardiac troponin elevation?

Predicted Answer: yes
Ground Truth: no

Top Contexts:

Context 1:
Elevation of cardiac troponin (cTn) is considered specific for myocardial
damage. Elevated cTn and echocardiogrpahic documentation of wall motion abnormalities
(WMAs) that were recorded after extreme physical effort raise the
question whether dobutamine stress echo (DSE), can also induce elevation
of troponin.

Context 2:
we prospective enrolled stable patients (age>18 years) referred to DSE.
The exam was performed under standardized conditions. Blood samples for
cTnI were obtained at baseline and 18-24 hours after the
test. We aimed to compare between the clinical and echocardiographic
features of patients with elevated cTnI and those without cTnI
elevations.

Context 3:
Fifty-seven consecutive patients were included. The average age was 64.4
± 10.7,

In [20]:
# Define 5 custom biomedical questions related to depression
depression_questions = [
    "Can vitamin D deficiency cause depression?",
    "Is there a link between serotonin levels and depression?",
    "Does cognitive behavioral therapy reduce symptoms of depression?",
    "Can chronic inflammation contribute to the development of depression?",
    "Is depression more prevalent in patients with chronic pain conditions?"
]

# Send each question through the LangGraph RAG pipeline
for i, q in enumerate(depression_questions, 1):
    print(f"\nQuestion {i}: {q}")
    _ = app.invoke({"question": q})



Question 1: Can vitamin D deficiency cause depression?

Question:
Can vitamin D deficiency cause depression?

Predicted Answer: yes
Ground Truth: [unknown]

Top Contexts:

Context 1:
Epidemiological data show significant associations of vitamin D deficiency and
autoimmune diseases. Vitamin D may prevent autoimmunity by stimulating naturally
occurring regulatory T cells.

Context 2:
To elucidate whether vitamin D supplementation increases Tregs frequency (%Tregs)
within circulating CD4+ T cells.

Context 3:
We performed an uncontrolled vitamin D supplementation trial among 50
apparently healthy subjects including supplementation of 140,000 IU at baseline
and after 4 weeks (visit 1). The final follow-up visit
was performed 8 weeks after the baseline examination (visit 2).
Blood was drawn at each study visit to determine 25-hydroxyvitamin
D levels and %Tregs. Tregs were characterized as CD4+CD25++ T
cells with expression of the transcription factor forkhead box P3
and low or absent expre