#**Install Required Packages**

In [None]:
print("Installing required packages...")

!pip install -q langchain langchain-community langchain-openai
!pip install -q sentence-transformers faiss-cpu chromadb
!pip install -q biopython requests beautifulsoup4 lxml
!pip install -q nltk spacy python-dotenv
!pip install -q groq google-generativeai openai
!pip install -q gradio plotly pandas numpy tqdm
!pip install -q pypdf python-magic

print("‚úÖ All packages installed successfully")

#**Mount Google Drive and Setup Project**


In [None]:
from google.colab import drive
import os
from pathlib import Path

drive.mount('/content/drive')

project_path = '/content/drive/MyDrive/MedAssist_RAG'
os.makedirs(project_path, exist_ok=True)
os.chdir(project_path)

folders = ['data/raw', 'data/processed', 'data/embeddings', 'src', 'logs', 'models']
for folder in folders:
    Path(folder).mkdir(parents=True, exist_ok=True)

print(f"‚úÖ Project initialized at: {os.getcwd()}")

#**Paste Pubmed Email and Configure API Keys**


In [None]:
import os
from getpass import getpass

PUBMED_EMAIL = input("üìß Enter email for PubMed API: ")
os.environ['PUBMED_EMAIL'] = PUBMED_EMAIL

print("\nüîë Choose LLM Provider:")
print("1. Groq (Recommended - Free & Fast)")
print("2. Google Gemini")

choice = input("\nEnter choice (1-2): ")

if choice == "1":
    GROQ_API_KEY = getpass("üîë Enter Groq API key: ")
    os.environ['GROQ_API_KEY'] = GROQ_API_KEY
    os.environ['LLM_PROVIDER'] = "groq"
    print("‚úÖ Groq configured")
elif choice == "2":
    GEMINI_API_KEY = getpass("üîë Enter Gemini API key: ")
    os.environ['GEMINI_API_KEY'] = GEMINI_API_KEY
    os.environ['LLM_PROVIDER'] = "gemini"
    print("‚úÖ Gemini configured")

print("\n‚úÖ Configuration complete")

#**Data Collection Configuration**


In [None]:
import os

TARGET_PAPERS = 20000
PAPERS_PER_TOPIC = 500

print(f"üìä Collection Strategy:")
print(f"   Target: {TARGET_PAPERS:,} papers")
print(f"   Per topic: {PAPERS_PER_TOPIC}")

os.environ['TARGET_PAPERS'] = str(TARGET_PAPERS)
os.environ['PAPERS_PER_TOPIC'] = str(PAPERS_PER_TOPIC)

#**Define Medical Domains**


In [None]:
MEDICAL_DOMAINS = {
    "cardiovascular": [
        "myocardial infarction", "heart failure", "atrial fibrillation",
        "coronary artery disease", "hypertension", "stroke", "angina"
    ],
    "endocrine": [
        "diabetes mellitus type 1", "diabetes mellitus type 2",
        "thyroid disorders", "metabolic syndrome", "obesity"
    ],
    "neurological": [
        "alzheimer disease", "parkinson disease", "multiple sclerosis",
        "epilepsy", "migraine", "dementia"
    ],
    "respiratory": [
        "asthma", "chronic obstructive pulmonary disease", "pneumonia",
        "tuberculosis", "lung cancer", "COVID-19"
    ],
    "gastrointestinal": [
        "inflammatory bowel disease", "crohn disease", "ulcerative colitis",
        "hepatitis", "cirrhosis", "pancreatitis"
    ],
    "oncology": [
        "lung cancer", "breast cancer", "colorectal cancer",
        "chemotherapy", "radiation therapy", "immunotherapy"
    ],
    "psychiatric": [
        "major depressive disorder", "anxiety disorders", "schizophrenia",
        "bipolar disorder", "ADHD", "autism spectrum disorder"
    ],
    "infectious": [
        "HIV AIDS", "COVID-19", "influenza", "tuberculosis",
        "hepatitis B", "hepatitis C", "sepsis"
    ],
    "medications": [
        "antibiotics", "antihypertensives", "antidiabetic agents",
        "statins", "anticoagulants", "beta blockers", "ACE inhibitors"
    ]
}

total_topics = sum(len(topics) for topics in MEDICAL_DOMAINS.values())
print(f"üìö Medical domains configured: {len(MEDICAL_DOMAINS)} categories, {total_topics} topics")

#**Data Collection from PubMed**


In [None]:
from Bio import Entrez, Medline
import json
import time
from tqdm import tqdm
import pandas as pd
from pathlib import Path
import re

Entrez.email = os.getenv('PUBMED_EMAIL')

def search_pubmed(query, max_results=2000):
    search_term = f"({query}) AND 2010:2024[PDAT] AND English[LA]"
    try:
        handle = Entrez.esearch(db="pubmed", term=search_term, retmax=max_results, sort="relevance")
        record = Entrez.read(handle)
        handle.close()
        return record["IdList"]
    except:
        return []

def fetch_papers(pmid_list):
    papers = []
    for i in tqdm(range(0, len(pmid_list), 100), desc="Fetching", leave=False):
        batch = pmid_list[i:i+100]
        try:
            handle = Entrez.efetch(db="pubmed", id=batch, rettype="medline", retmode="text")
            records = Medline.parse(handle)
            for record in records:
                abstract = record.get("AB", "")
                if abstract and len(abstract) > 200:
                    papers.append({
                        "pmid": record.get("PMID", ""),
                        "title": record.get("TI", ""),
                        "abstract": abstract,
                        "authors": ", ".join(record.get("AU", [])),
                        "journal": record.get("TA", ""),
                        "publication_date": record.get("DP", ""),
                        "url": f"https://pubmed.ncbi.nlm.nih.gov/{record.get('PMID', '')}/"
                    })
            handle.close()
            time.sleep(0.3)
        except:
            continue
    return papers

print("üîç Starting data collection...")

all_papers = []
target = int(os.getenv('TARGET_PAPERS', '20000'))
per_topic = int(os.getenv('PAPERS_PER_TOPIC', '500'))

for category, topics in MEDICAL_DOMAINS.items():
    if len(all_papers) >= target:
        break
    print(f"\nüìñ {category.upper()}")
    for topic in tqdm(topics):
        pmids = search_pubmed(topic, per_topic)
        papers = fetch_papers(pmids)
        for p in papers:
            p['category'] = category
            p['topic'] = topic
        all_papers.extend(papers)
        time.sleep(0.5)
    print(f"‚úÖ {len([p for p in all_papers if p['category']==category])} papers")

seen = set()
unique_papers = []
for paper in tqdm(all_papers, desc="Deduplicating"):
    pmid = paper.get("pmid")
    if pmid and pmid not in seen:
        seen.add(pmid)
        unique_papers.append(paper)

output_path = 'data/raw/medical_papers.json'
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(unique_papers, f, indent=2, ensure_ascii=False)

df = pd.DataFrame(unique_papers)
df.to_csv('data/raw/medical_papers.csv', index=False)

print(f"\n‚úÖ Collected {len(unique_papers):,} unique papers")
print(f"üíæ Saved to {output_path}")

#**Text Preprocessing and Chunking**


In [None]:
import json
import re
from pathlib import Path
from tqdm import tqdm

CHUNK_SIZE = 800
CHUNK_OVERLAP = 100

with open('/content/drive/MyDrive/MedAssist_RAG/data/raw/pubmed_papers.json', 'r') as f:
    papers = json.load(f)

def clean_text(text):
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s\-\( \[\].;,:/%]', '', text)
    return text.strip()

def estimate_tokens(text):
    return len(text) // 4

def chunk_text(text, chunk_size=800, overlap=100):
    text = clean_text(text)
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current = []
    length = 0

    for sent in sentences:
        sent_len = estimate_tokens(sent)
        if length + sent_len > chunk_size and current:
            chunks.append(' '.join(current))
            current = current[-2:] if len(current) >= 2 else []
            length = estimate_tokens(' '.join(current))
        current.append(sent)
        length += sent_len

    if current:
        chunks.append(' '.join(current))
    return chunks

all_chunks = []
chunk_id = 0

for paper in tqdm(papers, desc="Chunking"):
    title = paper.get('title', '')
    abstract = paper.get('abstract', '')
    full_text = f"Title: {title}\n\nAbstract: {abstract}"

    text_chunks = chunk_text(full_text, CHUNK_SIZE, CHUNK_OVERLAP)

    for i, chunk in enumerate(text_chunks):
        all_chunks.append({
            'chunk_id': f"chunk_{chunk_id}",
            'text': chunk,
            'metadata': {
                'title': title,
                'pmid': paper.get('pmid', ''),
                'category': paper.get('category', ''),
                'topic': paper.get('topic', ''),
                'url': paper.get('url', ''),
                'chunk_index': i
            }
        })
        chunk_id += 1

output_path = 'data/processed/chunks.json'
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(all_chunks, f, indent=2, ensure_ascii=False)

print(f"‚úÖ Created {len(all_chunks):,} chunks from {len(papers):,} papers")
print(f"üíæ Saved to {output_path}")

#**Generate Embeddings**


In [None]:
import numpy as np
from sentence_transformers import SentenceTransformer
import torch
import pickle
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)

with open('/content/drive/MyDrive/MedAssist_RAG/data/processed/chunks.json', 'r') as f:
    chunks = json.load(f)

texts = [chunk['text'] for chunk in chunks]

print(f"Generating embeddings for {len(texts):,} chunks...")
embeddings = model.encode(
    texts,
    batch_size=32,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

output_dir = Path('data/embeddings')
np.save(output_dir / 'embeddings.npy', embeddings)

with open(output_dir / 'chunks_metadata.pkl', 'wb') as f:
    pickle.dump(chunks, f)

print(f"‚úÖ Embeddings generated: {embeddings.shape}")
print(f"üíæ Saved to {output_dir}")

#**Build FAISS Vector Store**


In [None]:
import faiss
import numpy as np
import pickle

embeddings = np.load('/content/drive/MyDrive/MedAssist_RAG/data/embeddings/embeddings.npy')
with open('/content/drive/MyDrive/MedAssist_RAG/data/embeddings/chunks_metadata.pkl', 'rb') as f:
    chunks = pickle.load(f)

dimension = embeddings.shape[1]
num_vectors = embeddings.shape[0]

if num_vectors < 10000:
    index = faiss.IndexFlatIP(dimension)
    index_type = "Flat"
elif num_vectors < 100000:
    nlist = min(100, num_vectors // 100)
    quantizer = faiss.IndexFlatIP(dimension)
    index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)
    index.train(embeddings.astype('float32'))
    index.nprobe = min(10, nlist)
    index_type = "IVF"
else:
    nlist = 200
    quantizer = faiss.IndexFlatIP(dimension)
    index = faiss.IndexIVFPQ(quantizer, dimension, nlist, 8, 8)
    index.train(embeddings.astype('float32'))
    index.nprobe = 10
    index_type = "IVFPQ"

index.add(embeddings.astype('float32'))

faiss.write_index(index, 'data/embeddings/faiss_index.bin')

retriever_code = '''
import numpy as np
import faiss
import pickle

class MedicalRetriever:
    def __init__(self):
        self.index = faiss.read_index("data/embeddings/faiss_index.bin")
        with open("data/embeddings/chunks_metadata.pkl", "rb") as f:
            self.chunks = pickle.load(f)
        self.embeddings = np.load("data/embeddings/embeddings.npy")

    def search(self, query_embedding, top_k=5):
        if len(query_embedding.shape) == 1:
            query_embedding = query_embedding.reshape(1, -1)
        distances, indices = self.index.search(query_embedding.astype("float32"), top_k)
        results = []
        for idx, score in zip(indices[0], distances[0]):
            if idx < len(self.chunks):
                results.append({
                    "score": float(score),
                    "text": self.chunks[idx]["text"],
                    "metadata": self.chunks[idx]["metadata"]
                })
        return results
'''

with open('src/retriever.py', 'w') as f:
    f.write(retriever_code)

print(f"‚úÖ FAISS index built: {index_type}, {index.ntotal:,} vectors")

#**Create RAG Pipeline**


In [None]:
from sentence_transformers import SentenceTransformer
from groq import Groq
import sys
import os

# Ensure the absolute path to 'src' is in sys.path
src_path = os.path.abspath('src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# If 'retriever' module was loaded previously in a bad state, remove it from sys.modules
if 'retriever' in sys.modules:
    del sys.modules['retriever']

from retriever import MedicalRetriever

MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
LLM_MODEL = "llama-3.3-70b-versatile"
TEMPERATURE = 0.1
MAX_TOKENS = 2000

SYSTEM_PROMPT = '''You are MedAssist, an expert medical research assistant.

Provide comprehensive, evidence-based answers with:
1. Clear medical terminology with explanations
2. Citations to sources as [Source N]
3. Detailed mechanisms and treatments
4. Professional medical language

DISCLAIMER: For educational purposes only. Consult healthcare professionals for medical advice.'''

embedding_model = SentenceTransformer(MODEL_NAME)
retriever = MedicalRetriever()
groq_client = Groq(api_key=GROQ_API_KEY)

class MedicalRAG:
    def __init__(self, retriever, embedding_model, groq_client, llm_model):
        self.retriever = retriever
        self.embedding_model = embedding_model
        self.groq_client = groq_client
        self.llm_model = llm_model

    def query(self, question, top_k=10):
        query_embedding = self.embedding_model.encode([question], normalize_embeddings=True)
        results = self.retriever.search(query_embedding, top_k=top_k)

        context_parts = []
        for i, result in enumerate(results, 1):
            meta = result['metadata']
            context_parts.append(f"[Source {i}: {meta['title']}]\n{result['text']}")
        context = "\n\n".join(context_parts)

        prompt = f"""Based on medical research literature, provide a comprehensive answer.

CONTEXT:
{context}

QUESTION: {question}

Provide a detailed answer with source citations [Source N]."""

        response = self.groq_client.chat.completions.create(
            model=self.llm_model,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt}
            ],
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS
        )

        return {
            "answer": response.choices[0].message.content,
            "sources": [{"source_id": i, "title": r['metadata']['title'], "url": r['metadata']['url']}
                       for i, r in enumerate(results, 1)]
        }

rag_system = MedicalRAG(retriever, embedding_model, groq_client, LLM_MODEL)

def ask_medical_question(question, top_k=10):
    return rag_system.query(question, top_k=top_k)

print("‚úÖ RAG pipeline ready")

#**Launch Gradio Interface**


In [None]:
import gradio as gr

def answer_question(question, num_sources):
    if not question.strip():
        return "‚ö†Ô∏è Please enter a question", "", ""

    try:
        response = ask_medical_question(question, top_k=int(num_sources))

        answer = f"## üí° Answer\n\n{response['answer']}\n\n---\n\n**Sources:** {len(response['sources'])} papers"

        sources = "## üìö Sources\n\n"
        for source in response['sources'][:int(num_sources)]:
            sources += f"**[{source['source_id']}]** {source['title']}\n[View Paper]({source['url']})\n\n"

        disclaimer = """## ‚ö†Ô∏è Disclaimer

**Educational purposes only.** Not a substitute for professional medical advice.

Always consult qualified healthcare professionals for medical decisions."""

        return answer, sources, disclaimer
    except Exception as e:
        return f"‚ùå Error: {str(e)}", "", ""

sample_questions = {
    "ü´Ä Cardiovascular": ["What are treatments for Type 2 diabetes?", "Explain heart failure pathophysiology"],
    "üíä Pharmacology": ["What are statin side effects?", "Explain metformin mechanism"],
    "üß¨ Oncology": ["How does chemotherapy work?", "What is immunotherapy?"],
    "üß† Neurology": ["What causes Alzheimer's disease?", "Explain Parkinson's disease"],
}

with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), title="MedAssist") as demo:
    gr.HTML("""
        <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 30px; border-radius: 15px; color: white; text-align: center;">
            <h1 style="font-size: 3em; margin: 0;">üè• MedAssist</h1>
            <p style="font-size: 1.3em; margin-top: 10px;">AI-Powered Medical Research Assistant</p>
            <p>20,000+ Papers ‚Ä¢ Powered by Groq AI</p>
        </div>
    """)

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## üìù Ask Your Question")
            question = gr.Textbox(label="", placeholder="Type medical question...", lines=4)
            num_sources = gr.Slider(3, 10, 5, step=1, label="üìä Number of Sources")
            submit_btn = gr.Button("üîç Get Answer", variant="primary", size="lg")

            gr.Markdown("### üí° Sample Questions")
            for category, questions in sample_questions.items():
                with gr.Accordion(category, open=False):
                    for q in questions:
                        gr.Button(q, size="sm").click(fn=lambda x=q: x, outputs=question)

        with gr.Column(scale=1):
            gr.HTML("""
                <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 20px; border-radius: 10px; color: white;">
                    <h3>üìä Statistics</h3>
                    <p><strong>Papers:</strong> 20,000+</p>
                    <p><strong>Chunks:</strong> 48,955</p>
                    <p><strong>Model:</strong> Llama 3.3 70B</p>
                    <p><strong>Speed:</strong> 800 tokens/sec</p>
                </div>
            """)

    with gr.Tabs():
        with gr.Tab("üí° Answer"):
            answer_output = gr.Markdown()
        with gr.Tab("üìö Sources"):
            sources_output = gr.Markdown()
        with gr.Tab("‚ö†Ô∏è Disclaimer"):
            disclaimer_output = gr.Markdown()

    submit_btn.click(
        fn=answer_question,
        inputs=[question, num_sources],
        outputs=[answer_output, sources_output, disclaimer_output]
    )

    gr.Examples(
        examples=[["What are treatments for Type 2 diabetes?", 5], ["Explain heart failure", 5]],
        inputs=[question, num_sources]
    )

demo.launch(share=True, debug=True)