In [None]:
# Install required packages for Google Colab
%pip install transformers torch sentence-transformers --quiet
%pip install flask flask-cors pyngrok --quiet
%pip install supabase python-dotenv --quiet
%pip install sacremoses --quiet


In [None]:
import os
import json
import logging
import traceback
from datetime import datetime
from typing import List, Dict, Any, Optional
import torch
from getpass import getpass

# AI Models
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoModelForImageTextToText

# Supabase
from supabase import create_client

# Flask API
from flask import Flask, request, jsonify
from flask_cors import CORS

# Colab secrets
from google.colab import userdata

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("üì¶ All packages imported successfully!")
print(f"üïê RAG session started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"üîß Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

# Available medical models
MEDICAL_MODELS = {
    "1": {
        "name": "OpenBioLLM-8B",
        "path": "aaditya/OpenBioLLM-8B",
        "description": "8B parameter medical LLM optimized for biomedical tasks",
        "type": "causal"
    },
    "2": {
        "name": "Med42-v2-8B",
        "path": "m42-health/Llama3-Med42-8B",
        "description": "8B parameter medical model from M42 Health based on Llama3",
        "type": "causal"
    },
    "3": {
        "name": "MedGemma-4B",
        "path": "google/medgemma-4b-it",
        "description": "4B parameter multimodal medical model from Google",
        "type": "multimodal"
    }
}

print("\nü§ñ Available Medical Models:")
for key, model in MEDICAL_MODELS.items():
    print(f"  {key}. {model['name']} - {model['description']}")

print("\nPlease select a model by entering the number (1, 2, or 3):")


In [None]:
# Model selection and loading
print("ü§ñ Select your medical model:")
model_choice = input("Enter model number (1, 2, or 3): ").strip()

if model_choice not in MEDICAL_MODELS:
    print(f"‚ùå Invalid choice '{model_choice}'. Defaulting to OpenBioLLM-8B (option 1)")
    model_choice = "1"

selected_model = MEDICAL_MODELS[model_choice]
model_name = selected_model["name"]
model_path = selected_model["path"]
model_type = selected_model["type"]

print(f"üß† Loading {model_name} for medical text generation...")
print(f"üì¶ Model path: {model_path}")
print(f"üîß Model type: {model_type}")

try:
    if model_type == "multimodal":
        # Special handling for MedGemma (multimodal)
        print("üî§ Loading processor (for multimodal model)...")
        tokenizer = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        
        print("üß† Loading multimodal model (this may take a few minutes)...")
        print("Note: MedGemma will be used in text-only mode for this RAG system")
        medical_model = AutoModelForImageTextToText.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto",
            trust_remote_code=True
        )
    else:
        # Standard causal LM handling
        print("üî§ Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        
        print("üß† Loading model (this may take a few minutes)...")
        medical_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto",
            trust_remote_code=True
        )
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Set pad token if not available (only for causal models)
    if model_type == "causal" and hasattr(tokenizer, 'pad_token') and tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer.pad_token = tokenizer.unk_token
    
    # Store model info globally
    MODEL_INFO = {
        "name": model_name,
        "path": model_path,
        "choice": model_choice,
        "type": model_type
    }
    
    print(f"‚úÖ {model_name} loaded and ready on {device}")
    print(f"üéØ Selected model: {model_name}")
    
except Exception as e:
    print(f"‚ùå Error loading {model_name}: {str(e)}")
    print("üîÑ Falling back to a smaller model...")
    
    # Fallback to a more reliable model
    fallback_path = "microsoft/DialoGPT-medium"
    print(f"üîÑ Loading fallback model: {fallback_path}")
    
    tokenizer = AutoTokenizer.from_pretrained(fallback_path)
    medical_model = AutoModelForCausalLM.from_pretrained(
        fallback_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    MODEL_INFO = {
        "name": "DialoGPT-medium (Fallback)",
        "path": fallback_path,
        "choice": "fallback"
    }
    
    print(f"‚úÖ Fallback model loaded successfully")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Setup Supabase connection using Colab secrets
print("üóÑÔ∏è Setting up Supabase connection using Colab secrets...")
print("üìã Required Colab secrets:")
print("   1. SUPABASE_URL - Your project URL (e.g., https://abc123.supabase.co)")
print("   2. SUPABASE_SERVICE_ROLE_KEY - Your service role key (NOT anon key)")
print("   3. NGROK_AUTH_TOKEN - Your ngrok authentication token")
print("")
print("üîë To set these secrets:")
print("   1. Click the üîë key icon in the left sidebar")
print("   2. Add the three secrets listed above")
print("   3. Re-run this cell")
print("")

try:
    supabase_url = userdata.get('SUPABASE_URL')
    supabase_key = userdata.get('SUPABASE_SERVICE_ROLE_KEY')
    ngrok_token = userdata.get('NGROK_AUTH_TOKEN')
    
    print("‚úÖ Successfully retrieved secrets from Colab")
    
except Exception as e:
    print(f"‚ùå Error retrieving secrets: {str(e)}")
    print("üîß Make sure you've added the required secrets in Colab:")
    print("   ‚Ä¢ SUPABASE_URL")
    print("   ‚Ä¢ SUPABASE_SERVICE_ROLE_KEY") 
    print("   ‚Ä¢ NGROK_AUTH_TOKEN")
    raise

# Validate the inputs
if not supabase_url or not supabase_key:
    raise ValueError("‚ùå Both Supabase URL and Service Role Key are required!")

if not supabase_url.startswith('https://'):
    raise ValueError("‚ùå Supabase URL should start with 'https://'")

if not supabase_key.startswith('eyJ'):
    print("‚ö†Ô∏è WARNING: Service role keys typically start with 'eyJ'")
    print("   You might be using the anon key instead of service_role key")
    
if len(supabase_key) < 100:
    print("‚ö†Ô∏è WARNING: Service role keys are typically very long (200+ characters)")
    print("   You might be using the anon key instead of service_role key")

try:
    supabase = create_client(supabase_url, supabase_key)
    print("‚úÖ Supabase client initialized")
except Exception as e:
    print(f"‚ùå Failed to initialize Supabase client: {str(e)}")
    print("üîß Common issues:")
    print("   ‚Ä¢ Wrong API key type (use service_role, not anon)")
    print("   ‚Ä¢ Typo in URL or key")
    print("   ‚Ä¢ Key might be expired or regenerated")
    raise

# Configuration
CONFIG = {
    "top_k": 5,
    "similarity_threshold": 0.5,
    "max_context_length": 2000,
    "max_response_length": 150,
}

print(f"\n‚öôÔ∏è RAG Configuration:")
print(f"   üéØ Retrieve top {CONFIG['top_k']} similar documents")
print(f"   üìä Similarity threshold: {CONFIG['similarity_threshold']}")
print(f"   üìè Max context length: {CONFIG['max_context_length']} chars")


In [None]:
# Supabase document retrieval functions
def query_supabase_documents(query: str, top_k: int = None) -> List[Dict[str, Any]]:
    """Query Supabase for similar documents using vector search"""
    try:
        from sentence_transformers import SentenceTransformer
        
        # Load the same embedding model used for indexing
        print(f"üîç Loading embedding model for query: {query[:50]}...")
        embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings')
        
        top_k = top_k or CONFIG['top_k']
        
        # Generate embedding for the query
        print(f"üß† Generating embedding vector...")
        query_embedding = embedding_model.encode([query])[0].tolist()
        
        # Use the correct RPC function from schema.sql: search_embeddings
        print(f"üîç Searching embeddings with threshold {CONFIG['similarity_threshold']}...")
        result = supabase.rpc('search_embeddings', {
            'query_embedding': query_embedding,
            'match_threshold': CONFIG['similarity_threshold'],
            'match_count': top_k
        }).execute()
        
        if result.data:
            documents = []
            for i, doc in enumerate(result.data):
                documents.append({
                    'content': doc.get('chunk_content', ''),  # Correct field name from RPC
                    'similarity_score': doc.get('similarity', 0.0),
                    'metadata': {
                        'title': doc.get('title', 'Medical Document'),
                        'source': doc.get('source', 'unknown'),
                        'topic': doc.get('topic', 'general'),
                        'document_type': doc.get('document_type', 'unknown'),
                        'document_id': doc.get('document_id', '')
                    },
                    'rank': i + 1,
                    'doc_id': doc.get('document_id', '')
                })
            
            print(f"üìä Found {len(documents)} similar documents from Supabase")
            return documents
        else:
            print("‚ö†Ô∏è No similar documents found in Supabase")
            return []
            
    except Exception as e:
        print(f"‚ùå Error querying Supabase: {str(e)}")
        # Fallback: try direct table query if RPC function doesn't exist
        try:
            print("üîÑ Trying fallback query method...")
            result = supabase.table('medical_documents').select('*').limit(top_k).execute()
            
            if result.data:
                documents = []
                for i, doc in enumerate(result.data[:top_k]):
                    documents.append({
                        'content': doc.get('content', ''),
                        'similarity_score': 0.8,  # Default similarity
                        'metadata': {
                            'title': doc.get('title', 'Medical Document'),
                            'source': doc.get('source', 'unknown'),
                            'topic': doc.get('topic', 'general'),
                            'document_type': doc.get('document_type', 'unknown'),
                            'document_id': doc.get('id', '')
                        },
                        'rank': i + 1,
                        'doc_id': doc.get('id', '')
                    })
                
                print(f"üìä Fallback: Retrieved {len(documents)} documents from Supabase")
                return documents
            
        except Exception as fallback_error:
            print(f"‚ùå Fallback query also failed: {str(fallback_error)}")
            return []

# Test Supabase connection and RPC functions
print("üß™ Testing Supabase connection...")
try:
    # Test basic connection
    test_result = supabase.table('medical_documents').select('count').execute()
    doc_count = len(test_result.data) if test_result.data else 0
    print(f"‚úÖ Supabase connected - Found {doc_count} documents in database")
    
    # Test RPC function availability
    print("üß™ Testing RPC functions...")
    try:
        stats_result = supabase.rpc('get_document_stats').execute()
        if stats_result.data:
            print("‚úÖ RPC functions working")
            for stat in stats_result.data[:3]:  # Show first 3 document sources
                print(f"   üìÑ {stat['source']}: {stat['count']} documents")
        else:
            print("‚ö†Ô∏è RPC function exists but returned no data")
    except Exception as rpc_error:
        print(f"‚ö†Ô∏è RPC function test failed: {str(rpc_error)}")
        print("   Vector search will use fallback method")
        
except Exception as e:
    error_str = str(e)
    print(f"‚ö†Ô∏è Supabase connection test failed: {error_str}")
    
    # Provide specific guidance based on error type
    if '401' in error_str or 'Invalid API key' in error_str:
        print("üîß AUTHENTICATION ERROR - Invalid API Key:")
        print("   ‚ùå You're using the wrong API key!")
        print("   üìã To fix this:")
        print("   1. Go to your Supabase project dashboard")
        print("   2. Settings ‚Üí API")
        print("   3. Copy the 'service_role' key (NOT anon key)")
        print("   4. The service_role key is much longer and starts with 'eyJ'")
        print("   5. Re-run Cell 3 with the correct key")
        print("")
        print("   üîç Key differences:")
        print("   ‚Ä¢ anon key: Used for client-side apps (WRONG for this notebook)")
        print("   ‚Ä¢ service_role key: Used for server-side/admin access (CORRECT)")
    elif '404' in error_str:
        print("üîß TABLE NOT FOUND:")
        print("   ‚ùå The 'medical_documents' table doesn't exist!")
        print("   üìã To fix this:")
        print("   1. Run the schema.sql in your Supabase SQL editor")
        print("   2. Or run the embed_documents.py script to create tables")
    elif 'timeout' in error_str.lower():
        print("üîß CONNECTION TIMEOUT:")
        print("   ‚ùå Can't reach Supabase servers")
        print("   üìã Check your internet connection and Supabase URL")
    else:
        print("üîß GENERAL CONNECTION ERROR:")
        print("   üìã Common fixes:")
        print("   ‚Ä¢ Double-check your Supabase URL")
        print("   ‚Ä¢ Verify you're using service_role key (not anon)")
        print("   ‚Ä¢ Check if your project is paused/suspended")
        print("   ‚Ä¢ Ensure database tables exist")
    
    print("\n   ‚ö†Ô∏è The system will continue but may have limited document retrieval")


In [None]:
def generate_medical_response(prompt: str, max_new_tokens: int = 150) -> str:
    """Generate medical response using the selected medical model"""
    try:
        print(f"ü§ñ Generating response using {MODEL_INFO['name']}...")
        
        # Handle different model types
        if MODEL_INFO['type'] == 'multimodal':
            # For MedGemma - use chat template format
            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": "You are a helpful medical assistant. Provide accurate medical information based on the given context."}]
                },
                {
                    "role": "user", 
                    "content": [{"type": "text", "text": prompt}]
                }
            ]
            
            inputs = tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=True,
                return_dict=True, return_tensors="pt"
            ).to(medical_model.device, dtype=torch.bfloat16)
            
            input_len = inputs["input_ids"].shape[-1]
            
        else:
            # Standard causal LM handling
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
            if torch.cuda.is_available():
                inputs = {k: v.to("cuda") for k, v in inputs.items()}
            input_len = inputs['input_ids'].shape[1]
        
        # Generation parameters optimized for medical models
        generation_params = {
            "max_new_tokens": max_new_tokens,
            "temperature": 0.7,
            "do_sample": True,
            "repetition_penalty": 1.1,
            "top_p": 0.9
        }
        
        # Set tokens for causal models only
        if MODEL_INFO['type'] == 'causal':
            if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None:
                generation_params["pad_token_id"] = tokenizer.pad_token_id
            if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
                generation_params["eos_token_id"] = tokenizer.eos_token_id
        
        with torch.no_grad():
            outputs = medical_model.generate(
                **inputs,
                **generation_params
            )
        
        # Decode response based on model type
        if MODEL_INFO['type'] == 'multimodal':
            # For MedGemma, decode only the new tokens
            generation = outputs[0][input_len:]
            response = tokenizer.decode(generation, skip_special_tokens=True)
        else:
            # Standard causal LM - decode only the generated part
            response = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
        
        response = response.strip()
        
        # Clean up response for medical context
        # Remove common artifacts
        response = response.replace("</s>", "").replace("<s>", "").strip()
        
        # Basic quality check
        if not response or len(response) < 10:
            response = "I understand your question about health. Please consult with a healthcare professional for personalized medical advice."
        
        print(f"‚úÖ Response generated successfully ({len(response)} characters)")
        return response
        
    except Exception as e:
        print(f"‚ùå Generation error: {str(e)}")
        return f"I apologize, but I encountered an error processing your question. Please try rephrasing your question or consult with a healthcare professional."

class WellnessRAGSystem:
    """RAG system for medical/wellness queries using Supabase and BioGPT"""
    
    def __init__(self, config):
        self.config = config
    
    def retrieve_context(self, query: str) -> Dict[str, Any]:
        """Retrieve relevant document chunks from Supabase"""
        retrieved_docs = query_supabase_documents(query, self.config['top_k'])
        
        context_parts = []
        total_chars = 0
        
        for doc in retrieved_docs:
            if total_chars + len(doc['content']) <= self.config['max_context_length']:
                context_parts.append(f"Source: {doc['metadata']['source']}\n{doc['content']}")
                total_chars += len(doc['content'])
            else:
                break
        
        context = "\n\n".join(context_parts)
        
        return {
            'query': query,
            'context': context,
            'retrieved_documents': retrieved_docs,
            'total_documents_found': len(retrieved_docs),
            'documents_used': len(retrieved_docs),
            'context_length': len(context)
        }
    
    def query(self, question: str) -> Dict[str, Any]:
        """Complete RAG query: retrieve context and generate response"""
        print(f"üîç Processing query: {question}")
        
        context_result = self.retrieve_context(question)
        
        print(f"üìä Found {context_result['total_documents_found']} similar documents")
        print(f"üìÑ Using {context_result['documents_used']} documents for context")
        
        print(f"ü§ñ Generating response using {MODEL_INFO['name']}...")
        # Use the advanced medical prompt structure
        medical_prompt = f"""You are a helpful and accurate medical assistant. Use the following context to answer the question.

Context:
{context_result['context']}

Question: {question}

Answer (based only on the context, no assumptions):"""
        generated_response = generate_medical_response(medical_prompt, self.config['max_response_length'])
        print("‚úÖ Response generated successfully")
        
        result = {
            'query': question,
            'response': generated_response,
            'sources': [
                {
                    'title': doc['metadata'].get('title', 'Medical Document'),
                    'source': doc['metadata']['source'],
                    'topic': doc['metadata']['topic'],
                    'similarity': f"{doc['similarity_score']:.3f}",
                    'rank': doc['rank'],
                    'content_preview': doc['content'][:150] + "..."
                }
                for doc in context_result['retrieved_documents']
            ],
            'metadata': {
                'documentsUsed': context_result['documents_used'],
                'totalFound': context_result['total_documents_found'],
                'contextLength': context_result['context_length'],
                'model': MODEL_INFO['name'],
                'model_path': MODEL_INFO['path'],
                'embeddings': 'Supabase pgvector',
                'processingTime': datetime.now().isoformat()
            }
        }
        
        return result

# Initialize the RAG system
rag_system = WellnessRAGSystem(config=CONFIG)
print("‚úÖ WellnessGrid RAG system initialized!")


In [None]:
# Enhanced Flask API Setup with Chat History Support
app = Flask(__name__)
CORS(app)

# Store chat sessions in memory (in production, use Redis or database)
chat_sessions = {}

@app.route('/embed', methods=['POST'])
def generate_embedding():
    """Generate embeddings for text with enhanced error handling"""
    try:
        data = request.get_json()
        text = data.get("text", "")
        
        if not text:
            return jsonify({"error": "Missing 'text' field."}), 400
        
        logger.info(f"üîç Generating embedding for text: {text[:100]}...")
        
        # Load embedding model for this request
        from sentence_transformers import SentenceTransformer
        embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings')
        
        # Generate embedding
        embedding = embedding_model.encode([text])[0].tolist()
        
        logger.info(f"‚úÖ Generated embedding with {len(embedding)} dimensions")
        
        return jsonify({
            "embedding": embedding,
            "dimensions": len(embedding),
            "model": "PubMedBERT"
        })
        
    except Exception as e:
        logger.error(f"‚ùå Error in embed endpoint: {str(e)}")
        logger.error(traceback.format_exc())
        return jsonify({"error": f"Embedding generation failed: {str(e)}"}), 500

@app.route('/generate', methods=['POST'])
def generate_text():
    """Enhanced generate endpoint with chat history support"""
    try:
        data = request.get_json()
        query = data.get("query", "")
        context = data.get("context", "")
        history = data.get("history", [])  # New: chat history support
        max_tokens = data.get("max_tokens", 200)
        temperature = data.get("temperature", 0.7)
        
        if not query:
            return jsonify({"error": "Missing 'query' field."}), 400
        
        logger.info(f"üî¨ Generating response for query: {query[:100]}...")
        logger.info(f"üìö Context length: {len(context)} characters")
        logger.info(f"üí¨ Chat history: {len(history)} messages")
        
        # Create enhanced prompt with chat history
        if history:
            history_context = "\n".join([f"Human: {h.get('question', '')}\nAssistant: {h.get('answer', '')}" for h in history[-3:]])  # Last 3 exchanges
            prompt = f"""Previous conversation:
{history_context}

Context:
{context}

Current question: {query}

Answer based on the context and conversation history:"""
        else:
            prompt = f"""You are a helpful medical assistant. Use the following context to answer the question.

Context:
{context}

Question: {query}

Answer (based only on the context):"""
        
        # Generate response using the selected medical model
        response = generate_medical_response(prompt, max_tokens)
        
        return jsonify({
            "answer": response,
            "model": MODEL_INFO['name'],
            "model_path": MODEL_INFO['path'],
            "context_used": len(context) > 0,
            "history_used": len(history) > 0
        })
        
    except Exception as e:
        logger.error(f"‚ùå Error in generate endpoint: {str(e)}")
        return jsonify({"error": f"Text generation failed: {str(e)}"}), 500

@app.route('/query', methods=['POST'])
def query_docs():
    """Enhanced query endpoint with better error handling"""
    try:
        data = request.get_json()
        query = data.get("query", "")
        top_k = data.get("top_k", CONFIG['top_k'])
        
        if not query:
            return jsonify({"error": "Missing 'query' field."}), 400

        results = query_supabase_documents(query, top_k=top_k)
        
        return jsonify({
            "documents": results,
            "total_found": len(results),
            "query": query
        })
        
    except Exception as e:
        logger.error(f"‚ùå Error in query endpoint: {str(e)}")
        return jsonify({"error": f"Document query failed: {str(e)}"}), 500

@app.route('/ask', methods=['POST'])
def ask_rag():
    """Enhanced RAG endpoint with chat history support"""
    try:
        data = request.get_json()
        question = data.get("question", "")
        session_id = data.get("session_id", "default")
        history = data.get("history", [])
        
        if not question:
            return jsonify({"error": "Missing 'question' field."}), 400
        
        logger.info(f"ü§ñ Processing RAG query: {question[:100]}...")
        logger.info(f"üìù Session ID: {session_id}")
        
        # Use session-specific history if no history provided
        if not history and session_id in chat_sessions:
            history = chat_sessions[session_id]
        
        # Get response using existing query method
        result = rag_system.query(question)
        
        # Update chat history
        new_message = {"question": question, "answer": result['response']}
        if session_id not in chat_sessions:
            chat_sessions[session_id] = []
        chat_sessions[session_id].append(new_message)
        
        # Keep only last 10 messages to prevent memory issues
        if len(chat_sessions[session_id]) > 10:
            chat_sessions[session_id] = chat_sessions[session_id][-10:]
        
        return jsonify({
            "response": result['response'],
            "sources": [
                {
                    "title": source['title'], 
                    "content": source['content_preview'],
                    "similarity": float(source['similarity'])
                }
                for source in result['sources']
            ],
            "chat_history": chat_sessions[session_id],
            "session_id": session_id,
            "mockMode": False,
            "metadata": result['metadata']
        })
        
    except Exception as e:
        logger.error(f"‚ùå Error in ask endpoint: {str(e)}")
        logger.error(traceback.format_exc())
        return jsonify({
            "response": f"I apologize, but I encountered an error processing your question: {str(e)}",
            "sources": [],
            "chat_history": [],
            "mockMode": True,
            "error": str(e)
        }), 500

@app.route('/chat/history/<session_id>', methods=['GET'])
def get_chat_history(session_id):
    """Get chat history for a session"""
    try:
        history = chat_sessions.get(session_id, [])
        return jsonify({
            "session_id": session_id,
            "history": history,
            "message_count": len(history)
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/chat/clear/<session_id>', methods=['POST'])
def clear_chat_history(session_id):
    """Clear chat history for a session"""
    try:
        if session_id in chat_sessions:
            del chat_sessions[session_id]
        
        # RAG system doesn't maintain its own history, only session-based history
        
        return jsonify({
            "session_id": session_id,
            "cleared": True
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/health', methods=['GET'])
def health_check():
    """Enhanced health check endpoint"""
    try:
        # Test Supabase connection
        test_result = supabase.table('medical_documents').select('count').execute()
        doc_count = len(test_result.data) if test_result.data else 0
        
        # Test embeddings
        embed_result = supabase.table('document_embeddings').select('count').execute()
        embed_count = len(embed_result.data) if embed_result.data else 0
        
        return jsonify({
            "status": "healthy",
            "model": MODEL_INFO['name'],
            "model_path": MODEL_INFO['path'],
            "database": "Supabase + pgvector",
            "documents_in_db": doc_count,
            "embeddings_in_db": embed_count,
            "rag_system": "enhanced",
            "chat_support": True,
            "active_sessions": len(chat_sessions)
        })
    except Exception as e:
        return jsonify({
            "status": "partial",
            "model": MODEL_INFO['name'],
            "model_path": MODEL_INFO['path'],
            "database": "Supabase (connection issues)",
            "documents_in_db": "unknown",
            "rag_system": "enhanced",
            "chat_support": True,
            "warning": str(e)
        })

@app.route('/status', methods=['GET'])
def status():
    """Detailed status endpoint"""
    try:
        return jsonify({
            "timestamp": datetime.now().isoformat(),
            "models": {
                "selected_model": MODEL_INFO['name'] if 'MODEL_INFO' in globals() else "not_selected",
                "medical_model": "loaded" if 'medical_model' in globals() else "not_loaded",
                "pubmedbert": "available"
            },
            "database": {
                "connected": True,
                "url": supabase_url[:30] + "..." if supabase_url else "not_set"
            },
            "config": CONFIG,
            "memory": {
                "active_sessions": len(chat_sessions),
                "session_ids": list(chat_sessions.keys())
            }
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500

print("üåê Enhanced Flask API endpoints configured:")
print("  ‚úÖ POST /embed - Generate embeddings with enhanced error handling")
print(f"  ‚úÖ POST /generate - Generate text with {MODEL_INFO['name']} + chat history")
print("  ‚úÖ POST /ask - Enhanced RAG endpoint with chat history")
print("  ‚úÖ GET /health - Enhanced health check")
print("  ‚úÖ POST /query - Enhanced document query")
print("  ‚úÖ GET /chat/history/<session_id> - Get chat history")
print("  ‚úÖ POST /chat/clear/<session_id> - Clear chat history")
print("  ‚úÖ GET /status - Detailed status information")
print(f"‚úÖ Enhanced Flask server ready with {MODEL_INFO['name']} and chat history support!")


In [None]:
# Test the RAG system
print("üß™ Testing RAG system with sample question...")

test_question = "What are the symptoms of diabetes?"
try:
    print(f"üîç Testing query: {test_question}")
    result = rag_system.query(test_question)
    
    print("\n" + "=" * 80)
    print(f"‚ùì QUESTION: {result['query']}")
    print("=" * 80)
    
    print(f"\nü§ñ AI RESPONSE:")
    print(f"{result['response']}")
    
    print(f"\nüìö SOURCES ({result['metadata']['documentsUsed']} documents):")
    if result['sources']:
        for i, source in enumerate(result['sources'], 1):
            print(f"   {i}. {source['title']} - {source['source']}")
            print(f"      üìä Similarity: {source['similarity']}")
            print(f"      üìÑ Preview: {source['content_preview']}")
            print()
    else:
        print("   ‚ö†Ô∏è No sources found - this could indicate:")
        print("   ‚Ä¢ No documents in database yet")
        print("   ‚Ä¢ Similarity threshold too high")
        print("   ‚Ä¢ RPC function needs adjustment")
    
    print(f"\nüìä Metadata:")
    print(f"   üîß Model: {result['metadata']['model']}")
    print(f"   üíæ Embeddings: {result['metadata']['embeddings']}")
    print(f"   üìÑ Documents Used: {result['metadata']['documentsUsed']}")
    print(f"   üéØ Total Found: {result['metadata']['totalFound']}")
    
    print("‚úÖ RAG system test completed!")
    
except Exception as e:
    print(f"‚ö†Ô∏è RAG test failed: {str(e)}")
    print("   This might be normal if:")
    print("   ‚Ä¢ Supabase connection needs adjustment")
    print("   ‚Ä¢ No documents have been embedded yet")
    print("   ‚Ä¢ RPC function is not deployed")
    print("   The Flask server will still start and you can test via the API")


In [None]:
# ngrok setup and Flask server startup
from pyngrok import ngrok

# Quick pre-check to give immediate feedback
print("üîç Pre-flight check...")
required_vars = ['medical_model', 'tokenizer', 'supabase', 'rag_system', 'CONFIG', 'MODEL_INFO']
missing_vars = [var for var in required_vars if var not in globals()]

if missing_vars:
    print(f"‚ùå Missing required variables: {', '.join(missing_vars)}")
    print(f"üîß Please run all previous cells (1-7) in order first!")
    print(f"   Then come back to this cell.")
    raise RuntimeError(f"Required setup incomplete. Missing: {', '.join(missing_vars)}")

print("‚úÖ Pre-flight check passed!")
print(f"üéØ Selected model: {MODEL_INFO['name']}")

print("üîë Using ngrok auth token from Colab secrets...")

# Use the token we already retrieved in Cell 3
if 'ngrok_token' not in globals() or not ngrok_token:
    print("‚ùå Ngrok token not found in secrets!")
    print("üîß Make sure you've added NGROK_AUTH_TOKEN to Colab secrets")
    raise ValueError("Missing NGROK_AUTH_TOKEN in Colab secrets")

ngrok.set_auth_token(ngrok_token)
print("‚úÖ Ngrok token set successfully!")

# Validate setup
print("üîç Validating setup...")

try:
    # Check if variables exist before asserting their values
    missing_components = []
    
    # Check Medical model
    try:
        if 'medical_model' not in globals() or medical_model is None:
            missing_components.append("Medical model (run Cell 3)")
        else:
            print(f"‚úÖ {MODEL_INFO['name']}: Loaded")
    except NameError:
        missing_components.append("Medical model (run Cell 3)")
    
    # Check tokenizer
    try:
        if 'tokenizer' not in globals() or tokenizer is None:
            missing_components.append("Tokenizer (run Cell 3)")
        else:
            print(f"‚úÖ Tokenizer: Loaded")
    except NameError:
        missing_components.append("Tokenizer (run Cell 3)")
    
    # Check Supabase client
    try:
        if 'supabase' not in globals() or supabase is None:
            missing_components.append("Supabase client (run Cell 3)")
        else:
            print(f"‚úÖ Supabase client: Initialized")
    except NameError:
        missing_components.append("Supabase client (run Cell 3)")
    
    # Check RAG system
    try:
        if 'rag_system' not in globals() or rag_system is None:
            missing_components.append("RAG system (run Cell 5)")
        else:
            print(f"‚úÖ RAG system: Initialized")
    except NameError:
        missing_components.append("RAG system (run Cell 5)")
    
    # Check device
    try:
        if 'device' not in globals():
            missing_components.append("Device variable (run Cell 3)")
        else:
            print(f"‚úÖ Device: {device}")
    except NameError:
        missing_components.append("Device variable (run Cell 3)")
    
    # If any components are missing, show helpful error message
    if missing_components:
        print(f"\n‚ùå Missing components:")
        for component in missing_components:
            print(f"   ‚Ä¢ {component}")
        print(f"\nüîß To fix this:")
        print(f"   1. Run all previous cells in order (Cells 1-7)")
        print(f"   2. Wait for each cell to complete before running the next")
        print(f"   3. Look for any error messages in the cell outputs")
        print(f"   4. Then run this cell again")
        raise ValueError(f"Missing required components: {', '.join(missing_components)}")
    
    # Test Supabase connection one more time
    try:
        test_result = supabase.table('medical_documents').select('count').execute()
        doc_count = len(test_result.data) if test_result.data else 0
        print(f"‚úÖ Supabase: Connected ({doc_count} documents)")
    except Exception as db_error:
        print(f"‚ö†Ô∏è Supabase: Connection issues ({str(db_error)})")
        print("   RAG system will still work but may have limited retrieval")
    
    # Start ngrok tunnel
    print("üåê Starting ngrok tunnel...")
    public_url = ngrok.connect(5000)
    print(f"üåç Public URL: {public_url}")
    print("üìã Copy this URL to your WellnessGrid app configuration!")
    
    # Start Flask app
    print("üöÄ Starting Flask app...")
    print("üì° Available endpoints:")
    print("  ‚úÖ POST /embed - Generate embeddings (required by WellnessGrid)")
    print(f"  ‚úÖ POST /generate - Generate text with {MODEL_INFO['name']} (required by WellnessGrid)")
    print("  ‚úÖ POST /ask - Main RAG endpoint for WellnessGrid")
    print("  ‚úÖ GET /health - Health check")
    print("  ‚úÖ POST /query - Query documents from Supabase")
    print("\nüéØ IMPORTANT: Copy the ngrok URL above to your WellnessGrid .env.local:")
    print("   FLASK_API_URL=https://your-ngrok-id.ngrok.io")
    print("\n‚ö†Ô∏è  Keep this cell running to maintain the server!")
    print(f"\nüöÄ Your WellnessGrid RAG system with {MODEL_INFO['name']} is now live!")
    
    app.run(host='0.0.0.0', port=5000, debug=False)
    
except Exception as e:
    print(f"‚ùå Setup validation failed: {str(e)}")
    print("\nüîß Troubleshooting steps:")
    print("   1. ‚ö° Run ALL previous cells (1-7) in order")
    print("   2. üîç Check for errors in cell outputs")
    print("   3. üîÑ Restart kernel if needed: Runtime ‚Üí Restart Runtime")
    print("   4. üìã Re-run cells 1-7, then try this cell again")


In [None]:
# Quick Test of Flask Endpoints (Optional)
# Run this AFTER starting the Flask server in Cell 8

import requests
import json

def test_flask_endpoints():
    """Test the Flask endpoints locally"""
    base_url = "http://localhost:5000"
    
    print("üß™ Testing Flask endpoints locally...")
    print("‚ö†Ô∏è Make sure Cell 8 (Flask server) is running first!\n")
    
    # Test 1: Health check
    try:
        print("1. Testing /health endpoint...")
        response = requests.get(f"{base_url}/health", timeout=5)
        if response.status_code == 200:
            print("‚úÖ Health check passed")
            print(f"   Response: {response.json()}")
        else:
            print(f"‚ùå Health check failed: {response.status_code}")
    except Exception as e:
        print(f"‚ùå Health check error: {str(e)}")
        print("   Make sure Flask server is running (Cell 8)")
        return
    
    # Test 2: Embedding endpoint
    try:
        print("\n2. Testing /embed endpoint...")
        test_data = {"text": "What is diabetes?"}
        response = requests.post(f"{base_url}/embed", 
                               json=test_data, 
                               headers={"Content-Type": "application/json"},
                               timeout=10)
        if response.status_code == 200:
            result = response.json()
            embedding = result.get('embedding', [])
            print(f"‚úÖ Embedding test passed")
            print(f"   Embedding dimensions: {len(embedding)}")
            print(f"   First 3 values: {embedding[:3]}")
        else:
            print(f"‚ùå Embedding test failed: {response.status_code}")
            print(f"   Error: {response.text}")
    except Exception as e:
        print(f"‚ùå Embedding test error: {str(e)}")
    
    # Test 3: Generation endpoint
    try:
        print("\n3. Testing /generate endpoint...")
        test_data = {
            "query": "What is diabetes?",
            "context": "Diabetes is a chronic condition affecting blood sugar levels.",
            "max_tokens": 50,
            "temperature": 0.7
        }
        response = requests.post(f"{base_url}/generate", 
                               json=test_data, 
                               headers={"Content-Type": "application/json"},
                               timeout=15)
        if response.status_code == 200:
            result = response.json()
            answer = result.get('answer', '')
            print(f"‚úÖ Generation test passed")
            print(f"   Answer length: {len(answer)} characters")
            print(f"   Answer preview: {answer[:100]}...")
        else:
            print(f"‚ùå Generation test failed: {response.status_code}")
            print(f"   Error: {response.text}")
    except Exception as e:
        print(f"‚ùå Generation test error: {str(e)}")
    
    print("\nüéØ Test completed!")
    print("If all tests pass, your Flask server is ready for WellnessGrid!")

# Uncomment the line below to run the test
# test_flask_endpoints()
