##RAG Chatbot for Q&A for medical queries

##Installing and importing required packages

In [4]:
!pip install pandas sentence-transformers faiss-cpu --quiet
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import torch

##Loading the dataset and pre-processing

In [6]:
df = pd.read_csv("mle_screening_dataset.csv")
df.dropna(inplace=True)
df.reset_index(drop=True, inplace=True)

print("Dataset sample:")
print(df.shape)
print(df.head())

Dataset sample:
(16401, 2)
                         question  \
0        What is (are) Glaucoma ?   
1        What is (are) Glaucoma ?   
2        What is (are) Glaucoma ?   
3  Who is at risk for Glaucoma? ?   
4       How to prevent Glaucoma ?   

                                              answer  
0  Glaucoma is a group of diseases that can damag...  
1  The optic nerve is a bundle of more than 1 mil...  
2  Open-angle glaucoma is the most common form of...  
3  Anyone can develop glaucoma. Some people are a...  
4  At this time, we do not know how to prevent gl...  


##Approach followed:
Considered a pre-trained sentence tranformers that can handle the Q&A pairs.

Converted the dataset into corpus to create an embedded model out of it with existing Q&A pairs.

Reason for following this approach is for small datasets it could scale effectively and utilize less resources compared to traditional fine-tuning setup.

This approach can affectively identify the paraphrased queries and compare them with cosine similarity

In [7]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

corpus = df["question"] + " " + df["answer"]
corpus_embeddings = model.encode(corpus, convert_to_tensor=True)

print(f"Corpus embeddings shape: {corpus_embeddings.shape}")

Corpus embeddings shape: torch.Size([16401, 384])


In [8]:
torch.save(corpus_embeddings, "corpus_embeddings.pt")
print("Embeddings saved to disk.")

Embeddings saved to disk.


Retrieval Code

In [11]:
def retrieve(query, top_k=1):
    query_emb = model.encode(query, convert_to_tensor=True)
    cos_scores = util.cos_sim(query_emb, corpus_embeddings)[0]
    top_results = torch.topk(cos_scores, k=top_k)

    results = []
    for score, idx in zip(top_results.values, top_results.indices):
        idx = idx.item()
        results.append({
            "question": df.iloc[idx]["question"],
            "answer": df.iloc[idx]["answer"],
            "score": score.item(),
            "idx": idx
        })
    return results

In [12]:
# Sample Queries
queries = [
    "What is glaucoma?",
    "How to prevent glaucoma?",
    "Who is at risk for glaucoma?"
]

for q in queries:
    results = retrieve(q)
    for r in results:
        print(f"\nQ: {r['question']}\nA: {r['answer']}\nScore: {r['score']:.4f}")


Q: What is (are) Glaucoma ?
A: Open-angle glaucoma is the most common form of glaucoma. In the normal eye, the clear fluid leaves the anterior chamber at the open angle where the cornea and iris meet. When the fluid reaches the angle, it flows through a spongy meshwork, like a drain, and leaves the eye. Sometimes, when the fluid reaches the angle, it passes too slowly through the meshwork drain, causing the pressure inside the eye to build. If the pressure damages the optic nerve, open-angle glaucoma -- and vision loss -- may result.
Score: 0.7356

Q: How to prevent Glaucoma ?
A: At this time, we do not know how to prevent glaucoma. However, studies have shown that the early detection and treatment of glaucoma, before it causes major vision loss, is the best way to control the disease. So, if you fall into one of the higher risk groups for the disease, make sure to have a comprehensive dilated eye exam at least once every one to two years.  Get tips on finding an eye care professional

##Evaluation pipeline

In [13]:
def evaluate_retrieval(eval_subset, top_k=3):
    mrr_total = 0
    hit_count = 0
    N = len(eval_subset)

    for i, row in eval_subset.iterrows():
        query = row["question"]
        true_answer = row["answer"]
        results = retrieve(query, top_k=top_k)
        ranks = [res["answer"] for res in results]

        if true_answer in ranks:
            hit_count += 1
            rank_position = ranks.index(true_answer) + 1
            mrr_total += 1 / rank_position
        else:
            mrr_total += 0

    hit_at_k = hit_count / N
    mrr_score = mrr_total / N
    return mrr_score, hit_at_k

##Creating a subset of data to evaluate

In [14]:
from sklearn.model_selection import train_test_split
eval_df, _ = train_test_split(df, test_size=0.9, random_state=42)
eval_df.reset_index(drop=True, inplace=True)

In [19]:
mrr, hit_at_3 = evaluate_retrieval(eval_df, top_k=3)
print(f"Retrieval MRR: {mrr:.4f}, Top-3 Accuracy: {hit_at_3:.4f}")

Retrieval MRR: 0.7514, Top-3 Accuracy: 0.8762


## Generation Phase

NLP pipeline to interact with Llama -1B

In [22]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="meta-llama/Llama-3.2-1B")

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

Device set to use cpu


Structuring the retrieval answers from embeddings

In [28]:
import re
def summarize_context(retrieved, max_sentences=3):
    context_list = []
    for r in retrieved:
        sentences = re.split(r'(?<=[.!?]) +', r['answer'])
        context_list.append(" ".join(sentences[:max_sentences]))
    return " ".join(context_list)

Generating answer by structuring prompt for the model to answer accurately

In [34]:
def generate_answer(query, top_k=3):
    retrieved = retrieve(query, top_k=top_k)
    context_text = summarize_context(retrieved)
    prompt = f"""Question: {query}
Context: {context_text}
Please provide a concise, factual answer based only on the context."""
    generated = pipe(prompt, max_new_tokens=250, do_sample=True, top_p=0.95, pad_token_id=pipe.tokenizer.eos_token_id)
    return generated[0]['generated_text']


Utilizing the embedding model to retieve top 3 answers which is passed to Llama model as context for the model to generate the answers accurately


In [30]:
example_queries = [
    "What is glaucoma?",
    "How can glaucoma be prevented?",
    "Who is at risk for glaucoma?"
]
for q in example_queries:
    print(f"\nUser Query: {q}")
    generated_answer = generate_answer(q, top_k=3)
    print(f"Generated Answer:\n{generated_answer}\n")



User Query: What is glaucoma?


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Generated Answer:
Question: What is glaucoma?
Context: Open-angle glaucoma is the most common form of glaucoma. In the normal eye, the clear fluid leaves the anterior chamber at the open angle where the cornea and iris meet. When the fluid reaches the angle, it flows through a spongy meshwork, like a drain, and leaves the eye. Glaucoma is a group of diseases that can damage the eye's optic nerve and result in vision loss and blindness. The most common form of the disease is open-angle glaucoma. With early treatment, you can often protect your eyes against serious vision loss. The optic nerve is a bundle of more than 1 million nerve fibers. It connects the retina to the brain.
Please provide a concise, factual answer based only on the context. The term "glaucoma" refers to a group of diseases that can damage the eye's optic nerve and result in vision loss and blindness. The most common form of the disease is open-angle glaucoma. With early treatment, you can often protect your eyes agai

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Generated Answer:
Question: How can glaucoma be prevented?
Context: At this time, we do not know how to prevent glaucoma. However, studies have shown that the early detection and treatment of glaucoma, before it causes major vision loss, is the best way to control the disease. So, if you fall into one of the higher risk groups for the disease, make sure to have a comprehensive dilated eye exam at least once every one to two years. Glaucoma is a group of diseases that can damage the eye's optic nerve. It is a leading cause of blindness in the United States. It usually happens when the fluid pressure inside the eyes slowly rises, damaging the optic nerve. Glaucoma is a group of diseases that can damage the eye's optic nerve and result in vision loss and blindness. The most common form of the disease is open-angle glaucoma. With early treatment, you can often protect your eyes against serious vision loss.
Please provide a concise, factual answer based only on the context. In other words, 

Looking at the example queries they were pretty good but being a 1B model it was not able to handle repetitive sentences. So a better model with higher size would provide great results. For the use-case and resources the above results are expected

In [35]:
!pip install rouge-score --quiet

from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

def evaluate_generation(eval_subset, top_k=3):
    rouge1_total, rouge2_total, rougeL_total = 0, 0, 0
    N = len(eval_subset)

    for _, row in eval_subset.iterrows():
        query = row["question"]
        true_answer = row["answer"]
        gen_answer = generate_answer(query, top_k=top_k)

        scores = scorer.score(true_answer, gen_answer)
        rouge1_total += scores['rouge1'].fmeasure
        rouge2_total += scores['rouge2'].fmeasure
        rougeL_total += scores['rougeL'].fmeasure

    return {
        "ROUGE-1": rouge1_total / N,
        "ROUGE-2": rouge2_total / N,
        "ROUGE-L": rougeL_total / N
    }

gen_metrics = evaluate_generation(eval_df.head(10), top_k=3)
print(gen_metrics)

{'ROUGE-1': 0.4589396216214527, 'ROUGE-2': 0.30940607354510846, 'ROUGE-L': 0.3470947412186616}


Generation Performance Explanation

In our pipeline, the generation component uses a retrieval-augmented generation (RAG) setup:
	1.	Top-K retrieval using Sentence-BERT embeddings identifies the most relevant answers from the medical Q&A corpus.
	2.	Text-generation model (LLaMA 1B) produces concise, readable answers based on the retrieved context.

Evaluation Metrics (ROUGE Scores on 10-sample subset):

ROUGE-1: 0.459

ROUGE-2: 0.309

ROUGE-L: 0.347

Interpretation:

These scores indicate moderate overlap between generated and reference answers.

Some repetition and truncation occurs due to long context passages and the model size.

Despite moderate ROUGE, the generated answers are mostly relevant, factually correct, and readable.

ROUGE penalizes paraphrasing, so factual answers with reworded sentences may appear lower-scoring.
Improvements could include using a larger model to make the generation more accurate. The retrieval phase is already providing promising results. The 1B model was chosen to run efficiently on low-resource setups.

Fine-tuning the model on the Q&A dataset is another option, but this RAG pipeline is generally more effective for most use cases. For even better performance, we could explore larger or more specialized models that offer improved accuracy while consuming fewer resources compared to a full fine-tuning approach.
