In [None]:
# Install required packages for Google Colab
%pip install transformers torch sentence-transformers --quiet
%pip install supabase python-dotenv --quiet
%pip install sacremoses --quiet
%pip install fastapi uvicorn pydantic python-multipart

In [None]:
import os
import json
import logging
import traceback
from datetime import datetime
from typing import List, Dict, Any, Optional
import torch
from getpass import getpass

# AI Models
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoModelForImageTextToText

# Supabase
from supabase import create_client

# FastAPI
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import uvicorn

# Pydantic models for validation
from pydantic import BaseModel, Field

# Colab secrets
from google.colab import userdata

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("üì¶ All packages imported successfully!")
print(f"üïê RAG session started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"üîß Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

# Available medical models
MEDICAL_MODELS = {
    "1": {
        "name": "OpenBioLLM-8B",
        "path": "aaditya/OpenBioLLM-Llama3-8B",
        "description": "8B parameter medical LLM optimized for biomedical tasks",
        "type": "causal"
    },
    "2": {
        "name": "Med42-v2-8B",
        "path": "m42-health/Llama3-Med42-8B",
        "description": "8B parameter medical model from M42 Health based on Llama3",
        "type": "causal"
    }
}

print("\nü§ñ Available Medical Models:")
for key, model in MEDICAL_MODELS.items():
    print(f"  {key}. {model['name']} - {model['description']}")

print("\nPlease select a model by entering the number (1 or 2):")
print("Note: MedGemma-4B is available in a separate notebook (query_rag_medgemma.ipynb)")


In [None]:
# Pydantic models for request/response validation
class AskRequest(BaseModel):
    question: str = Field(..., min_length=1, max_length=1000, description="The medical question to ask")
    session_id: Optional[str] = Field(default="default", description="Chat session identifier")
    history: Optional[List[dict]] = Field(default=[], description="Previous chat history")

class AskResponse(BaseModel):
    response: str = Field(..., description="AI generated response")
    sources: List[dict] = Field(..., description="Source documents used")
    chat_history: List[dict] = Field(..., description="Updated chat history")
    session_id: str = Field(..., description="Session identifier")
    mockMode: bool = Field(default=False, description="Whether response is from mock mode")
    metadata: dict = Field(..., description="Response metadata")

class EmbedRequest(BaseModel):
    text: str = Field(..., min_length=1, max_length=5000, description="Text to embed")

class EmbedResponse(BaseModel):
    embedding: List[float] = Field(..., description="Generated embedding vector")
    dimensions: int = Field(..., description="Embedding dimensions")
    model: str = Field(..., description="Model used for embedding")
    device: str = Field(..., description="Device used for embedding")

class GenerateRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=1000, description="Query text")
    context: str = Field(default="", description="Context information")
    history: List[dict] = Field(default=[], description="Chat history")
    max_tokens: int = Field(default=200, ge=1, le=1000, description="Maximum tokens to generate")
    temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Generation temperature")

class GenerateResponse(BaseModel):
    answer: str = Field(..., description="Generated answer")
    model: str = Field(..., description="Model used for generation")
    model_path: str = Field(..., description="Model path")
    context_used: bool = Field(..., description="Whether context was used")
    history_used: bool = Field(..., description="Whether history was used")

class QueryRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=1000, description="Search query")
    top_k: int = Field(default=5, ge=1, le=20, description="Number of documents to retrieve")

class QueryResponse(BaseModel):
    documents: List[dict] = Field(..., description="Retrieved documents")
    total_found: int = Field(..., description="Total documents found")
    query: str = Field(..., description="Original query")

class HealthResponse(BaseModel):
    status: str = Field(..., description="System status")
    model: str = Field(..., description="Active model name")
    model_path: str = Field(..., description="Model path")
    embedding_device: str = Field(..., description="Embedding device")
    database: str = Field(..., description="Database type")
    documents_in_db: int = Field(..., description="Number of documents in database")
    embeddings_in_db: int = Field(..., description="Number of embeddings in database")
    rag_system: str = Field(..., description="RAG system type")
    chat_support: bool = Field(..., description="Chat support status")
    active_sessions: int = Field(..., description="Number of active sessions")
    gpu_memory: dict = Field(..., description="GPU memory information")

print("‚úÖ Pydantic models defined for request/response validation")

In [None]:
# Memory Management and Background Processing
import gc
import psutil
from typing import Optional

class MemoryManager:
    """Manages GPU and CPU memory efficiently"""
    
    def __init__(self):
        self.gpu_memory_threshold = 0.8  # 80% GPU memory usage
        self.cpu_memory_threshold = 0.9  # 90% CPU memory usage
    
    async def check_gpu_memory(self) -> bool:
        """Check if GPU memory is available"""
        if torch.cuda.is_available():
            try:
                allocated = torch.cuda.memory_allocated()
                total = torch.cuda.get_device_properties(0).total_memory
                usage_ratio = allocated / total
                return usage_ratio < self.gpu_memory_threshold
            except:
                return False
        return True
    
    async def check_cpu_memory(self) -> bool:
        """Check if CPU memory is available"""
        try:
            memory = psutil.virtual_memory()
            return memory.percent < (self.cpu_memory_threshold * 100)
        except:
            return True
    
    async def clear_gpu_cache(self):
        """Clear GPU memory cache"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
            print("üßπ GPU cache cleared")
    
    async def clear_cpu_cache(self):
        """Clear CPU memory cache"""
        gc.collect()
        print("üßπ CPU cache cleared")
    
    async def get_memory_status(self) -> dict:
        """Get current memory status"""
        status = {
            "gpu_available": torch.cuda.is_available(),
            "cpu_memory_percent": psutil.virtual_memory().percent
        }
        
        if torch.cuda.is_available():
            try:
                allocated = torch.cuda.memory_allocated() / 1024**3
                total = torch.cuda.get_device_properties(0).total_memory / 1024**3
                status.update({
                    "gpu_allocated_gb": f"{allocated:.2f}",
                    "gpu_total_gb": f"{total:.2f}",
                    "gpu_usage_percent": f"{(allocated/total)*100:.1f}%"
                })
            except:
                status["gpu_info"] = "unavailable"
        
        return status

class BackgroundTaskManager:
    """Manages background tasks and queues"""
    
    def __init__(self):
        self.embedding_queue = asyncio.Queue(maxsize=100)
        self.generation_queue = asyncio.Queue(maxsize=50)
        self.memory_manager = MemoryManager()
        self.running = False
    
    async def start_background_workers(self):
        """Start background workers for embedding and generation"""
        self.running = True
        
        # Start embedding worker
        asyncio.create_task(self._embedding_worker())
        
        # Start generation worker
        asyncio.create_task(self._generation_worker())
        
        print("‚úÖ Background workers started")
    
    async def stop_background_workers(self):
        """Stop background workers"""
        self.running = False
        print("ÔøΩÔøΩ Background workers stopped")
    
    async def _embedding_worker(self):
        """Background worker for embedding tasks"""
        while self.running:
            try:
                # Check memory before processing
                if not await self.memory_manager.check_cpu_memory():
                    await asyncio.sleep(1)
                    continue
                
                # Process embedding tasks
                if not self.embedding_queue.empty():
                    task = await self.embedding_queue.get()
                    # Process task here
                    self.embedding_queue.task_done()
                
                await asyncio.sleep(0.1)
                
            except Exception as e:
                print(f"‚ùå Embedding worker error: {str(e)}")
                await asyncio.sleep(1)
    
    async def _generation_worker(self):
        """Background worker for generation tasks"""
        while self.running:
            try:
                # Check GPU memory before processing
                if not await self.memory_manager.check_gpu_memory():
                    await asyncio.sleep(1)
                    continue
                
                # Process generation tasks
                if not self.generation_queue.empty():
                    task = await self.generation_queue.get()
                    # Process task here
                    self.generation_queue.task_done()
                
                await asyncio.sleep(0.1)
                
            except Exception as e:
                print(f"‚ùå Generation worker error: {str(e)}")
                await asyncio.sleep(1)

# Initialize memory and background managers
print("üóÉÔ∏è Initializing memory and background managers...")
memory_manager = MemoryManager()
background_manager = BackgroundTaskManager()

# Start background workers
await background_manager.start_background_workers()

print("‚úÖ Memory and background managers initialized!")

In [None]:
# Model selection and loading
print("ü§ñ Select your medical model:")
model_choice = input("Enter model number (1 or 2): ").strip()

if model_choice not in MEDICAL_MODELS:
    print(f"‚ùå Invalid choice '{model_choice}'. Defaulting to OpenBioLLM-8B (option 1)")
    model_choice = "1"

selected_model = MEDICAL_MODELS[model_choice]
model_name = selected_model["name"]
model_path = selected_model["path"]
model_type = selected_model["type"]

print(f"üß† Loading {model_name} for medical text generation...")
print(f"üì¶ Model path: {model_path}")
print(f"üîß Model type: {model_type}")

# Check GPU memory before loading
if torch.cuda.is_available():
    print(f"üîç GPU Memory before loading: {torch.cuda.memory_allocated()/1024**3:.2f} GB allocated, {torch.cuda.memory_reserved()/1024**3:.2f} GB reserved")

try:
    # Standard causal LM handling (models 1 and 2 are both causal models)
    print("üî§ Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    
    print("üß† Loading model (this may take a few minutes)...")
    medical_model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",
        trust_remote_code=True
    )
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Set pad token if not available
    if hasattr(tokenizer, 'pad_token') and tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer.pad_token = tokenizer.unk_token
    
    # Store model info globally
    MODEL_INFO = {
        "name": model_name,
        "path": model_path,
        "choice": model_choice,
        "type": model_type
    }
    
    # Check GPU memory after loading
    if torch.cuda.is_available():
        print(f"üîç GPU Memory after loading: {torch.cuda.memory_allocated()/1024**3:.2f} GB allocated, {torch.cuda.memory_reserved()/1024**3:.2f} GB reserved")
    
    print(f"‚úÖ {model_name} loaded and ready on {device}")
    print(f"üéØ Selected model: {model_name}")
    
except Exception as e:
    print(f"‚ùå Error loading {model_name}: {str(e)}")
    print("üîÑ Falling back to a smaller model...")
    
    # Fallback to a more reliable model
    fallback_path = "microsoft/DialoGPT-medium"
    print(f"üîÑ Loading fallback model: {fallback_path}")
    
    tokenizer = AutoTokenizer.from_pretrained(fallback_path)
    medical_model = AutoModelForCausalLM.from_pretrained(
        fallback_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    MODEL_INFO = {
        "name": "DialoGPT-medium (Fallback)",
        "path": fallback_path,
        "choice": "fallback"
    }
    
    print(f"‚úÖ Fallback model loaded successfully")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Pre-load the embedding model to avoid CUDA errors during Flask requests
print("üîç Pre-loading embedding model to avoid CUDA memory conflicts...")
try:
    from sentence_transformers import SentenceTransformer
    
    # Check if there's enough GPU memory for the embedding model
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(0).total_memory
        allocated_memory = torch.cuda.memory_allocated()
        free_memory = total_memory - allocated_memory
        
        print(f"üîç Available GPU memory: {free_memory/1024**3:.2f} GB")
        
        # If less than 2GB free, use CPU for embeddings
        if free_memory < 2 * 1024**3:  # 2GB threshold
            print("‚ö†Ô∏è Limited GPU memory - using CPU for embedding model")
            embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings', device='cpu')
            EMBEDDING_DEVICE = 'cpu'
        else:
            print("‚úÖ Sufficient GPU memory - using GPU for embedding model")
            embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings')
            EMBEDDING_DEVICE = 'cuda'
    else:
        print("üì± No GPU available - using CPU for embedding model")
        embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings', device='cpu')
        EMBEDDING_DEVICE = 'cpu'
    
    print(f"‚úÖ Embedding model loaded on {EMBEDDING_DEVICE}")
    
except Exception as e:
    print(f"‚ö†Ô∏è Error pre-loading embedding model: {str(e)}")
    print("üîÑ Will load embedding model on-demand with CPU fallback")
    embedding_model = None
    EMBEDDING_DEVICE = 'cpu'

# Setup Supabase connection using Colab secrets
print("üóÑÔ∏è Setting up Supabase connection using Colab secrets...")
print("üìã Required Colab secrets:")
print("   1. SUPABASE_URL - Your project URL (e.g., https://abc123.supabase.co)")
print("   2. SUPABASE_SERVICE_ROLE_KEY - Your service role key (NOT anon key)")
print("   3. NGROK_AUTH_TOKEN - Your ngrok authentication token")
print("")
print("üîë To set these secrets:")
print("   1. Click the üîë key icon in the left sidebar")
print("   2. Add the three secrets listed above")
print("   3. Re-run this cell")
print("")

try:
    supabase_url = userdata.get('SUPABASE_URL')
    supabase_key = userdata.get('SUPABASE_SERVICE_ROLE_KEY')
    ngrok_token = userdata.get('NGROK_AUTH_TOKEN')
    
    print("‚úÖ Successfully retrieved secrets from Colab")
    
except Exception as e:
    print(f"‚ùå Error retrieving secrets: {str(e)}")
    print("üîß Make sure you've added the required secrets in Colab:")
    print("   ‚Ä¢ SUPABASE_URL")
    print("   ‚Ä¢ SUPABASE_SERVICE_ROLE_KEY") 
    print("   ‚Ä¢ NGROK_AUTH_TOKEN")
    raise

# Validate the inputs
if not supabase_url or not supabase_key:
    raise ValueError("‚ùå Both Supabase URL and Service Role Key are required!")

if not supabase_url.startswith('https://'):
    raise ValueError("‚ùå Supabase URL should start with 'https://'")

if not supabase_key.startswith('eyJ'):
    print("‚ö†Ô∏è WARNING: Service role keys typically start with 'eyJ'")
    print("   You might be using the anon key instead of service_role key")
    
if len(supabase_key) < 100:
    print("‚ö†Ô∏è WARNING: Service role keys are typically very long (200+ characters)")
    print("   You might be using the anon key instead of service_role key")

try:
    supabase = create_client(supabase_url, supabase_key)
    print("‚úÖ Supabase client initialized")
except Exception as e:
    print(f"‚ùå Failed to initialize Supabase client: {str(e)}")
    print("üîß Common issues:")
    print("   ‚Ä¢ Wrong API key type (use service_role, not anon)")
    print("   ‚Ä¢ Typo in URL or key")
    print("   ‚Ä¢ Key might be expired or regenerated")
    raise

# Configuration
CONFIG = {
    "top_k": 5,
    "similarity_threshold": 0.5,
    "max_context_length": 2000,
    "max_response_length": 150,
}

print(f"\n‚öôÔ∏è RAG Configuration:")
print(f"   üéØ Retrieve top {CONFIG['top_k']} similar documents")
print(f"   üìä Similarity threshold: {CONFIG['similarity_threshold']}")
print(f"   üìè Max context length: {CONFIG['max_context_length']} chars")


In [None]:
# Modular Service Architecture
import asyncio
from concurrent.futures import ThreadPoolExecutor
from abc import ABC, abstractmethod

class BaseService(ABC):
    """Base class for all services"""
    
    def __init__(self):
        self.executor = ThreadPoolExecutor(max_workers=2)
    
    async def run_in_executor(self, func, *args):
        """Run CPU-intensive tasks in thread pool"""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, func, *args)

class EmbedderService(BaseService):
    """Handles text embedding generation on CPU"""
    
    def __init__(self):
        super().__init__()
        self.model = None
        self.device = 'cpu'
    
    async def initialize(self):
        """Initialize embedding model on CPU"""
        if self.model is None:
            from sentence_transformers import SentenceTransformer
            self.model = SentenceTransformer('NeuML/pubmedbert-base-embeddings', device='cpu')
            print("‚úÖ EmbedderService: Model loaded on CPU")
    
    async def embed_text(self, text: str) -> List[float]:
        """Generate embeddings for text"""
        await self.initialize()
        
        # Run embedding on CPU thread pool
        embedding = await self.run_in_executor(self.model.encode, [text])
        return embedding[0].tolist()
    
    async def embed_batch(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings for multiple texts"""
        await self.initialize()
        
        # Run batch embedding on CPU thread pool
        embeddings = await self.run_in_executor(self.model.encode, texts)
        return embeddings.tolist()

class RetrieverService(BaseService):
    """Handles document retrieval from Supabase"""
    
    def __init__(self, supabase_client, embedder_service: EmbedderService):
        super().__init__()
        self.supabase = supabase_client
        self.embedder = embedder_service
        self.top_k = 5
        self.similarity_threshold = 0.5
    
    async def retrieve_documents(self, query: str, top_k: int = None) -> List[dict]:
        """Retrieve relevant documents using vector search"""
        top_k = top_k or self.top_k
        
        # Generate query embedding
        query_embedding = await self.embedder.embed_text(query)
        
        # Search Supabase
        try:
            result = self.supabase.rpc('search_embeddings', {
                'query_embedding': query_embedding,
                'match_threshold': self.similarity_threshold,
                'match_count': top_k
            }).execute()
            
            if result.data:
                documents = []
                for i, doc in enumerate(result.data):
                    documents.append({
                        'content': doc.get('chunk_content', ''),
                        'similarity_score': doc.get('similarity', 0.0),
                        'metadata': {
                            'title': doc.get('title', 'Medical Document'),
                            'source': doc.get('source', 'unknown'),
                            'topic': doc.get('topic', 'general'),
                            'document_type': doc.get('document_type', 'unknown'),
                            'document_id': doc.get('document_id', '')
                        },
                        'rank': i + 1,
                        'doc_id': doc.get('document_id', '')
                    })
                return documents
            else:
                return []
                
        except Exception as e:
            print(f"‚ùå RetrieverService error: {str(e)}")
            return []

class GeneratorService(BaseService):
    """Handles text generation using medical LLMs on GPU"""
    
    def __init__(self, model_name: str, tokenizer, medical_model):
        super().__init__()
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.medical_model = medical_model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    async def generate_response(self, prompt: str, max_tokens: int = 150) -> str:
        """Generate medical response using the selected model"""
        try:
            # Run generation on GPU thread pool
            response = await self.run_in_executor(self._generate_sync, prompt, max_tokens)
            return response
        except Exception as e:
            print(f"‚ùå GeneratorService error: {str(e)}")
            return f"I apologize, but I encountered an error: {str(e)}"
    
    def _generate_sync(self, prompt: str, max_tokens: int) -> str:
        """Synchronous generation method for thread pool"""
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}
        
        input_len = inputs['input_ids'].shape[1]
        
        generation_params = {
            "max_new_tokens": max_tokens,
            "temperature": 0.7,
            "do_sample": True,
            "repetition_penalty": 1.1,
            "top_p": 0.9
        }
        
        if hasattr(self.tokenizer, 'pad_token_id') and self.tokenizer.pad_token_id is not None:
            generation_params["pad_token_id"] = self.tokenizer.pad_token_id
        if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
            generation_params["eos_token_id"] = self.tokenizer.eos_token_id
        
        with torch.no_grad():
            outputs = self.medical_model.generate(**inputs, **generation_params)
        
        response = self.tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
        return response.strip()

class RAGController:
    """Orchestrates the complete RAG pipeline"""
    
    def __init__(self, embedder: EmbedderService, retriever: RetrieverService, generator: GeneratorService):
        self.embedder = embedder
        self.retriever = retriever
        self.generator = generator
        self.max_context_length = 2000
    
    async def process_query(self, question: str, history: List[dict] = None) -> dict:
        """Complete RAG pipeline: retrieve context and generate response"""
        try:
            print(f"üîç Processing query: {question}")
            
            # Step 1: Retrieve relevant documents
            retrieved_docs = await self.retriever.retrieve_documents(question)
            print(f"üìä Found {len(retrieved_docs)} similar documents")
            
            # Step 2: Build context from documents
            context_parts = []
            total_chars = 0
            
            for doc in retrieved_docs:
                if total_chars + len(doc['content']) <= self.max_context_length:
                    context_parts.append(f"Source: {doc['metadata']['source']}\n{doc['content']}")
                    total_chars += len(doc['content'])
                else:
                    break
            
            context = "\n\n".join(context_parts)
            print(f"ÔøΩÔøΩ Using {len(context_parts)} documents for context ({len(context)} chars)")
            
            # Step 3: Create prompt with history
            if history:
                history_context = "\n".join([f"Human: {h.get('question', '')}\nAssistant: {h.get('answer', '')}" for h in history[-3:]])
                prompt = f"""Previous conversation:
{history_context}

Context:
{context}

Current question: {question}

Answer based on the context and conversation history:"""
            else:
                prompt = f"""You are a helpful medical assistant. Use the following context to answer the question.

Context:
{context}

Question: {question}

Answer (based only on the context):"""
            
            # Step 4: Generate response
            print(f"ü§ñ Generating response using {self.generator.model_name}...")
            response = await self.generator.generate_response(prompt, 150)
            
            # Step 5: Format result
            result = {
                'query': question,
                'response': response,
                'sources': [
                    {
                        'title': doc['metadata'].get('title', 'Medical Document'),
                        'source': doc['metadata']['source'],
                        'topic': doc['metadata']['topic'],
                        'similarity': f"{doc['similarity_score']:.3f}",
                        'rank': doc['rank'],
                        'content_preview': doc['content'][:150] + "..."
                    }
                    for doc in retrieved_docs
                ],
                'metadata': {
                    'documentsUsed': len(context_parts),
                    'totalFound': len(retrieved_docs),
                    'contextLength': len(context),
                    'model': self.generator.model_name,
                    'processingTime': datetime.now().isoformat()
                }
            }
            
            return result
            
        except Exception as e:
            print(f"‚ùå RAGController error: {str(e)}")
            raise

# Initialize services in correct order
print("üîß Initializing modular services...")
embedder_service = EmbedderService()
retriever_service = RetrieverService(supabase, embedder_service)
generator_service = GeneratorService(MODEL_INFO['name'], tokenizer, medical_model)
rag_controller = RAGController(embedder_service, retriever_service, generator_service)

print("‚úÖ Modular services initialized!")
print(f"   üìç EmbedderService: CPU-based embedding generation")
print(f"   üìç RetrieverService: Supabase vector search")
print(f"   üìç GeneratorService: GPU-based text generation")
print(f"   üìç RAGController: Pipeline orchestration")

In [None]:
# Test Supabase connection and RPC functions
print("üß™ Testing Supabase connection...")
try:
    # Test basic connection
    test_result = supabase.table('medical_documents').select('count').execute()
    doc_count = len(test_result.data) if test_result.data else 0
    print(f"‚úÖ Supabase connected - Found {doc_count} documents in database")
    
    # Test RPC function availability
    print("üß™ Testing RPC functions...")
    try:
        stats_result = supabase.rpc('get_document_stats').execute()
        if stats_result.data:
            print("‚úÖ RPC functions working")
            for stat in stats_result.data[:3]:  # Show first 3 document sources
                print(f"   üìÑ {stat['source']}: {stat['count']} documents")
        else:
            print("‚ö†Ô∏è RPC function exists but returned no data")
    except Exception as rpc_error:
        print(f"‚ö†Ô∏è RPC function test failed: {str(rpc_error)}")
        print("   Vector search will use fallback method")
        
except Exception as e:
    error_str = str(e)
    print(f"‚ö†Ô∏è Supabase connection test failed: {error_str}")
    
    # Provide specific guidance based on error type
    if '401' in error_str or 'Invalid API key' in error_str:
        print("üîß AUTHENTICATION ERROR - Invalid API Key:")
        print("   ‚ùå You're using the wrong API key!")
        print("   üìã To fix this:")
        print("   1. Go to your Supabase project dashboard")
        print("   2. Settings ‚Üí API")
        print("   3. Copy the 'service_role' key (NOT anon key)")
        print("   4. The service_role key is much longer and starts with 'eyJ'")
        print("   5. Re-run Cell 3 with the correct key")
        print("")
        print("   üîç Key differences:")
        print("   ‚Ä¢ anon key: Used for client-side apps (WRONG for this notebook)")
        print("   ‚Ä¢ service_role key: Used for server-side/admin access (CORRECT)")
    elif '404' in error_str:
        print("üîß TABLE NOT FOUND:")
        print("   ‚ùå The 'medical_documents' table doesn't exist!")
        print("   üìã To fix this:")
        print("   1. Run the schema.sql in your Supabase SQL editor")
        print("   2. Or run the embed_documents.py script to create tables")
    elif 'timeout' in error_str.lower():
        print("üîß CONNECTION TIMEOUT:")
        print("   ‚ùå Can't reach Supabase servers")
        print("   üìã Check your internet connection and Supabase URL")
    else:
        print("üîß GENERAL CONNECTION ERROR:")
        print("   üìã Common fixes:")
        print("   ‚Ä¢ Double-check your Supabase URL")
        print("   ‚Ä¢ Verify you're using service_role key (not anon)")
        print("   ‚Ä¢ Check if your project is paused/suspended")
        print("   ‚Ä¢ Ensure database tables exist")
    
    print("\n   ‚ö†Ô∏è The system will continue but may have limited document retrieval")


In [None]:
def generate_medical_response(prompt: str, max_new_tokens: int = 150) -> str:
    """Generate medical response using the selected medical model"""
    try:
        print(f"ü§ñ Generating response using {MODEL_INFO['name']}...")
        
        # Check GPU memory before generation
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated()/1024**3
            reserved = torch.cuda.memory_reserved()/1024**3
            print(f"üîç GPU Memory before generation: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")
        
        # Standard causal LM handling (models 1 and 2 are both causal models)
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}
        input_len = inputs['input_ids'].shape[1]
        
        # Generation parameters optimized for medical models
        generation_params = {
            "max_new_tokens": max_new_tokens,
            "temperature": 0.7,
            "do_sample": True,
            "repetition_penalty": 1.1,
            "top_p": 0.9
        }
        
        # Set tokens for generation
        if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None:
            generation_params["pad_token_id"] = tokenizer.pad_token_id
        if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
            generation_params["eos_token_id"] = tokenizer.eos_token_id
        
        with torch.no_grad():
            try:
                outputs = medical_model.generate(
                    **inputs,
                    **generation_params
                )
            except RuntimeError as cuda_error:
                if "out of memory" in str(cuda_error).lower() or "cuda" in str(cuda_error).lower():
                    print(f"‚ö†Ô∏è CUDA memory error during generation: {str(cuda_error)}")
                    print("üßπ Clearing GPU cache and retrying...")
                    
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                    # Retry with smaller parameters
                    generation_params["max_new_tokens"] = min(generation_params["max_new_tokens"], 50)
                    print(f"üîÑ Retrying with reduced tokens: {generation_params['max_new_tokens']}")
                    
                    outputs = medical_model.generate(
                        **inputs,
                        **generation_params
                    )
                else:
                    raise cuda_error
        
        # Decode response (standard causal LM - decode only the generated part)
        response = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
        
        response = response.strip()
        
        # Clean up response for medical context
        # Remove common artifacts
        response = response.replace("</s>", "").replace("<s>", "").strip()
        
        # Basic quality check
        if not response or len(response) < 10:
            response = "I understand your question about health. Please consult with a healthcare professional for personalized medical advice."
        
        print(f"‚úÖ Response generated successfully ({len(response)} characters)")
        return response
        
    except Exception as e:
        print(f"‚ùå Generation error: {str(e)}")
        return f"I apologize, but I encountered an error processing your question. Please try rephrasing your question or consult with a healthcare professional."


In [None]:
# Enhanced FastAPI Setup with Chat History Support
app = FastAPI(
    title="WellnessGrid RAG API",
    description="Medical AI Assistant with RAG capabilities using FastAPI",
    version="2.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Configure appropriately for production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Store chat sessions in memory (in production, use Redis or database)
chat_sessions = {}

@app.post("/embed", response_model=EmbedResponse)
async def generate_embedding(request: EmbedRequest):
    """Generate embeddings for text with enhanced error handling and GPU memory management"""
    try:
        global embedding_model, EMBEDDING_DEVICE
        
        logger.info(f"üîç Generating embedding for text: {request.text[:100]}...")
        
        # Use pre-loaded embedding model or load with CPU fallback
        if embedding_model is None:
            logger.info("üì• Loading embedding model on-demand with CPU fallback...")
            try:
                from sentence_transformers import SentenceTransformer
                
                # Clear GPU cache first to free up memory
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    logger.info("üßπ GPU cache cleared")
                
                # Always use CPU for FastAPI requests to avoid CUDA conflicts
                embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings', device='cpu')
                EMBEDDING_DEVICE = 'cpu'
                logger.info("‚úÖ Embedding model loaded on CPU (safer for FastAPI)")
                
            except Exception as load_error:
                logger.error(f"‚ùå Failed to load embedding model: {str(load_error)}")
                raise HTTPException(status_code=500, detail=f"Failed to load embedding model: {str(load_error)}")
        
        # Generate embedding
        logger.info(f"üß† Generating embedding on {EMBEDDING_DEVICE}...")
        embedding = embedding_model.encode([request.text])[0].tolist()
        
        logger.info(f"‚úÖ Generated embedding with {len(embedding)} dimensions")
        
        return EmbedResponse(
            embedding=embedding,
            dimensions=len(embedding),
            model=f"PubMedBERT ({EMBEDDING_DEVICE})",
            device=EMBEDDING_DEVICE
        )
        
    except Exception as e:
        logger.error(f"‚ùå Error in embed endpoint: {str(e)}")
        logger.error(traceback.format_exc())
        raise HTTPException(status_code=500, detail=f"Embedding generation failed: {str(e)}")

@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
    """Enhanced generate endpoint using modular generator service"""
    try:
        logger.info(f"üî¨ Generating response for query: {request.query[:100]}...")
        logger.info(f"üìö Context length: {len(request.context)} characters")
        logger.info(f"ÔøΩÔøΩ Chat history: {len(request.history)} messages")
        
        # Create enhanced prompt with chat history
        if request.history:
            history_context = "\n".join([f"Human: {h.get('question', '')}\nAssistant: {h.get('answer', '')}" for h in request.history[-3:]])
            prompt = f"""Previous conversation:
{history_context}

Context:
{request.context}

Current question: {request.query}

Answer based on the context and conversation history:"""
        else:
            prompt = f"""You are a helpful medical assistant. Use the following context to answer the question.

Context:
{request.context}

Question: {request.query}

Answer (based only on the context):"""
        
        # Generate response using NEW generator service
        response = await generator_service.generate_response(prompt, request.max_tokens)
        
        return GenerateResponse(
            answer=response,
            model=MODEL_INFO['name'],
            model_path=MODEL_INFO['path'],
            context_used=len(request.context) > 0,
            history_used=len(request.history) > 0
        )
        
    except Exception as e:
        logger.error(f"‚ùå Error in generate endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}")

@app.post("/query", response_model=QueryResponse)
async def query_docs(request: QueryRequest):
    """Enhanced query endpoint using modular retriever service"""
    try:
        # Use NEW retriever service
        results = await retriever_service.retrieve_documents(request.query, request.top_k)
        
        return QueryResponse(
            documents=results,
            total_found=len(results),
            query=request.query
        )
        
    except Exception as e:
        logger.error(f"‚ùå Error in query endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Document query failed: {str(e)}")

@app.post("/ask", response_model=AskResponse)
async def ask_rag(request: AskRequest):
    """Enhanced RAG endpoint using modular services"""
    try:
        logger.info(f"ü§ñ Processing RAG query: {request.question[:100]}...")
        logger.info(f"üìù Session ID: {request.session_id}")
        
        # Use session-specific history if no history provided
        if not request.history and request.session_id in chat_sessions:
            request.history = chat_sessions[request.session_id]
        
        # Get response using NEW modular RAG controller
        result = await rag_controller.process_query(request.question, request.history)
        
        # Update chat history
        new_message = {"question": request.question, "answer": result['response']}
        if request.session_id not in chat_sessions:
            chat_sessions[request.session_id] = []
        chat_sessions[request.session_id].append(new_message)
        
        # Keep only last 10 messages to prevent memory issues
        if len(chat_sessions[request.session_id]) > 10:
            chat_sessions[request.session_id] = chat_sessions[request.session_id][-10:]
        
        return AskResponse(
            response=result['response'],
            sources=result['sources'],
            chat_history=chat_sessions[request.session_id],
            session_id=request.session_id,
            mockMode=False,
            metadata=result['metadata']
        )
        
    except Exception as e:
        logger.error(f"‚ùå Error in ask endpoint: {str(e)}")
        logger.error(traceback.format_exc())
        raise HTTPException(
            status_code=500, 
            detail=f"I apologize, but I encountered an error processing your question: {str(e)}"
        )

@app.get("/chat/history/{session_id}")
async def get_chat_history(session_id: str):
    """Get chat history for a session"""
    try:
        history = chat_sessions.get(session_id, [])
        return {
            "session_id": session_id,
            "history": history,
            "message_count": len(history)
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/chat/clear/{session_id}")
async def clear_chat_history(session_id: str):
    """Clear chat history for a session"""
    try:
        if session_id in chat_sessions:
            del chat_sessions[session_id]
        
        return {
            "session_id": session_id,
            "cleared": True
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Enhanced health check with memory monitoring"""
    try:
        # Get memory status
        memory_status = await memory_manager.get_memory_status()
        
        # Test Supabase connection
        test_result = supabase.table('medical_documents').select('count').execute()
        doc_count = len(test_result.data) if test_result.data else 0
        
        # Test embeddings
        embed_result = supabase.table('document_embeddings').select('count').execute()
        embed_count = len(embed_result.data) if embed_result.data else 0
        
        return HealthResponse(
            status="healthy",
            model=MODEL_INFO['name'],
            model_path=MODEL_INFO['path'],
            embedding_device=EMBEDDING_DEVICE if 'EMBEDDING_DEVICE' in globals() else 'unknown',
            database="Supabase + pgvector",
            documents_in_db=doc_count,
            embeddings_in_db=embed_count,
            rag_system="enhanced_modular",
            chat_support=True,
            active_sessions=len(chat_sessions),
            gpu_memory=memory_status
        )
    except Exception as e:
        raise HTTPException(status_code=503, detail=str(e))

@app.get("/status")
async def status():
    """Detailed status endpoint"""
    try:
        return {
            "timestamp": datetime.now().isoformat(),
            "models": {
                "selected_model": MODEL_INFO['name'] if 'MODEL_INFO' in globals() else "not_selected",
                "medical_model": "loaded" if 'medical_model' in globals() else "not_loaded",
                "pubmedbert": "available"
            },
            "database": {
                "connected": True,
                "url": supabase_url[:30] + "..." if supabase_url else "not_set"
            },
            "config": CONFIG,
            "memory": {
                "active_sessions": len(chat_sessions),
                "session_ids": list(chat_sessions.keys())
            }
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

print("üåê Enhanced FastAPI endpoints configured:")
print("  ‚úÖ POST /embed - Generate embeddings with enhanced error handling")
print(f"  ‚úÖ POST /generate - Generate text with {MODEL_INFO['name']} + chat history")
print("  ‚úÖ POST /ask - Enhanced RAG endpoint with chat history")
print("  ‚úÖ GET /health - Enhanced health check")
print("  ‚úÖ POST /query - Enhanced document query")
print("  ‚úÖ GET /chat/history/{session_id} - Get chat history")
print("  ‚úÖ POST /chat/clear/{session_id} - Clear chat history")
print("  ‚úÖ GET /status - Detailed status information")
print("  ‚úÖ GET /docs - Auto-generated API documentation")
print("  ‚úÖ GET /redoc - Alternative API documentation")
print(f"‚úÖ Enhanced FastAPI server ready with {MODEL_INFO['name']} and chat history support!")

In [None]:
# Test the RAG system
print("ÔøΩÔøΩ Testing RAG system with sample question...")

test_question = "What are the symptoms of diabetes?"
try:
    print(f"üîç Testing query: {test_question}")
    # Use NEW modular controller
    result = await rag_controller.process_query(test_question)
    
    print("\n" + "=" * 80)
    print(f"‚ùì QUESTION: {result['query']}")
    print("=" * 80)
    
    print(f"\nÔøΩÔøΩ AI RESPONSE:")
    print(f"{result['response']}")
    
    print(f"\nÔøΩÔøΩ SOURCES ({result['metadata']['documentsUsed']} documents):")
    if result['sources']:
        for i, source in enumerate(result['sources'], 1):
            print(f"   {i}. {source['title']} - {source['source']}")
            print(f"      üìä Similarity: {source['similarity']}")
            print(f"      ÔøΩÔøΩ Preview: {source['content_preview']}")
            print()
    else:
        print("   ‚ö†Ô∏è No sources found - this could indicate:")
        print("   ‚Ä¢ No documents in database yet")
        print("   ‚Ä¢ Similarity threshold too high")
        print("   ‚Ä¢ RPC function needs adjustment")
    
    print(f"\nüìä Metadata:")
    print(f"   üîß Model: {result['metadata']['model']}")
    print(f"   üìÑ Documents Used: {result['metadata']['documentsUsed']}")
    print(f"   üéØ Total Found: {result['metadata']['totalFound']}")
    
    print("‚úÖ RAG system test completed!")
    
except Exception as e:
    print(f"‚ö†Ô∏è RAG test failed: {str(e)}")
    print("   This might be normal if:")
    print("   ‚Ä¢ Supabase connection needs adjustment")
    print("   ‚Ä¢ No documents have been embedded yet")
    print("   ‚Ä¢ RPC function is not deployed")
    print("   The FastAPI server will still start and you can test via the API")

In [None]:
# GPU Memory Management Utilities
def get_gpu_memory_info():
    """Get detailed GPU memory information"""
    if not torch.cuda.is_available():
        return "No GPU available"
    
    device = torch.cuda.current_device()
    allocated = torch.cuda.memory_allocated(device) / 1024**3
    reserved = torch.cuda.memory_reserved(device) / 1024**3
    total = torch.cuda.get_device_properties(device).total_memory / 1024**3
    free = total - allocated
    
    return {
        "allocated": f"{allocated:.2f} GB",
        "reserved": f"{reserved:.2f} GB", 
        "total": f"{total:.2f} GB",
        "free": f"{free:.2f} GB",
        "percentage_used": f"{(allocated/total)*100:.1f}%"
    }

def clear_gpu_memory():
    """Clear GPU memory cache"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("üßπ GPU cache cleared")
    else:
        print("üì± No GPU to clear")

# Show current GPU memory status
print("üîç Current GPU Memory Status:")
memory_info = get_gpu_memory_info()
if isinstance(memory_info, dict):
    for key, value in memory_info.items():
        print(f"   {key}: {value}")
else:
    print(f"   {memory_info}")

# Add GPU memory management to Flask health endpoint
print("‚úÖ GPU memory utilities ready!")
print("   Use get_gpu_memory_info() to check memory")
print("   Use clear_gpu_memory() to free cache")


In [None]:
# ngrok setup and FastAPI server startup
from pyngrok import ngrok

# Quick pre-check to give immediate feedback
print("üîç Pre-flight check...")
required_vars = ['medical_model', 'tokenizer', 'supabase', 'rag_system', 'CONFIG', 'MODEL_INFO']
missing_vars = [var for var in required_vars if var not in globals()]

if missing_vars:
    print(f"‚ùå Missing required variables: {', '.join(missing_vars)}")
    print(f"üîß Please run all previous cells (1-7) in order first!")
    print(f"   Then come back to this cell.")
    raise RuntimeError(f"Required setup incomplete. Missing: {', '.join(missing_vars)}")

print("‚úÖ Pre-flight check passed!")
print(f"üéØ Selected model: {MODEL_INFO['name']}")

print("üîë Using ngrok auth token from Colab secrets...")

# Use the token we already retrieved in Cell 3
if 'ngrok_token' not in globals() or not ngrok_token:
    print("‚ùå Ngrok token not found in secrets!")
    print("üîß Make sure you've added NGROK_AUTH_TOKEN to Colab secrets")
    raise ValueError("Missing NGROK_AUTH_TOKEN in Colab secrets")

ngrok.set_auth_token(ngrok_token)
print("‚úÖ Ngrok token set successfully!")

# Start ngrok tunnel
print("üåê Starting ngrok tunnel...")
public_url = ngrok.connect(8000)  # Changed from 5000 to 8000 for FastAPI
print(f"üåç Public URL: {public_url}")
print("üìã Copy this URL to your WellnessGrid app configuration!")

# Start FastAPI app with uvicorn
print("üöÄ Starting FastAPI app with uvicorn...")
print("üì° Available endpoints:")
print("  ‚úÖ POST /embed - Generate embeddings (required by WellnessGrid)")
print(f"  ‚úÖ POST /generate - Generate text with {MODEL_INFO['name']} (required by WellnessGrid)")
print("  ‚úÖ POST /ask - Main RAG endpoint for WellnessGrid")
print("  ‚úÖ GET /health - Health check")
print("  ‚úÖ POST /query - Query documents from Supabase")
print("  ‚úÖ GET /docs - Interactive API documentation")
print("  ‚úÖ GET /redoc - Alternative API documentation")
print("\nüéØ IMPORTANT: Copy the ngrok URL above to your WellnessGrid .env.local:")
print("   FLASK_API_URL=https://your-ngrok-id.ngrok.io")
print("\n‚ö†Ô∏è  Keep this cell running to maintain the server!")
print(f"\nüöÄ Your WellnessGrid RAG system with {MODEL_INFO['name']} is now live!")

# Run the FastAPI app with uvicorn
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

In [None]:
# Quick Test of Flask Endpoints (Optional)
# Run this AFTER starting the Flask server in Cell 8

import requests
import json

def test_flask_endpoints():
    """Test the Flask endpoints locally"""
    base_url = "http://localhost:5000"
    
    print("üß™ Testing Flask endpoints locally...")
    print("‚ö†Ô∏è Make sure Cell 8 (Flask server) is running first!\n")
    
    # Test 1: Health check
    try:
        print("1. Testing /health endpoint...")
        response = requests.get(f"{base_url}/health", timeout=5)
        if response.status_code == 200:
            print("‚úÖ Health check passed")
            print(f"   Response: {response.json()}")
        else:
            print(f"‚ùå Health check failed: {response.status_code}")
    except Exception as e:
        print(f"‚ùå Health check error: {str(e)}")
        print("   Make sure Flask server is running (Cell 8)")
        return
    
    # Test 2: Embedding endpoint
    try:
        print("\n2. Testing /embed endpoint...")
        test_data = {"text": "What is diabetes?"}
        response = requests.post(f"{base_url}/embed", 
                               json=test_data, 
                               headers={"Content-Type": "application/json"},
                               timeout=10)
        if response.status_code == 200:
            result = response.json()
            embedding = result.get('embedding', [])
            print(f"‚úÖ Embedding test passed")
            print(f"   Embedding dimensions: {len(embedding)}")
            print(f"   First 3 values: {embedding[:3]}")
        else:
            print(f"‚ùå Embedding test failed: {response.status_code}")
            print(f"   Error: {response.text}")
    except Exception as e:
        print(f"‚ùå Embedding test error: {str(e)}")
    
    # Test 3: Generation endpoint
    try:
        print("\n3. Testing /generate endpoint...")
        test_data = {
            "query": "What is diabetes?",
            "context": "Diabetes is a chronic condition affecting blood sugar levels.",
            "max_tokens": 50,
            "temperature": 0.7
        }
        response = requests.post(f"{base_url}/generate", 
                               json=test_data, 
                               headers={"Content-Type": "application/json"},
                               timeout=15)
        if response.status_code == 200:
            result = response.json()
            answer = result.get('answer', '')
            print(f"‚úÖ Generation test passed")
            print(f"   Answer length: {len(answer)} characters")
            print(f"   Answer preview: {answer[:100]}...")
        else:
            print(f"‚ùå Generation test failed: {response.status_code}")
            print(f"   Error: {response.text}")
    except Exception as e:
        print(f"‚ùå Generation test error: {str(e)}")
    
    print("\nüéØ Test completed!")
    print("If all tests pass, your Flask server is ready for WellnessGrid!")

# Uncomment the line below to run the test
# test_flask_endpoints()
