# Medical Chatbot with RAG, TinyLLaMA, and PDF Knowledge Base

In [None]:
# # Install dependencies (minimized for speed)
# pip install -q pypdf2 langchain sentence-transformers faiss-cpu transformers torch accelerate bitsandbytes langchain-community

In [None]:
# Import libraries
import os
import PyPDF2
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import torch
# from google.colab import files
from typing import List, Dict

In [None]:
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
import os
from typing import List
import PyPDF2

def extract_text_from_pdfs(pdf_directory: str) -> List[str]:
    documents = []
    for filename in os.listdir(pdf_directory):
        if filename.endswith(".pdf"):
            with open(os.path.join(pdf_directory, filename), "rb") as f:
                try:
                    reader = PyPDF2.PdfReader(f)
                    text = " ".join(page.extract_text() or "" for page in reader.pages)
                    if text.strip():
                        documents.append(text)
                except Exception as e:
                    print(f"[!] Error reading {filename}: {e}")
    return documents


In [None]:
# Process PDFs
pdf_texts = extract_text_from_pdfs("pdfs")
print(f"Processed {len(pdf_texts)} PDFs with text content.")

In [None]:
# Create vector database
def build_vector_store(texts: List[str]) -> FAISS:
    # Optimize chunking for medical texts
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=800,  # Smaller chunks for faster retrieval
        chunk_overlap=150,  # Reduced overlap for efficiency
        length_function=len,
        separators=["\n\n", "\n", ". ", " ", ""]
    )
    chunks = text_splitter.create_documents(texts)

    # Use lightweight embeddings
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        model_kwargs={"device": device}
    )

    # Build FAISS index
    try:
        vector_store = FAISS.from_documents(chunks, embeddings)
        # Save index for local use (optional)
        vector_store.save_local("faiss_index")
        return vector_store
    except Exception as e:
        print(f"Error building vector store: {e}")
        return None

In [None]:
vector_store = build_vector_store(pdf_texts)
if vector_store:
    print("Vector database created successfully.")
else:
    raise RuntimeError("Failed to create vector database.")

In [None]:
# Set up TinyLLaMA with quantization
def initialize_llm() -> HuggingFacePipeline:
    model_name = "TinyLLaMA/TinyLLaMA-1.1B-Chat-v1.0"

    # Configure 4-bit quantization
    quantization_config = BitsAndBytesConfig(
    load_in_4bit=False,  # Disable 4-bit quantization
    load_in_int8=True,  # Use 8-bit quantization
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4"
)


    try:
        # Initialize tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,  # Use float16 for better efficiency
        device_map="auto",
        low_cpu_mem_usage=True,
        # quantization_config={"load_in_int8": True}  # Use int8 quantization (CPU-supported)
    )

        # Create pipeline
        
        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=150,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
            return_full_text=False
        )

        return HuggingFacePipeline(pipeline=pipe)
    except Exception as e:
        print(f"Error initializing LLM: {e}")
        return None

In [None]:
llm = initialize_llm()
if llm:
    print("TinyLLaMA initialized successfully.")
else:
    raise RuntimeError("Failed to initialize LLM.")

In [None]:
# Note for fine-tuning:
# - Use LoRA with `peft` library for efficient fine-tuning on medical Q&A datasets.
# - Example: peft_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
# - Train on datasets like PubMedQA or MedQA for better medical accuracy.


In [None]:
# Configure RAG pipeline
def setup_rag(vector_store, llm) -> RetrievalQA:
    try:
        qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=vector_store.as_retriever(
                search_type="mmr",  # Optimized retrieval
                search_kwargs={"k": 2, "fetch_k": 5}  # Fewer docs for speed
            ),
            return_source_documents=False,
            chain_type_kwargs={"prompt": None}  # Custom prompt can be added
        )
        return qa_chain
    except Exception as e:
        print(f"Error setting up RAG: {e}")
        return None

In [None]:
qa_chain = setup_rag(vector_store, llm)
if qa_chain:
    print("RAG pipeline configured.")
else:
    raise RuntimeError("Failed to set up RAG.")

In [None]:
import textwrap

def medical_chatbot(qa_chain):
    print("\n🩺 HealthScribe is ready! Type 'exit' to quit.")
    print("💡 Example questions: 'What are the symptoms of diabetes?' or 'How is hypertension treated?'")

    while True:
        try:
            query = input("\nYou: ").strip()

            # Handle empty input
            if not query:
                print("⚠️  Please enter a valid question.")
                continue

            # Exit condition
            if query.lower() == "exit":
                print("👋 Goodbye! Stay healthy.")
                break

            # Get the answer from the QA chain
            result = qa_chain({"query": query})
            answer = result.get("result", "").strip()

            # Handle empty answers
            if not answer:
                print("🤖 HealthScribe: Sorry, I couldn't find relevant information. Try rephrasing your question.")
            else:
                wrapped_answer = textwrap.fill(answer, width=100)
                print(f"🤖 HealthScribe: {wrapped_answer}")

        except KeyboardInterrupt:
            print("\n⛔ Interrupted. Type 'exit' to quit.")
        except Exception as e:
            print(f"❌ Error processing query: {e}")
            print("🔄 Please try again.")


In [None]:
# Start chatbot
medical_chatbot(qa_chain)