<a href="https://colab.research.google.com/github/hemhalatha/ML_projects/blob/main/medbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install openai==0.28.0 gradio numpy faiss-cpu sentence-transformers

Collecting openai==0.28.0
  Downloading openai-0.28.0-py3-none-any.whl.metadata (13 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Downloading openai-0.28.0-py3-none-any.whl (76 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.5/76.5 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (23.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.6/23.6 MB[0m [31m83.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu, openai
  Attempting uninstall: openai
    Found existing installation: openai 1.109.1
    Uninstalling openai-1.109.1:
      Successfully uninstalled openai-1.109.1
Successfully installed faiss-cpu-1.13.0 openai-0.28.0


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!pip install google-generativeai gradio sentence-transformers faiss-cpu tqdm beautifulsoup4 requests



In [32]:
import faiss
import json
import numpy as np
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import os

# ============================================================================
# COMPONENT 1: IMPROVED FAISS Retriever
# ============================================================================

SAVE_DIR = "/content/drive/MyDrive/medical_rag_db"

# Load index
index = faiss.read_index(os.path.join(SAVE_DIR, "faiss_index.bin"))

# Load metadata
with open(os.path.join(SAVE_DIR, "metadata.json"), "r") as f:
    metadata = json.load(f)

# Load the embedder
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

def retrieve_with_rerank(query, top_k=5, initial_k=20):
    """
    Improved retrieval with two-stage approach:
    1. Retrieve more candidates (initial_k)
    2. Re-rank using query-document similarity

    Args:
        query: search query string
        top_k: final number of results to return
        initial_k: number of candidates to retrieve initially

    Returns:
        list of dicts with id, sentence, and score
    """
    # Stage 1: Retrieve more candidates
    query_embedding = embed_model.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_embedding, initial_k)

    # Stage 2: Re-rank using cross-encoder or better scoring
    candidates = []
    for i, idx in enumerate(indices[0]):
        if idx < len(metadata):
            chunk_text = metadata[idx].get("chunk", metadata[idx].get("sentence", ""))
            # Better similarity score (cosine similarity)
            similarity = float(1 / (1 + distances[0][i]))

            candidates.append({
                "id": metadata[idx].get("id", idx),
                "sentence": chunk_text,
                "score": similarity,
                "distance": float(distances[0][i])
            })

    # Sort by score and return top_k
    candidates.sort(key=lambda x: x["score"], reverse=True)
    return candidates[:top_k]

def retrieve_with_query_expansion(query, top_k=5):
    """
    Query expansion: generate related queries and aggregate results

    Args:
        query: original search query
        top_k: number of results to return

    Returns:
        list of dicts with id, sentence, and score
    """
    # Generate query variations
    query_variations = [
        query,
        f"What is {query}",
        f"How to treat {query}",
        f"Symptoms of {query}"
    ]

    all_results = {}

    for q_var in query_variations:
        query_embedding = embed_model.encode([q_var], convert_to_numpy=True)
        distances, indices = index.search(query_embedding, top_k * 2)

        for i, idx in enumerate(indices[0]):
            if idx < len(metadata):
                chunk_id = metadata[idx].get("id", idx)
                similarity = float(1 / (1 + distances[0][i]))

                # Aggregate scores for duplicate chunks
                if chunk_id in all_results:
                    all_results[chunk_id]["score"] = max(all_results[chunk_id]["score"], similarity)
                else:
                    all_results[chunk_id] = {
                        "id": chunk_id,
                        "sentence": metadata[idx].get("chunk", metadata[idx].get("sentence", "")),
                        "score": similarity
                    }

    # Sort and return top_k
    results = sorted(all_results.values(), key=lambda x: x["score"], reverse=True)
    return results[:top_k]

def retrieve_with_filters(query, top_k=5, min_score=0.3, max_length=500):
    """
    Retrieval with filtering for quality control

    Args:
        query: search query string
        top_k: number of results to return
        min_score: minimum similarity score threshold
        max_length: maximum chunk length to consider

    Returns:
        list of dicts with id, sentence, and score
    """
    query_embedding = embed_model.encode([query], convert_to_numpy=True)
    # Retrieve more candidates to account for filtering
    distances, indices = index.search(query_embedding, top_k * 3)

    results = []
    for i, idx in enumerate(indices[0]):
        if idx < len(metadata):
            chunk_text = metadata[idx].get("chunk", metadata[idx].get("sentence", ""))
            similarity = float(1 / (1 + distances[0][i]))

            # Apply filters
            if (similarity >= min_score and
                len(chunk_text) <= max_length and
                len(chunk_text) > 20):  # Not too short either

                results.append({
                    "id": metadata[idx].get("id", idx),
                    "sentence": chunk_text,
                    "score": similarity
                })

            if len(results) >= top_k:
                break

    return results

def retrieve_hybrid(query, top_k=5, use_rerank=True, use_expansion=False, use_filters=True):
    """
    Hybrid retrieval combining multiple techniques

    Args:
        query: search query string
        top_k: number of results to return
        use_rerank: enable two-stage retrieval
        use_expansion: enable query expansion
        use_filters: enable quality filters

    Returns:
        list of dicts with id, sentence, and score
    """
    if use_expansion:
        results = retrieve_with_query_expansion(query, top_k)
    elif use_rerank:
        results = retrieve_with_rerank(query, top_k, initial_k=top_k * 4)
    else:
        # Basic retrieval
        query_embedding = embed_model.encode([query], convert_to_numpy=True)
        distances, indices = index.search(query_embedding, top_k)

        results = []
        for i, idx in enumerate(indices[0]):
            if idx < len(metadata):
                results.append({
                    "id": metadata[idx].get("id", idx),
                    "sentence": metadata[idx].get("chunk", metadata[idx].get("sentence", "")),
                    "score": float(1 / (1 + distances[0][i]))
                })

    # Apply filters if enabled
    if use_filters:
        results = [r for r in results if r["score"] >= 0.3 and 20 < len(r["sentence"]) <= 500]

    return results[:top_k]

# ============================================================================
# COMPONENT 2: Prompt Builder
# ============================================================================

def build_prompt(query, retrieved):
    """
    Build a RAG prompt using retrieved chunks only.
    """
    context_parts = []
    for i, r in enumerate(retrieved, start=1):
        snippet = r.get("chunk", r.get("sentence", ""))[:1800]
        context_parts.append(f"[chunk_{i}] {snippet}")

    context_text = "\n\n".join(context_parts)

    user_prompt = f"""
You are a medical information assistant. Use ONLY the provided context paragraphs to answer user questions.
- Do NOT recommend or prescribe medications.
- If asked for urgent/emergency matters, tell the user to seek immediate medical attention.
- If the context is insufficient, recommend seeing a healthcare professional.
- Provide short, clear, educational answers.
- Keep responses friendly and concise (2-3 short paragraphs).
Answer the user question using the context below. Do NOT hallucinate.Also use simple terms and the answer should be easily understandable

CONTEXT:
{context_text}

QUESTION:
{query}

Answer format:
Short factual answer
"""
    return user_prompt

# ============================================================================
# COMPONENT 3: Gemini Answer Generator
# ============================================================================

genai.configure(api_key="AIzaSyC1fLGhBKEkxxAcM-LuZOaBUJO_WpNwZwU")
model = genai.GenerativeModel("models/gemini-2.5-flash")
chat = model.start_chat()

def generate_answer_gemini(query, k=5, retrieval_method="hybrid"):
    """
    Generate answer using Gemini with improved retrieval

    Args:
        query: User's question
        k: Number of sentences to retrieve
        retrieval_method: "basic", "rerank", "expansion", "hybrid", or "filtered"
    """
    # Select retrieval method
    if retrieval_method == "rerank":
        retrieved = retrieve_with_rerank(query, top_k=k)
    elif retrieval_method == "expansion":
        retrieved = retrieve_with_query_expansion(query, top_k=k)
    elif retrieval_method == "filtered":
        retrieved = retrieve_with_filters(query, top_k=k)
    elif retrieval_method == "hybrid":
        retrieved = retrieve_hybrid(query, top_k=k)
    else:  # basic
        query_embedding = embed_model.encode([query], convert_to_numpy=True)
        distances, indices = index.search(query_embedding, k)
        retrieved = []
        for i, idx in enumerate(indices[0]):
            if idx < len(metadata):
                retrieved.append({
                    "id": metadata[idx].get("id", idx),
                    "sentence": metadata[idx].get("chunk", metadata[idx].get("sentence", "")),
                    "score": float(1 / (1 + distances[0][i]))
                })

    # Convert to chunk format for prompt
    retrieved_chunks = []
    for r in retrieved:
        retrieved_chunks.append({
            "chunk": r["sentence"],
            "id": r["id"],
            "score": r["score"]
        })

    # Build prompt and get response
    user_prompt = build_prompt(query, retrieved_chunks)
    response = chat.send_message(user_prompt)

    return {"answer": response.text}

# ============================================================================
# USAGE EXAMPLES
# ============================================================================

# Method 1: Hybrid (recommended - combines best techniques)
out = generate_answer_gemini("causes of Pneumonia?", k=5, retrieval_method="rerank")

# Method 2: With re-ranking
# out = generate_answer_gemini("How can I cure flu?", k=5, retrieval_method="rerank")

# Method 3: With query expansion
# out = generate_answer_gemini("How can I cure flu?", k=5, retrieval_method="expansion")

# Method 4: With filters
# out = generate_answer_gemini("How can I cure flu?", k=5, retrieval_method="filtered")

# Method 5: Basic (original)
# out = generate_answer_gemini("How can I cure flu?", k=5, retrieval_method="basic")

print(out["answer"])


Pneumonia is caused by inflamed or swollen lung tissue due to an infection with a germ. Specifically, viral pneumonia is caused by a virus.


In [33]:
out = generate_answer_gemini("I have diarrhea for 5 days and stomach pain. what is the problem i have?", k=3, retrieval_method="expansion")
print(out["answer"])

I cannot diagnose your specific condition based on your symptoms.

However, viral gastroenteritis is an infection of the stomach and intestine caused by a virus, which can lead to symptoms like diarrhea and vomiting.

For a proper diagnosis and treatment plan, it is best to see a healthcare professional.


In [34]:
out = generate_answer_gemini("I have neck pain?", k=5, retrieval_method="expansion")
print(out["answer"])

Neck pain is discomfort in any of the structures in your neck. These structures include the muscles, nerves, bones (vertebrae), joints, and the discs between the bones.

Cervical spondylosis is a common disorder that can cause chronic neck pain. It involves wear on the cartilage (disks) and bones of the neck.

If you are experiencing neck pain, it is recommended to see a healthcare professional for a proper diagnosis.


In [35]:
out = generate_answer_gemini("I have joint pain?", k=5, retrieval_method="expansion")
print(out["answer"])

Joint pain can be a symptom of various conditions. One type of arthritis mentioned is gout, which occurs when uric acid builds up in the blood and causes inflammation in the joints.

Acute gout is a painful condition that often affects only one joint, while chronic gout involves repeated episodes of pain and inflammation, potentially affecting more than one joint.

If you are experiencing joint pain, it is recommended to see a healthcare professional for a proper diagnosis and treatment plan.


In [37]:
out = generate_answer_gemini("tell me some disease names?", k=30, retrieval_method="expansion")
print(out["answer"])

Here are some disease names:

*   Tay-Sachs disease
*   Klippel-Trenaunay syndrome (KTS)
*   Necrotizing vasculitis
*   Viral gastroenteritis
*   Osteomyelitis
*   Peutz-Jeghers syndrome (PJS)
*   Williams syndrome
*   Truncus arteriosus
*   Pneumonia
*   Acute mountain sickness
*   Cervical spondylosis
*   Gout
*   Aspergillosis
*   Facioscapulohumeral muscular dystrophy
*   Wolff-Parkinson-White (WPW) syndrome
*   Congenital heart disease (CHD)
*   Food allergy
*   Myocardial infarction (heart attack)
*   Pericarditis
*   Night terrors (sleep terrors)
*   Scoliosis
*   Radial nerve dysfunction
*   Peripartum cardiomyopathy
*   Obesity hypoventilation syndrome (OHS)
*   Hyperkalemia (high potassium level)
*   Abdominal aortic aneurysm (AAA)
*   Myocardial contusion


In [38]:
out = generate_answer_gemini("I have vomiting", k=5, retrieval_method="expansion")
print(out["answer"])

Vomiting can be a symptom of various conditions. For example, viral gastroenteritis, which is an infection of the stomach and intestine caused by a virus, can lead to vomiting and diarrhea.

It is important to see a healthcare professional to understand the cause of your vomiting and to receive appropriate advice and care.
